1""" 2The torch package contains data structures for multi-dimensional 3tensors and defines mathematical operations over these tensors. 4Additionally, it provides many utilities for efficient serialization of 5Tensors and arbitrary types, and other useful utilities. 6 7It has a CUDA counterpart, that enables you to run your tensor computations 8on an NVIDIA GPU with compute capability >= 3.0. 9""" 10 11# mypy: allow-untyped-defs 12 13import builtins 14import ctypes 15import glob 16import importlib 17import inspect 18import math 19import os 20import platform 21import sys 22import textwrap 23import threading 24from typing import ( 25 Any as _Any, 26 Callable as _Callable, 27 Dict as _Dict, 28 Optional as _Optional, 29 overload as _overload, 30 Set as _Set, 31 Tuple as _Tuple, 32 Type as _Type, 33 TYPE_CHECKING, 34 TypeVar as _TypeVar, 35 Union as _Union, 36) 37from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard 38 39 40if TYPE_CHECKING: 41 from .types import IntLikeType 42 43 44# multipy/deploy is setting this import before importing torch, this is the most 45# reliable way we have to detect if we're running within deploy. 46# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 47def _running_with_deploy() -> builtins.bool: 48 return sys.modules.get("torch._meta_registrations", None) is object 49 50 51from torch._utils import ( 52 _functionalize_sync as _sync, 53 _import_dotted_name, 54 classproperty, 55) 56from torch._utils_internal import ( 57 get_file_path, 58 prepare_multiprocessing_environment, 59 USE_GLOBAL_DEPS, 60 USE_RTLD_GLOBAL_WITH_LIBTORCH, 61) 62 63 64# TODO(torch_deploy) figure out how to freeze version.py in fbcode build 65if _running_with_deploy(): 66 __version__ = "torch-deploy-1.8" 67else: 68 from torch.torch_version import __version__ as __version__ 69 70__all__ = [ 71 "BoolStorage", 72 "BoolTensor", 73 "ByteStorage", 74 "ByteTensor", 75 "CharStorage", 76 "CharTensor", 77 "DoubleStorage", 78 "DoubleTensor", 79 "FloatStorage", 80 "FloatTensor", 81 "GradScaler", 82 "IntStorage", 83 "IntTensor", 84 "LongStorage", 85 "LongTensor", 86 "ShortStorage", 87 "ShortTensor", 88 "SymBool", 89 "SymFloat", 90 "SymInt", 91 "Tensor", 92 "TypedStorage", 93 "UntypedStorage", 94 "are_deterministic_algorithms_enabled", 95 "autocast", 96 "chunk", 97 "compile", 98 "cond", 99 "enable_grad", 100 "export", 101 "get_default_device", 102 "get_deterministic_debug_mode", 103 "get_device_module", 104 "get_float32_matmul_precision", 105 "get_rng_state", 106 "inference_mode", 107 "initial_seed", 108 "is_deterministic_algorithms_warn_only_enabled", 109 "is_storage", 110 "is_tensor", 111 "is_warn_always_enabled", 112 "load", 113 "lobpcg", 114 "manual_seed", 115 "matmul", 116 "no_grad", 117 "rand", 118 "randn", 119 "save", 120 "seed", 121 "set_default_device", 122 "set_default_tensor_type", 123 "set_deterministic_debug_mode", 124 "set_float32_matmul_precision", 125 "set_printoptions", 126 "set_rng_state", 127 "set_warn_always", 128 "split", 129 "stack", 130 "sym_float", 131 "sym_int", 132 "sym_ite", 133 "sym_max", 134 "sym_min", 135 "sym_not", 136 "typename", 137 "unravel_index", 138 "use_deterministic_algorithms", 139 "vmap", 140] 141 142# Please keep this list sorted 143assert __all__ == sorted(__all__) 144 145################################################################################ 146# Load the extension module 147################################################################################ 148 149if sys.platform == "win32": 150 151 def _load_dll_libraries() -> None: 152 import sysconfig 153 154 from torch.version import cuda as cuda_version 155 156 pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files") 157 py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin") 158 th_dll_path = os.path.join(os.path.dirname(__file__), "lib") 159 usebase_path = os.path.join( 160 sysconfig.get_config_var("userbase"), "Library", "bin" 161 ) 162 163 # When users create a virtualenv that inherits the base environment, 164 # we will need to add the corresponding library directory into 165 # DLL search directories. Otherwise, it will rely on `PATH` which 166 # is dependent on user settings. 167 if sys.exec_prefix != sys.base_exec_prefix: 168 base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin") 169 else: 170 base_py_dll_path = "" 171 172 dll_paths = [ 173 p 174 for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path) 175 if os.path.exists(p) 176 ] 177 178 if not builtins.any( 179 os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths 180 ): 181 nvtoolsext_dll_path = os.path.join( 182 os.getenv( 183 "NVTOOLSEXT_PATH", 184 os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"), 185 ), 186 "bin", 187 "x64", 188 ) 189 else: 190 nvtoolsext_dll_path = "" 191 192 if cuda_version and builtins.all( 193 not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths 194 ): 195 cuda_version_1 = cuda_version.replace(".", "_") 196 cuda_path_var = "CUDA_PATH_V" + cuda_version_1 197 default_path = os.path.join( 198 pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}" 199 ) 200 cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin") 201 else: 202 cuda_path = "" 203 204 dll_paths.extend( 205 p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p) 206 ) 207 208 kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) 209 with_load_library_flags = hasattr(kernel32, "AddDllDirectory") 210 prev_error_mode = kernel32.SetErrorMode(0x0001) 211 212 kernel32.LoadLibraryW.restype = ctypes.c_void_p 213 if with_load_library_flags: 214 kernel32.LoadLibraryExW.restype = ctypes.c_void_p 215 216 for dll_path in dll_paths: 217 os.add_dll_directory(dll_path) 218 219 try: 220 ctypes.CDLL("vcruntime140.dll") 221 ctypes.CDLL("msvcp140.dll") 222 ctypes.CDLL("vcruntime140_1.dll") 223 except OSError: 224 print( 225 textwrap.dedent( 226 """ 227 Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. 228 It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe 229 """ 230 ).strip() 231 ) 232 233 dlls = glob.glob(os.path.join(th_dll_path, "*.dll")) 234 path_patched = False 235 for dll in dlls: 236 is_loaded = False 237 if with_load_library_flags: 238 res = kernel32.LoadLibraryExW(dll, None, 0x00001100) 239 last_error = ctypes.get_last_error() 240 if res is None and last_error != 126: 241 err = ctypes.WinError(last_error) 242 err.strerror += ( 243 f' Error loading "{dll}" or one of its dependencies.' 244 ) 245 raise err 246 elif res is not None: 247 is_loaded = True 248 if not is_loaded: 249 if not path_patched: 250 os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]]) 251 path_patched = True 252 res = kernel32.LoadLibraryW(dll) 253 if res is None: 254 err = ctypes.WinError(ctypes.get_last_error()) 255 err.strerror += ( 256 f' Error loading "{dll}" or one of its dependencies.' 257 ) 258 raise err 259 260 kernel32.SetErrorMode(prev_error_mode) 261 262 _load_dll_libraries() 263 del _load_dll_libraries 264 265 266def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None: 267 """Preloads cuda deps if they could not be found otherwise.""" 268 # Should only be called on Linux if default path resolution have failed 269 assert platform.system() == "Linux", "Should only be called on Linux" 270 271 lib_path = None 272 for path in sys.path: 273 nvidia_path = os.path.join(path, "nvidia") 274 if not os.path.exists(nvidia_path): 275 continue 276 candidate_lib_paths = glob.glob( 277 os.path.join(nvidia_path, lib_folder, "lib", lib_name) 278 ) 279 if candidate_lib_paths and not lib_path: 280 lib_path = candidate_lib_paths[0] 281 if lib_path: 282 break 283 if not lib_path: 284 raise ValueError(f"{lib_name} not found in the system path {sys.path}") 285 ctypes.CDLL(lib_path) 286 287 288# See Note [Global dependencies] 289def _load_global_deps() -> None: 290 if _running_with_deploy() or platform.system() == "Windows": 291 return 292 293 # Determine the file extension based on the platform 294 lib_ext = ".dylib" if platform.system() == "Darwin" else ".so" 295 lib_name = f"libtorch_global_deps{lib_ext}" 296 here = os.path.abspath(__file__) 297 global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name) 298 299 try: 300 ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) 301 except OSError as err: 302 # Can only happen for wheel with cuda libs as PYPI deps 303 # As PyTorch is not purelib, but nvidia-*-cu12 is 304 cuda_libs: _Dict[str, str] = { 305 "cublas": "libcublas.so.*[0-9]", 306 "cudnn": "libcudnn.so.*[0-9]", 307 "cuda_nvrtc": "libnvrtc.so.*[0-9]", 308 "cuda_runtime": "libcudart.so.*[0-9]", 309 "cuda_cupti": "libcupti.so.*[0-9]", 310 "cufft": "libcufft.so.*[0-9]", 311 "curand": "libcurand.so.*[0-9]", 312 "nvjitlink": "libnvJitLink.so.*[0-9]", 313 "cusparse": "libcusparse.so.*[0-9]", 314 "cusolver": "libcusolver.so.*[0-9]", 315 "nccl": "libnccl.so.*[0-9]", 316 "nvtx": "libnvToolsExt.so.*[0-9]", 317 } 318 is_cuda_lib_err = [ 319 lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] 320 ] 321 if not is_cuda_lib_err: 322 raise err 323 for lib_folder, lib_name in cuda_libs.items(): 324 _preload_cuda_deps(lib_folder, lib_name) 325 ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) 326 327 328if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( 329 _running_with_deploy() or platform.system() != "Windows" 330): 331 # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a 332 # few circumstances: 333 # 334 # 1. You're in a build environment (e.g., fbcode) where 335 # libtorch_global_deps is not available, but you still need 336 # to get mkl to link in with RTLD_GLOBAL or it will just 337 # not work. 338 # 339 # 2. You're trying to run PyTorch under UBSAN and you need 340 # to ensure that only one copy of libtorch is loaded, so 341 # vptr checks work properly 342 # 343 # If you're using this setting, you must verify that all the libraries 344 # you load consistently use the same libstdc++, or you may have 345 # mysterious segfaults. 346 # 347 old_flags = sys.getdlopenflags() 348 sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY) 349 350 from torch._C import * # noqa: F403 351 352 sys.setdlopenflags(old_flags) 353 del old_flags 354 355else: 356 # Easy way. You want this most of the time, because it will prevent 357 # C++ symbols from libtorch clobbering C++ symbols from other 358 # libraries, leading to mysterious segfaults. 359 # 360 # If building in an environment where libtorch_global_deps isn't available 361 # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will 362 # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False 363 # 364 # See Note [Global dependencies] 365 if USE_GLOBAL_DEPS: 366 _load_global_deps() 367 from torch._C import * # noqa: F403 368 369 370class SymInt: 371 """ 372 Like an int (including magic methods), but redirects all operations on the 373 wrapped node. This is used in particular to symbolically record operations 374 in the symbolic shape workflow. 375 """ 376 377 def __init__(self, node): 378 # This field MUST be named node; C++ binding code assumes that this 379 # class has a field named node that stores SymNode 380 self.node = node 381 382 def __bool__(self): 383 return builtins.bool(self != 0) 384 385 def __int__(self): 386 return self.node.int_() 387 388 def __index__(self): 389 return self.node.int_() 390 391 # Magic methods installed by torch.fx.experimental.sym_node 392 393 def __round__(self, ndigits=None): 394 return self 395 396 def __truediv__(self, other): 397 if isinstance(other, (builtins.float, SymFloat)): 398 return sym_float(self).__float_truediv__(other) 399 if not isinstance(other, (builtins.int, SymInt)): 400 return NotImplemented 401 return self.__int_truediv__(other) 402 403 def __rtruediv__(self, other): 404 if isinstance(other, (builtins.float, SymFloat)): 405 return sym_float(self).__rfloat_truediv__(other) 406 if not isinstance(other, (builtins.int, SymInt)): 407 return NotImplemented 408 return self.__rint_truediv__(other) 409 410 def __floordiv__(self, other): 411 if isinstance(other, (builtins.float, SymFloat)): 412 return sym_float(math.floor(sym_float(self) / other)) 413 if not isinstance(other, (builtins.int, SymInt)): 414 return NotImplemented 415 return self.__int_floordiv__(other) 416 417 def __rfloordiv__(self, other): 418 if isinstance(other, (builtins.float, SymFloat)): 419 return sym_float(math.floor(other / sym_float(self))) 420 if not isinstance(other, (builtins.int, SymInt)): 421 return NotImplemented 422 return self.__rint_floordiv__(other) 423 424 # nb: complex is impossible to handle correctly lol, with 425 # negative base and integral float need to diverge semantics and 426 # just always return complex. Neener neener pretend this problem 427 # doesn't exist 428 def __pow__(self, other): 429 if isinstance(other, (builtins.float, SymFloat)): 430 return sym_float(self).__pow__(other) 431 if not isinstance(other, (builtins.int, SymInt)): 432 return NotImplemented 433 # Guards! This guard is necessary because we need to know it to 434 # determine the output type of this operation 435 if other >= 0: 436 return self.__pow_by_natural__(other) 437 else: 438 # Mercifully, when the exponent is negative, Python just promotes 439 # to doubles and does a float pow: 440 # 441 # if (Py_SIZE(b) < 0 && c == NULL) { 442 # /* if exponent is negative and there's no modulus: 443 # return a float. This works because we know 444 # that this calls float_pow() which converts its 445 # arguments to double. */ 446 # Py_DECREF(a); 447 # Py_DECREF(b); 448 # return PyFloat_Type.tp_as_number->nb_power(v, w, x); 449 # } 450 return sym_float(self).__pow__(sym_float(other)) 451 452 def __rpow__(self, other): 453 if isinstance(other, (builtins.float, SymFloat)): 454 return sym_float(self).__rpow__(other) 455 if not isinstance(other, (builtins.int, SymInt)): 456 return NotImplemented 457 if self >= 0: # self is exponent 458 return self.__rpow_by_natural__(other) 459 else: 460 return sym_float(self).__rpow__(sym_float(other)) 461 462 def __eq__(self, other: object) -> builtins.bool: 463 raise TypeError("type stub not overridden") 464 465 def __lt__(self, other) -> builtins.bool: 466 raise TypeError("type stub not overridden") 467 468 def __gt__(self, other) -> builtins.bool: 469 raise TypeError("type stub not overridden") 470 471 def __le__(self, other) -> builtins.bool: 472 raise TypeError("type stub not overridden") 473 474 def __ge__(self, other) -> builtins.bool: 475 raise TypeError("type stub not overridden") 476 477 def __add__(self, other) -> "SymInt": 478 raise TypeError("type stub not overridden") 479 480 def __mod__(self, other: "IntLikeType") -> "SymInt": 481 raise TypeError("type stub not overridden") 482 483 def __mul__(self, other) -> "SymInt": 484 raise TypeError("type stub not overridden") 485 486 def __pow_by_natural__(self, other) -> "SymInt": 487 raise TypeError("type stub not overridden") 488 489 def __rpow_by_natural__(self, other) -> "SymInt": 490 raise TypeError("type stub not overridden") 491 492 def __int_truediv__(self, other) -> "SymFloat": 493 raise TypeError("type stub not overridden") 494 495 def __rint_truediv__(self, other) -> "SymFloat": 496 raise TypeError("type stub not overridden") 497 498 def __int_floordiv__(self, other) -> "SymFloat": 499 raise TypeError("type stub not overridden") 500 501 def __rint_floordiv__(self, other) -> "SymFloat": 502 raise TypeError("type stub not overridden") 503 504 def __sym_max__(self, other): 505 raise TypeError("type stub not overridden") 506 507 def __sym_min__(self, other): 508 raise TypeError("type stub not overridden") 509 510 def __sym_float__(self): 511 raise TypeError("type stub not overridden") 512 513 def __neg__(self): 514 raise TypeError("type stub not overridden") 515 516 def __sub__(self, other: "IntLikeType") -> "SymInt": 517 raise TypeError("type stub not overridden") 518 519 def __repr__(self): 520 return self.node._graph_repr() 521 522 def _sympy_(self): 523 return self.node.expr 524 525 def __hash__(self) -> builtins.int: 526 if self.node.is_nested_int(): 527 return hash(self.node.nested_int()) 528 else: 529 # We could support constant SymInts as well, but not doing it for now 530 raise TypeError("unhashable type: non-nested SymInt") 531 # TODO: Force specialization 532 # This can't be done because the TypeError here is load bearing 533 # for einops 534 # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237 535 # return hash(builtins.int(self)) 536 537 def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]: 538 """Represent this int as an exact integer ratio""" 539 return self, 1 540 541 def bit_length(self) -> builtins.int: 542 # TODO: A more relaxed guard is possible here, where you guard to 543 # allow all integer quantities which would result in the same bit 544 # length. We can also just make a dedicated Sympy function for 545 # computing this quantity and represent it symbolically. 546 return builtins.int(self).bit_length() 547 548 def conjugate(self) -> "SymInt": 549 return self 550 551 552class SymFloat: 553 """ 554 Like an float (including magic methods), but redirects all operations on the 555 wrapped node. This is used in particular to symbolically record operations 556 in the symbolic shape workflow. 557 """ 558 559 def __init__(self, node): 560 # This field MUST be named node; C++ binding code assumes that this 561 # class has a field named node that stores SymNode 562 self.node = node 563 564 def __truediv__(self, other): 565 if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): 566 return NotImplemented 567 return self.__float_truediv__(sym_float(other)) 568 569 def __rtruediv__(self, other): 570 if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): 571 return NotImplemented 572 return self.__rfloat_truediv__(sym_float(other)) 573 574 def __floordiv__(self, other): 575 if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): 576 return NotImplemented 577 return sym_float(math.floor(self / sym_float(other))) 578 579 def __rfloordiv__(self, other): 580 if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): 581 return NotImplemented 582 return sym_float(math.floor(sym_float(other) / self)) 583 584 def __bool__(self): 585 return self.node.bool_() 586 587 def __float__(self): 588 return self.node.guard_float("", 0) 589 590 # Symbolic power does NOT work with negative base, this is to avoid 591 # potential complex outputs 592 def __pow__(self, other): 593 if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): 594 return NotImplemented 595 torch._check(self >= 0) 596 return self.__float_pow__(other) 597 598 def __rpow__(self, other): 599 if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): 600 return NotImplemented 601 torch._check(other >= 0) 602 return self.__rfloat_pow__(other) 603 604 # Magic methods installed by torch.fx.experimental.sym_node 605 606 def __eq__(self, other: object) -> builtins.bool: 607 raise TypeError("type stub not overridden") 608 609 def __lt__(self, other) -> builtins.bool: 610 raise TypeError("type stub not overridden") 611 612 def __gt__(self, other) -> builtins.bool: 613 raise TypeError("type stub not overridden") 614 615 def __le__(self, other) -> builtins.bool: 616 raise TypeError("type stub not overridden") 617 618 def __ge__(self, other) -> builtins.bool: 619 raise TypeError("type stub not overridden") 620 621 def __float_pow__(self, other) -> "SymFloat": 622 raise TypeError("type stub not overridden") 623 624 def __rfloat_pow__(self, other) -> "SymFloat": 625 raise TypeError("type stub not overridden") 626 627 def __float_truediv__(self, other) -> "SymFloat": 628 raise TypeError("type stub not overridden") 629 630 def __rfloat_truediv__(self, other) -> "SymFloat": 631 raise TypeError("type stub not overridden") 632 633 def __trunc__(self): 634 raise TypeError("type stub not overridden") 635 636 def __sym_max__(self, other): 637 raise TypeError("type stub not overridden") 638 639 def __sym_min__(self, other): 640 raise TypeError("type stub not overridden") 641 642 def __sym_int__(self): 643 raise TypeError("type stub not overridden") 644 645 def is_integer(self): 646 """Return True if the float is an integer.""" 647 raise TypeError("type stub not overridden") 648 649 def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]: 650 """Represent this float as an exact integer ratio""" 651 return builtins.float(self).as_integer_ratio() 652 653 def __repr__(self): 654 return self.node._graph_repr() 655 656 def _sympy_(self): 657 return self.node.expr 658 659 def __hash__(self): 660 return hash(builtins.float(self)) 661 662 663class SymBool: 664 """ 665 Like an bool (including magic methods), but redirects all operations on the 666 wrapped node. This is used in particular to symbolically record operations 667 in the symbolic shape workflow. 668 669 Unlike regular bools, regular boolean operators will force extra guards instead 670 of symbolically evaluate. Use the bitwise operators instead to handle this. 671 """ 672 673 def __init__(self, node): 674 # This field MUST be named node; C++ binding code assumes that this 675 # class has a field named node that stores SymNode 676 self.node = node 677 678 def __bool__(self): 679 return self.node.bool_() 680 681 def __int__(self): 682 return builtins.int(self.node.bool_()) 683 684 # Magic methods installed by torch.fx.experimental.sym_node 685 def __and__(self, other) -> "SymBool": 686 raise TypeError("type stub not overridden") 687 688 def __or__(self, other) -> "SymBool": 689 raise TypeError("type stub not overridden") 690 691 # We very carefully define __sym_not__, and not a number of other 692 # plausible alternatives: 693 # 694 # - We do not override __not__ because this is not a real magic 695 # method; you cannot override the meaning of the not builtin in 696 # Python. We use the name 'sym_not' to clarify that in user code you 697 # cannot use the builtin not or operator.not_ or operator.__not__ and 698 # hit this magic method; you must use our custom sym_not operator. 699 # 700 # - We do not override the __invert__ method because SymBool is 701 # meant to be usable in situations where bool is expected. However, 702 # bitwise negation ~a does the wrong thing with booleans (because 703 # bool is a subclass of int, so ~1 = -2 which is not falseish.) 704 # This would be a giant footgun, so we get around it by defining 705 # our own operator. Note that bitwise and/or do the right thing, 706 # so we reuse the conventional operators there for readability. 707 # 708 def __sym_not__(self) -> "SymBool": 709 raise TypeError("type stub not overridden") 710 711 def __sym_ite__(self, then_val, else_val): 712 raise TypeError("type stub not overridden") 713 714 def __eq__(self, other) -> builtins.bool: 715 raise TypeError("type stub not overridden") 716 717 def __repr__(self): 718 return self.node._graph_repr() 719 720 def _sympy_(self): 721 return self.node.expr 722 723 def __hash__(self): 724 if self.node.is_constant(): 725 return hash(self.node.bool_()) 726 else: 727 # Force specialization 728 return hash(builtins.bool(self)) 729 730 731def sym_not(a): 732 r"""SymInt-aware utility for logical negation. 733 734 Args: 735 a (SymBool or bool): Object to negate 736 """ 737 import sympy 738 739 if overrides.has_torch_function_unary(a): 740 return overrides.handle_torch_function(sym_not, (a,), a) 741 if hasattr(a, "__sym_not__"): 742 return a.__sym_not__() 743 if isinstance(a, sympy.Basic): 744 return ~a # type: ignore[operator] 745 return not a 746 747 748def sym_float(a): 749 r"""SymInt-aware utility for float casting. 750 751 Args: 752 a (SymInt, SymFloat, or object): Object to cast 753 """ 754 if overrides.has_torch_function_unary(a): 755 return overrides.handle_torch_function(sym_float, (a,), a) 756 if isinstance(a, SymFloat): 757 return a 758 elif hasattr(a, "__sym_float__"): 759 return a.__sym_float__() 760 return builtins.float(a) # type: ignore[operator] 761 762 763def sym_int(a): 764 r"""SymInt-aware utility for int casting. 765 766 Args: 767 a (SymInt, SymFloat, or object): Object to cast 768 """ 769 if overrides.has_torch_function_unary(a): 770 return overrides.handle_torch_function(sym_int, (a,), a) 771 if isinstance(a, SymInt): 772 return a 773 elif isinstance(a, SymFloat): 774 return math.trunc(a) 775 return builtins.int(a) # type: ignore[operator] 776 777 778def sym_max(a, b): 779 """ 780 SymInt-aware utility for max which avoids branching on a < b. 781 Unlike builtins.max(), this only works for int/float, and it always 782 promotes to float if any argument is float (unlike builtins.max, which 783 will faithfully preserve the type of the input argument). 784 """ 785 if overrides.has_torch_function((a, b)): 786 return overrides.handle_torch_function(sym_max, (a, b), a, b) 787 if isinstance(a, (SymInt, SymFloat)): 788 return a.__sym_max__(b) 789 elif isinstance(b, (SymInt, SymFloat)): 790 # Due to promotion semantics, this is operator is commutative: 791 # max(1, 1.0) === max(1.0, 1) === 1.0 792 return b.__sym_max__(a) 793 # TODO: Probably can make bool work too, just lazy 794 795 all_types, float_types = __all_and_float_types() 796 797 assert isinstance(a, all_types), type(a) 798 assert isinstance(b, all_types), type(b) 799 if isinstance(a, float_types) or isinstance(b, float_types): 800 return builtins.float(builtins.max(a, b)) 801 else: 802 return builtins.max(a, b) 803 804 805def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]: 806 try: 807 import numpy as np 808 809 all_types: _Tuple[_Type, ...] = ( 810 np.integer, 811 np.floating, 812 builtins.int, 813 builtins.float, 814 ) 815 float_types: _Tuple[_Type, ...] = (np.floating, builtins.float) 816 except ModuleNotFoundError: 817 all_types = (builtins.int, builtins.float) 818 float_types = (builtins.float,) 819 820 return all_types, float_types 821 822 823def sym_min(a, b): 824 """SymInt-aware utility for min().""" 825 if overrides.has_torch_function((a, b)): 826 return overrides.handle_torch_function(sym_min, (a, b), a, b) 827 if isinstance(a, (SymInt, SymFloat)): 828 return a.__sym_min__(b) 829 elif isinstance(b, (SymInt, SymFloat)): 830 return b.__sym_min__(a) 831 832 all_types, float_types = __all_and_float_types() 833 834 assert isinstance(a, all_types), type(a) 835 assert isinstance(b, all_types), type(b) 836 if isinstance(a, float_types) or isinstance(b, float_types): 837 return builtins.float(builtins.min(a, b)) 838 else: 839 return builtins.min(a, b) 840 841 842# Drop in replacement for math.sqrt, math.sin, math.cos etc 843def _get_sym_math_fn(name): 844 def fn(a): 845 if overrides.has_torch_function_unary(a): 846 return overrides.handle_torch_function(fn, (a,), a) 847 if hasattr(a, f"__sym_{name}__"): 848 return getattr(a, f"__sym_{name}__")() 849 return getattr(math, name)(a) 850 851 return fn 852 853 854__fn, __name, __sym_name = None, "", "" 855for __name in ( 856 "sqrt", 857 "cos", 858 "cosh", 859 "sin", 860 "sinh", 861 "tan", 862 "tanh", 863 "asin", 864 "acos", 865 "atan", 866): 867 __sym_name = f"_sym_{__name}" 868 __fn = _get_sym_math_fn(__name) 869 __fn.__qualname__ = __fn.__name__ = __sym_name 870 globals()[__sym_name] = __fn 871 872del __fn, __name, __sym_name, _get_sym_math_fn 873 874# Adding temporary shortcut 875sym_sqrt = globals()["_sym_sqrt"] 876__all__.append("sym_sqrt") 877 878 879def sym_ite(b, t, f): 880 if overrides.has_torch_function((b, t, f)): 881 return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f) 882 assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f) 883 if isinstance(b, SymBool): 884 return b.__sym_ite__(t, f) 885 return t if b else f 886 887 888# Check to see if we can load C extensions, and if not provide some guidance 889# on what the problem might be. 890try: 891 # _initExtension is chosen (arbitrarily) as a sentinel. 892 from torch._C import _initExtension 893except ImportError: 894 import torch._C as _C_for_compiled_check 895 896 # The __file__ check only works for Python 3.7 and above. 897 if _C_for_compiled_check.__file__ is None: 898 raise ImportError( 899 textwrap.dedent( 900 """ 901 Failed to load PyTorch C extensions: 902 It appears that PyTorch has loaded the `torch/_C` folder 903 of the PyTorch repository rather than the C extensions which 904 are expected in the `torch._C` namespace. This can occur when 905 using the `install` workflow. e.g. 906 $ python setup.py install && python -c "import torch" 907 908 This error can generally be solved using the `develop` workflow 909 $ python setup.py develop && python -c "import torch" # This should succeed 910 or by running Python from a different directory. 911 """ 912 ).strip() 913 ) from None 914 raise # If __file__ is not None the cause is unknown, so just re-raise. 915 916# The torch._C submodule is already loaded via `from torch._C import *` above 917# Make an explicit reference to the _C submodule to appease linters 918from torch import _C as _C 919 920 921__name, __obj = "", None 922for __name in dir(_C): 923 if __name[0] != "_" and not __name.endswith("Base"): 924 __all__.append(__name) 925 __obj = getattr(_C, __name) 926 if callable(__obj) or inspect.isclass(__obj): 927 if __obj.__module__ != __name__: # "torch" 928 # TODO: fix their module from C++ side 929 if __name not in { 930 "DisableTorchFunctionSubclass", 931 "DisableTorchFunction", 932 "Generator", 933 }: 934 __obj.__module__ = __name__ # "torch" 935 elif __name == "TensorBase": 936 # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch. 937 delattr(sys.modules[__name__], __name) 938 939del __name, __obj 940 941if not TYPE_CHECKING: 942 # issue 38137 and python issue 43367. Submodules of a C extension are 943 # non-standard, and attributes of those submodules cannot be pickled since 944 # pickle expect to be able to import them as "from _C.sub import attr" 945 # which fails with "_C is not a package 946 def _import_extension_to_sys_modules(module, memo=None): 947 if memo is None: 948 memo = set() 949 if module in memo: 950 return 951 memo.add(module) 952 module_name = module.__name__ 953 for name in dir(module): 954 member = getattr(module, name) 955 member_name = getattr(member, "__name__", "") 956 if inspect.ismodule(member) and member_name.startswith(module_name): 957 sys.modules.setdefault(member_name, member) 958 # Recurse for submodules (e.g., `_C._dynamo.eval_frame`) 959 _import_extension_to_sys_modules(member, memo) 960 961 _import_extension_to_sys_modules(_C) 962 del _import_extension_to_sys_modules 963 964################################################################################ 965# Define basic utilities 966################################################################################ 967 968 969def typename(obj: _Any, /) -> str: 970 """ 971 String representation of the type of an object. 972 973 This function returns a fully qualified string representation of an object's type. 974 Args: 975 obj (object): The object whose type to represent 976 Returns: 977 str: the type of the object `o` 978 Example: 979 >>> x = torch.tensor([1, 2, 3]) 980 >>> torch.typename(x) 981 'torch.LongTensor' 982 >>> torch.typename(torch.nn.Parameter) 983 'torch.nn.parameter.Parameter' 984 """ 985 if isinstance(obj, torch.Tensor): 986 return obj.type() 987 988 module = getattr(obj, "__module__", "") or "" 989 qualname = "" 990 991 if hasattr(obj, "__qualname__"): 992 qualname = obj.__qualname__ 993 elif hasattr(obj, "__name__"): 994 qualname = obj.__name__ 995 else: 996 module = obj.__class__.__module__ or "" 997 qualname = obj.__class__.__qualname__ 998 999 if module in {"", "builtins"}: 1000 return qualname 1001 return f"{module}.{qualname}" 1002 1003 1004def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]: 1005 r"""Returns True if `obj` is a PyTorch tensor. 1006 1007 Note that this function is simply doing ``isinstance(obj, Tensor)``. 1008 Using that ``isinstance`` check is better for typechecking with mypy, 1009 and more explicit - so it's recommended to use that instead of 1010 ``is_tensor``. 1011 1012 Args: 1013 obj (object): Object to test 1014 Example:: 1015 1016 >>> x = torch.tensor([1, 2, 3]) 1017 >>> torch.is_tensor(x) 1018 True 1019 1020 """ 1021 return isinstance(obj, torch.Tensor) 1022 1023 1024def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]: 1025 r"""Returns True if `obj` is a PyTorch storage object. 1026 1027 Args: 1028 obj (Object): Object to test 1029 """ 1030 return type(obj) in _storage_classes 1031 1032 1033_GLOBAL_DEVICE_CONTEXT = threading.local() 1034 1035 1036def get_default_device() -> "torch.device": 1037 r"""Gets the default ``torch.Tensor`` to be allocated on ``device``""" 1038 global _GLOBAL_DEVICE_CONTEXT 1039 1040 if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): 1041 device = _GLOBAL_DEVICE_CONTEXT.device_context.device 1042 if device.index is not None: 1043 return device 1044 else: 1045 # TODO: Call like get_device_index() method corresponding to 1046 # each device type 1047 return torch.tensor([]).device 1048 else: 1049 return torch.device("cpu") 1050 1051 1052def set_default_device( 1053 device: _Optional[_Union["torch.device", str, builtins.int]], 1054) -> None: 1055 """Sets the default ``torch.Tensor`` to be allocated on ``device``. This 1056 does not affect factory function calls which are called with an explicit 1057 ``device`` argument. Factory calls will be performed as if they 1058 were passed ``device`` as an argument. 1059 1060 To only temporarily change the default device instead of setting it 1061 globally, use ``with torch.device(device):`` instead. 1062 1063 The default device is initially ``cpu``. If you set the default tensor 1064 device to another device (e.g., ``cuda``) without a device index, tensors 1065 will be allocated on whatever the current device for the device type, 1066 even after :func:`torch.cuda.set_device` is called. 1067 1068 .. warning:: 1069 1070 This function imposes a slight performance cost on every Python 1071 call to the torch API (not just factory functions). If this 1072 is causing problems for you, please comment on 1073 https://github.com/pytorch/pytorch/issues/92701 1074 1075 .. note:: 1076 1077 This doesn't affect functions that create tensors that share the same memory as the input, like: 1078 :func:`torch.from_numpy` and :func:`torch.frombuffer` 1079 1080 Args: 1081 device (device or string): the device to set as default 1082 1083 Example:: 1084 1085 >>> # xdoctest: +SKIP("requires cuda, changes global state") 1086 >>> torch.get_default_device() 1087 device(type='cpu') 1088 >>> torch.set_default_device('cuda') # current device is 0 1089 >>> torch.get_default_device() 1090 device(type='cuda', index=0) 1091 >>> torch.set_default_device('cuda') 1092 >>> torch.cuda.set_device('cuda:1') # current device is 1 1093 >>> torch.get_default_device() 1094 device(type='cuda', index=1) 1095 >>> torch.set_default_device('cuda:1') 1096 >>> torch.get_default_device() 1097 device(type='cuda', index=1) 1098 1099 """ 1100 global _GLOBAL_DEVICE_CONTEXT 1101 if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): 1102 device_context = _GLOBAL_DEVICE_CONTEXT.device_context 1103 if device_context is not None: 1104 device_context.__exit__(None, None, None) 1105 1106 if device is None: 1107 device_context = None 1108 else: 1109 from torch.utils._device import DeviceContext 1110 1111 device_context = DeviceContext(device) 1112 device_context.__enter__() 1113 _GLOBAL_DEVICE_CONTEXT.device_context = device_context 1114 1115 1116def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None: 1117 r""" 1118 .. warning:: 1119 1120 This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and 1121 :func:`torch.set_default_device()` as alternatives. 1122 1123 Sets the default ``torch.Tensor`` type to floating point tensor type 1124 ``t``. This type will also be used as default floating point type for 1125 type inference in :func:`torch.tensor`. 1126 1127 The default floating point tensor type is initially ``torch.FloatTensor``. 1128 1129 Args: 1130 t (type or string): the floating point tensor type or its name 1131 1132 Example:: 1133 1134 >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") 1135 >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 1136 torch.float32 1137 >>> torch.set_default_tensor_type(torch.DoubleTensor) 1138 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor 1139 torch.float64 1140 1141 """ 1142 if isinstance(t, str): 1143 t = _import_dotted_name(t) 1144 _C._set_default_tensor_type(t) 1145 1146 1147def set_default_dtype(d: "torch.dtype", /) -> None: 1148 r""" 1149 1150 Sets the default floating point dtype to :attr:`d`. Supports floating point dtype 1151 as inputs. Other dtypes will cause torch to raise an exception. 1152 1153 When PyTorch is initialized its default floating point dtype is torch.float32, 1154 and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like 1155 type inference. The default floating point dtype is used to: 1156 1157 1. Implicitly determine the default complex dtype. When the default floating type is float16, 1158 the default complex dtype is complex32. For float32, the default complex dtype is complex64. 1159 For float64, it is complex128. For bfloat16, an exception will be raised because 1160 there is no corresponding complex type for bfloat16. 1161 2. Infer the dtype for tensors constructed using Python floats or complex Python 1162 numbers. See examples below. 1163 3. Determine the result of type promotion between bool and integer tensors and 1164 Python floats and complex Python numbers. 1165 1166 Args: 1167 d (:class:`torch.dtype`): the floating point dtype to make the default. 1168 1169 Example: 1170 >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") 1171 >>> # initial default for floating point is torch.float32 1172 >>> # Python floats are interpreted as float32 1173 >>> torch.tensor([1.2, 3]).dtype 1174 torch.float32 1175 >>> # initial default for floating point is torch.complex64 1176 >>> # Complex Python numbers are interpreted as complex64 1177 >>> torch.tensor([1.2, 3j]).dtype 1178 torch.complex64 1179 1180 >>> torch.set_default_dtype(torch.float64) 1181 >>> # Python floats are now interpreted as float64 1182 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor 1183 torch.float64 1184 >>> # Complex Python numbers are now interpreted as complex128 1185 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor 1186 torch.complex128 1187 1188 >>> torch.set_default_dtype(torch.float16) 1189 >>> # Python floats are now interpreted as float16 1190 >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor 1191 torch.float16 1192 >>> # Complex Python numbers are now interpreted as complex128 1193 >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor 1194 torch.complex32 1195 1196 """ 1197 _C._set_default_dtype(d) 1198 1199 1200def use_deterministic_algorithms( 1201 mode: builtins.bool, 1202 *, 1203 warn_only: builtins.bool = False, 1204) -> None: 1205 r"""Sets whether PyTorch operations must use "deterministic" 1206 algorithms. That is, algorithms which, given the same input, and when 1207 run on the same software and hardware, always produce the same output. 1208 When enabled, operations will use deterministic algorithms when available, 1209 and if only nondeterministic algorithms are available they will throw a 1210 :class:`RuntimeError` when called. 1211 1212 .. note:: This setting alone is not always enough to make an application 1213 reproducible. Refer to :ref:`reproducibility` for more information. 1214 1215 .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative 1216 interface for this feature. 1217 1218 The following normally-nondeterministic operations will act 1219 deterministically when ``mode=True``: 1220 1221 * :class:`torch.nn.Conv1d` when called on CUDA tensor 1222 * :class:`torch.nn.Conv2d` when called on CUDA tensor 1223 * :class:`torch.nn.Conv3d` when called on CUDA tensor 1224 * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor 1225 * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor 1226 * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor 1227 * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor 1228 * :func:`torch.bmm` when called on sparse-dense CUDA tensors 1229 * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor 1230 and the index is a list of tensors 1231 * :func:`torch.Tensor.index_put` with ``accumulate=False`` 1232 * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU 1233 tensor 1234 * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU 1235 tensor 1236 * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor 1237 * :func:`torch.gather` when called on a CUDA tensor that requires grad 1238 * :func:`torch.index_add` when called on CUDA tensor 1239 * :func:`torch.index_select` when attempting to differentiate a CUDA tensor 1240 * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor 1241 * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor 1242 * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor 1243 * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor 1244 1245 The following normally-nondeterministic operations will throw a 1246 :class:`RuntimeError` when ``mode=True``: 1247 1248 * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor 1249 * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor 1250 * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor 1251 * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor 1252 * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor 1253 * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor 1254 * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor 1255 * :class:`torch.nn.MaxUnpool1d` 1256 * :class:`torch.nn.MaxUnpool2d` 1257 * :class:`torch.nn.MaxUnpool3d` 1258 * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor 1259 and one of the following modes is used: 1260 1261 - ``linear`` 1262 - ``bilinear`` 1263 - ``bicubic`` 1264 - ``trilinear`` 1265 1266 * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor 1267 * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor 1268 * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor 1269 * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor 1270 * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor 1271 * :class:`torch.nn.NLLLoss` when called on a CUDA tensor 1272 * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor 1273 * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when 1274 ``mode='max'`` 1275 * :func:`torch.Tensor.put_` when ``accumulate=False`` 1276 * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor 1277 * :func:`torch.histc` when called on a CUDA tensor 1278 * :func:`torch.bincount` when called on a CUDA tensor and ``weights`` 1279 tensor is given 1280 * :func:`torch.kthvalue` with called on a CUDA tensor 1281 * :func:`torch.median` with indices output when called on a CUDA tensor 1282 * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor 1283 * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex 1284 * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor 1285 * :func:`torch.Tensor.resize_` when called with a quantized tensor 1286 1287 In addition, several operations fill uninitialized memory when this setting 1288 is turned on and when 1289 :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on. 1290 See the documentation for that attribute for more information. 1291 1292 A handful of CUDA operations are nondeterministic if the CUDA version is 1293 10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8`` 1294 or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more 1295 details: `<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_ 1296 If one of these environment variable configurations is not set, a :class:`RuntimeError` 1297 will be raised from these operations when called with CUDA tensors: 1298 1299 * :func:`torch.mm` 1300 * :func:`torch.mv` 1301 * :func:`torch.bmm` 1302 1303 Note that deterministic operations tend to have worse performance than 1304 nondeterministic operations. 1305 1306 .. note:: 1307 1308 This flag does not detect or prevent nondeterministic behavior caused 1309 by calling an inplace operation on a tensor with an internal memory 1310 overlap or by giving such a tensor as the :attr:`out` argument for an 1311 operation. In these cases, multiple writes of different data may target 1312 a single memory location, and the order of writes is not guaranteed. 1313 1314 Args: 1315 mode (:class:`bool`): If True, makes potentially nondeterministic 1316 operations switch to a deterministic algorithm or throw a runtime 1317 error. If False, allows nondeterministic operations. 1318 1319 Keyword args: 1320 warn_only (:class:`bool`, optional): If True, operations that do not 1321 have a deterministic implementation will throw a warning instead of 1322 an error. Default: ``False`` 1323 1324 Example:: 1325 1326 >>> # xdoctest: +SKIP 1327 >>> torch.use_deterministic_algorithms(True) 1328 1329 # Forward mode nondeterministic error 1330 >>> torch.randn(10, device='cuda').kthvalue(1) 1331 ... 1332 RuntimeError: kthvalue CUDA does not have a deterministic implementation... 1333 1334 # Backward mode nondeterministic error 1335 >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward() 1336 ... 1337 RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation... 1338 """ 1339 _C._set_deterministic_algorithms(mode, warn_only=warn_only) 1340 1341 1342def are_deterministic_algorithms_enabled() -> builtins.bool: 1343 r"""Returns True if the global deterministic flag is turned on. Refer to 1344 :func:`torch.use_deterministic_algorithms` documentation for more details. 1345 """ 1346 return _C._get_deterministic_algorithms() 1347 1348 1349def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: 1350 r"""Returns True if the global deterministic flag is set to warn only. 1351 Refer to :func:`torch.use_deterministic_algorithms` documentation for more 1352 details. 1353 """ 1354 return _C._get_deterministic_algorithms_warn_only() 1355 1356 1357def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None: 1358 r"""Sets the debug mode for deterministic operations. 1359 1360 .. note:: This is an alternative interface for 1361 :func:`torch.use_deterministic_algorithms`. Refer to that function's 1362 documentation for details about affected operations. 1363 1364 Args: 1365 debug_mode(str or int): If "default" or 0, don't error or warn on 1366 nondeterministic operations. If "warn" or 1, warn on 1367 nondeterministic operations. If "error" or 2, error on 1368 nondeterministic operations. 1369 """ 1370 1371 # NOTE: builtins.int is used here because int in this scope resolves 1372 # to torch.int 1373 if not isinstance(debug_mode, (builtins.int, str)): 1374 raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}") 1375 1376 if isinstance(debug_mode, str): 1377 if debug_mode == "default": 1378 debug_mode = 0 1379 elif debug_mode == "warn": 1380 debug_mode = 1 1381 elif debug_mode == "error": 1382 debug_mode = 2 1383 else: 1384 raise RuntimeError( 1385 "invalid value of debug_mode, expected one of `default`, " 1386 f"`warn`, `error`, but got {debug_mode}" 1387 ) 1388 1389 if debug_mode == 0: 1390 _C._set_deterministic_algorithms(False) 1391 elif debug_mode == 1: 1392 _C._set_deterministic_algorithms(True, warn_only=True) 1393 elif debug_mode == 2: 1394 _C._set_deterministic_algorithms(True) 1395 else: 1396 raise RuntimeError( 1397 "invalid value of debug_mode, expected 0, 1, or 2, " f"but got {debug_mode}" 1398 ) 1399 1400 1401def get_deterministic_debug_mode() -> builtins.int: 1402 r"""Returns the current value of the debug mode for deterministic 1403 operations. Refer to :func:`torch.set_deterministic_debug_mode` 1404 documentation for more details. 1405 """ 1406 1407 if _C._get_deterministic_algorithms(): 1408 if _C._get_deterministic_algorithms_warn_only(): 1409 return 1 1410 else: 1411 return 2 1412 else: 1413 return 0 1414 1415 1416def get_float32_matmul_precision() -> str: 1417 r"""Returns the current value of float32 matrix multiplication precision. Refer to 1418 :func:`torch.set_float32_matmul_precision` documentation for more details. 1419 """ 1420 return _C._get_float32_matmul_precision() 1421 1422 1423def set_float32_matmul_precision(precision: str) -> None: 1424 r"""Sets the internal precision of float32 matrix multiplications. 1425 1426 Running float32 matrix multiplications in lower precision may significantly increase 1427 performance, and in some programs the loss of precision has a negligible impact. 1428 1429 Supports three settings: 1430 1431 * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa 1432 bits with 23 bits explicitly stored) for internal computations. 1433 * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10 1434 mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers 1435 (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication 1436 algorithms are available. Otherwise float32 matrix multiplications are computed 1437 as if the precision is "highest". See below for more information on the bfloat16 1438 approach. 1439 * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa 1440 bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm 1441 using that datatype internally is available. Otherwise float32 1442 matrix multiplications are computed as if the precision is "high". 1443 1444 When using "high" precision, float32 multiplications may use a bfloat16-based algorithm 1445 that is more complicated than simply truncating to some smaller number mantissa bits 1446 (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete 1447 description of this algorithm. To briefly explain here, the first step is to realize 1448 that we can perfectly encode a single float32 number as the sum of three bfloat16 1449 numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the 1450 same number of exponent bits). This means that the product of two float32 numbers can 1451 be exactly given by the sum of nine products of bfloat16 numbers. We can then trade 1452 accuracy for speed by dropping some of these products. The "high" precision algorithm 1453 specifically keeps only the three most significant products, which conveniently excludes 1454 all of the products involving the last 8 mantissa bits of either input. This means that 1455 we can represent our inputs as the sum of two bfloat16 numbers rather than three. 1456 Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than 1457 float32 ones, it's faster to do three multiplications and 2 additions with bfloat16 1458 precision than it is to do a single multiplication with float32 precision. 1459 1460 .. [Henry2019] http://arxiv.org/abs/1904.06376 1461 1462 .. note:: 1463 1464 This does not change the output dtype of float32 matrix multiplications, 1465 it controls how the internal computation of the matrix multiplication is performed. 1466 1467 .. note:: 1468 1469 This does not change the precision of convolution operations. Other flags, 1470 like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution 1471 operations. 1472 1473 .. note:: 1474 1475 This flag currently only affects one native device type: CUDA. 1476 If "high" or "medium" are set then the TensorFloat32 datatype will be used 1477 when computing float32 matrix multiplications, equivalent to setting 1478 `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default) 1479 is set then the float32 datatype is used for internal computations, equivalent 1480 to setting `torch.backends.cuda.matmul.allow_tf32 = False`. 1481 1482 Args: 1483 precision(str): can be set to "highest" (default), "high", or "medium" (see above). 1484 1485 """ 1486 _C._set_float32_matmul_precision(precision) 1487 1488 1489def set_warn_always(b: builtins.bool, /) -> None: 1490 r"""When this flag is False (default) then some PyTorch warnings may only 1491 appear once per process. This helps avoid excessive warning information. 1492 Setting it to True causes these warnings to always appear, which may be 1493 helpful when debugging. 1494 1495 Args: 1496 b (:class:`bool`): If True, force warnings to always be emitted 1497 If False, set to the default behaviour 1498 """ 1499 _C._set_warnAlways(b) 1500 1501 1502def is_warn_always_enabled() -> builtins.bool: 1503 r"""Returns True if the global warn_always flag is turned on. Refer to 1504 :func:`torch.set_warn_always` documentation for more details. 1505 """ 1506 return _C._get_warnAlways() 1507 1508 1509################################################################################ 1510# Define error checking functions 1511################################################################################ 1512 1513# These error checking functions must be kept consistent with their C++ 1514# equivalents. Their C++ equivalents are mentioned where applicable. 1515 1516 1517def _check_with( 1518 error_type, 1519 cond: _Union[builtins.bool, SymBool], 1520 message: _Callable[[], str], 1521): # noqa: F811 1522 if not isinstance(cond, (builtins.bool, SymBool)): 1523 raise TypeError(f"cond must be a bool, but got {type(cond)}") 1524 1525 from torch.fx.experimental.symbolic_shapes import expect_true 1526 1527 if expect_true(cond): 1528 return 1529 1530 # error_type must be a subclass of Exception and not subclass of Warning 1531 assert issubclass(error_type, Exception) and not issubclass(error_type, Warning) 1532 1533 if message is None: 1534 message_evaluated = ( 1535 "Expected cond to be True, but got False. (Could this error " 1536 "message be improved? If so, please report an enhancement request " 1537 "to PyTorch.)" 1538 ) 1539 1540 else: 1541 if not callable(message): 1542 raise TypeError("message must be a callable") 1543 1544 message_evaluated = str(message()) 1545 1546 raise error_type(message_evaluated) 1547 1548 1549def _check(cond, message=None): # noqa: F811 1550 r"""Throws error containing an optional message if the specified condition 1551 is False. 1552 1553 Error type: ``RuntimeError`` 1554 1555 C++ equivalent: ``TORCH_CHECK`` 1556 1557 Args: 1558 cond (:class:`bool`): If False, throw error 1559 1560 message (Callable, optional): Callable that returns either a string or 1561 an object that has a ``__str__()`` method to be used as the error 1562 message. Default: ``None`` 1563 """ 1564 _check_with(RuntimeError, cond, message) 1565 1566 1567def _check_is_size(i, message=None): 1568 """Checks that a given integer is a valid size (i.e., is non-negative). 1569 You should use this over _check(i >= 0) because we can use the semantic 1570 information (that i is a size) to make some further inferences in case 1571 i is an unbacked SymInt. 1572 1573 NB: Do NOT use this in contexts where a -1 size would be valid (indicating 1574 to infer the size from context, or if you should wrap-around or truncate). 1575 Only use this if the only valid value is an honest to goodness size. 1576 """ 1577 # This is responsible for the expect_true 1578 _check(i >= 0, message) 1579 from torch.fx.experimental.symbolic_shapes import _advise_is_size 1580 1581 _advise_is_size(i) 1582 1583 1584def _check_index(cond, message=None): # noqa: F811 1585 r"""Throws error containing an optional message if the specified condition 1586 is False. 1587 1588 Error type: ``IndexError`` 1589 1590 C++ equivalent: ``TORCH_CHECK_INDEX`` 1591 1592 Args: 1593 cond (:class:`bool`): If False, throw error 1594 1595 message (Callable, optional): Callable that returns either a string or 1596 an object that has a ``__str__()`` method to be used as the error 1597 message. Default: ``None`` 1598 """ 1599 _check_with(IndexError, cond, message) 1600 1601 1602def _check_value(cond, message=None): # noqa: F811 1603 r"""Throws error containing an optional message if the specified condition 1604 is False. 1605 1606 Error type: ``ValueError`` 1607 1608 C++ equivalent: ``TORCH_CHECK_VALUE`` 1609 1610 Args: 1611 cond (:class:`bool`): If False, throw error 1612 1613 message (Callable, optional): Callable that returns either a string or 1614 an object that has a ``__str__()`` method to be used as the error 1615 message. Default: ``None`` 1616 """ 1617 _check_with(ValueError, cond, message) 1618 1619 1620def _check_type(cond, message=None): # noqa: F811 1621 r"""Throws error containing an optional message if the specified condition 1622 is False. 1623 1624 Error type: ``TypeError`` 1625 1626 C++ equivalent: ``TORCH_CHECK_TYPE`` 1627 1628 Args: 1629 cond (:class:`bool`): If False, throw error 1630 1631 message (Callable, optional): Callable that returns either a string or 1632 an object that has a ``__str__()`` method to be used as the error 1633 message. Default: ``None`` 1634 """ 1635 _check_with(TypeError, cond, message) 1636 1637 1638def _check_not_implemented(cond, message=None): # noqa: F811 1639 r"""Throws error containing an optional message if the specified condition 1640 is False. 1641 1642 Error type: ``NotImplementedError`` 1643 1644 C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED`` 1645 1646 Args: 1647 cond (:class:`bool`): If False, throw error 1648 1649 message (Callable, optional): Callable that returns either a string or 1650 an object that has a ``__str__()`` method to be used as the error 1651 message. Default: ``None`` 1652 """ 1653 _check_with(NotImplementedError, cond, message) 1654 1655 1656def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 1657 if not is_tensor(cond): 1658 raise TypeError(f"cond must be a tensor, but got {type(cond)}") 1659 1660 if not cond.dtype == torch.bool: 1661 raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}") 1662 1663 _check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type] 1664 1665 1666# C++ equivalent: `TORCH_CHECK_TENSOR_ALL` 1667def _check_tensor_all(cond, message=None): # noqa: F811 1668 r"""Throws error containing an optional message if the specified condition 1669 is False. 1670 1671 Error type: ``RuntimeError`` 1672 1673 C++ equivalent: ``TORCH_CHECK_TENSOR_ALL`` 1674 1675 Args: 1676 cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any 1677 element is ``False``, throw error 1678 1679 message (Callable, optional): Callable that returns either a string or 1680 an object that has a ``__str__()`` method to be used as the error 1681 message. Default: ``None`` 1682 """ 1683 _check_tensor_all_with(RuntimeError, cond, message) 1684 1685 1686################################################################################ 1687# Define numeric constants 1688################################################################################ 1689 1690# For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and 1691# NumPy consistency (https://numpy.org/devdocs/reference/constants.html) 1692from math import e, inf, nan, pi 1693 1694 1695newaxis: None = None 1696 1697__all__.extend(["e", "pi", "nan", "inf", "newaxis"]) 1698 1699################################################################################ 1700# Define Storage and Tensor classes 1701################################################################################ 1702 1703from torch._tensor import Tensor # usort: skip 1704 1705# needs to be after torch.Tensor is defined to avoid circular dependencies 1706from torch import storage as storage # usort: skip 1707from torch.storage import ( 1708 _LegacyStorage, 1709 _StorageBase, 1710 _warn_typed_storage_removal, 1711 TypedStorage, 1712 UntypedStorage, 1713) 1714 1715 1716# NOTE: New <type>Storage classes should never be added. When adding a new 1717# dtype, use torch.storage.TypedStorage directly. 1718class ByteStorage(_LegacyStorage): 1719 @classproperty 1720 def dtype(self): 1721 _warn_typed_storage_removal(stacklevel=3) 1722 return self._dtype 1723 1724 @classproperty 1725 def _dtype(self): 1726 return torch.uint8 1727 1728 1729class DoubleStorage(_LegacyStorage): 1730 @classproperty 1731 def dtype(self): 1732 _warn_typed_storage_removal(stacklevel=3) 1733 return self._dtype 1734 1735 @classproperty 1736 def _dtype(self): 1737 return torch.double 1738 1739 1740class FloatStorage(_LegacyStorage): 1741 @classproperty 1742 def dtype(self): 1743 _warn_typed_storage_removal(stacklevel=3) 1744 return self._dtype 1745 1746 @classproperty 1747 def _dtype(self): 1748 return torch.float 1749 1750 1751class HalfStorage(_LegacyStorage): 1752 @classproperty 1753 def dtype(self): 1754 _warn_typed_storage_removal(stacklevel=3) 1755 return self._dtype 1756 1757 @classproperty 1758 def _dtype(self): 1759 return torch.half 1760 1761 1762class LongStorage(_LegacyStorage): 1763 @classproperty 1764 def dtype(self): 1765 _warn_typed_storage_removal(stacklevel=3) 1766 return self._dtype 1767 1768 @classproperty 1769 def _dtype(self): 1770 return torch.long 1771 1772 1773class IntStorage(_LegacyStorage): 1774 @classproperty 1775 def dtype(self): 1776 _warn_typed_storage_removal(stacklevel=3) 1777 return self._dtype 1778 1779 @classproperty 1780 def _dtype(self): 1781 return torch.int 1782 1783 1784class ShortStorage(_LegacyStorage): 1785 @classproperty 1786 def dtype(self): 1787 _warn_typed_storage_removal(stacklevel=3) 1788 return self._dtype 1789 1790 @classproperty 1791 def _dtype(self): 1792 return torch.short 1793 1794 1795class CharStorage(_LegacyStorage): 1796 @classproperty 1797 def dtype(self): 1798 _warn_typed_storage_removal(stacklevel=3) 1799 return self._dtype 1800 1801 @classproperty 1802 def _dtype(self): 1803 return torch.int8 1804 1805 1806class BoolStorage(_LegacyStorage): 1807 @classproperty 1808 def dtype(self): 1809 _warn_typed_storage_removal(stacklevel=3) 1810 return self._dtype 1811 1812 @classproperty 1813 def _dtype(self): 1814 return torch.bool 1815 1816 1817class BFloat16Storage(_LegacyStorage): 1818 @classproperty 1819 def dtype(self): 1820 _warn_typed_storage_removal(stacklevel=3) 1821 return self._dtype 1822 1823 @classproperty 1824 def _dtype(self): 1825 return torch.bfloat16 1826 1827 1828class ComplexDoubleStorage(_LegacyStorage): 1829 @classproperty 1830 def dtype(self): 1831 _warn_typed_storage_removal(stacklevel=3) 1832 return self._dtype 1833 1834 @classproperty 1835 def _dtype(self): 1836 return torch.cdouble 1837 1838 1839class ComplexFloatStorage(_LegacyStorage): 1840 @classproperty 1841 def dtype(self): 1842 _warn_typed_storage_removal(stacklevel=3) 1843 return self._dtype 1844 1845 @classproperty 1846 def _dtype(self): 1847 return torch.cfloat 1848 1849 1850class QUInt8Storage(_LegacyStorage): 1851 @classproperty 1852 def dtype(self): 1853 _warn_typed_storage_removal(stacklevel=3) 1854 return self._dtype 1855 1856 @classproperty 1857 def _dtype(self): 1858 return torch.quint8 1859 1860 1861class QInt8Storage(_LegacyStorage): 1862 @classproperty 1863 def dtype(self): 1864 _warn_typed_storage_removal(stacklevel=3) 1865 return self._dtype 1866 1867 @classproperty 1868 def _dtype(self): 1869 return torch.qint8 1870 1871 1872class QInt32Storage(_LegacyStorage): 1873 @classproperty 1874 def dtype(self): 1875 _warn_typed_storage_removal(stacklevel=3) 1876 return self._dtype 1877 1878 @classproperty 1879 def _dtype(self): 1880 return torch.qint32 1881 1882 1883class QUInt4x2Storage(_LegacyStorage): 1884 @classproperty 1885 def dtype(self): 1886 _warn_typed_storage_removal(stacklevel=3) 1887 return self._dtype 1888 1889 @classproperty 1890 def _dtype(self): 1891 return torch.quint4x2 1892 1893 1894class QUInt2x4Storage(_LegacyStorage): 1895 @classproperty 1896 def dtype(self): 1897 _warn_typed_storage_removal(stacklevel=3) 1898 return self._dtype 1899 1900 @classproperty 1901 def _dtype(self): 1902 return torch.quint2x4 1903 1904 1905_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = { 1906 UntypedStorage, 1907 DoubleStorage, 1908 FloatStorage, 1909 LongStorage, 1910 IntStorage, 1911 ShortStorage, 1912 CharStorage, 1913 ByteStorage, 1914 HalfStorage, 1915 BoolStorage, 1916 QUInt8Storage, 1917 QInt8Storage, 1918 QInt32Storage, 1919 BFloat16Storage, 1920 ComplexFloatStorage, 1921 ComplexDoubleStorage, 1922 QUInt4x2Storage, 1923 QUInt2x4Storage, 1924 TypedStorage, 1925} 1926 1927# The _tensor_classes set is initialized by the call to initialize_python_bindings. 1928_tensor_classes: _Set[_Type["torch.Tensor"]] = set() 1929 1930# If you edit these imports, please update torch/__init__.py.in as well 1931from torch import amp as amp, random as random, serialization as serialization 1932from torch._tensor_str import set_printoptions 1933from torch.amp import autocast, GradScaler 1934from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state 1935from torch.serialization import load, save 1936 1937 1938################################################################################ 1939# Initialize extension 1940################################################################################ 1941 1942 1943# Shared memory manager needs to know the exact location of manager executable 1944def _manager_path(): 1945 if _running_with_deploy() or platform.system() == "Windows": 1946 return b"" 1947 path = get_file_path("torch", "bin", "torch_shm_manager") 1948 prepare_multiprocessing_environment(get_file_path("torch")) 1949 if not os.path.exists(path): 1950 raise RuntimeError("Unable to find torch_shm_manager at " + path) 1951 return path.encode("utf-8") 1952 1953 1954_C._initExtension(_manager_path()) 1955 1956del _manager_path 1957 1958# Appease the type checker: it can't deal with direct setting of globals(). 1959# Note that we will see "too many" functions when reexporting this way; there 1960# is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions 1961# so that this import is good enough 1962if TYPE_CHECKING: 1963 # Some type signatures pulled in from _VariableFunctions here clash with 1964 # signatures already imported. For now these clashes are ignored; see 1965 # PR #43339 for details. 1966 from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403 1967 1968 # Fixup segment_reduce visibility 1969 _segment_reduce = segment_reduce 1970 del segment_reduce # noqa: F821 1971 1972# Ops not to be exposed in `torch` namespace, 1973# mostly helper ops. 1974PRIVATE_OPS = ("unique_dim",) 1975 1976__name, __obj = "", None 1977for __name in dir(_C._VariableFunctions): 1978 if __name.startswith("__") or __name in PRIVATE_OPS: 1979 continue 1980 __obj = getattr(_C._VariableFunctions, __name) 1981 __obj.__module__ = __name__ # "torch" 1982 # Hide some APIs that should not be public 1983 if __name == "segment_reduce": 1984 # TODO: Once the undocumented FC window is passed, remove the line bellow 1985 globals()[__name] = __obj 1986 __name = "_" + __name 1987 globals()[__name] = __obj 1988 if not __name.startswith("_"): 1989 __all__.append(__name) 1990 1991del __name, __obj 1992 1993################################################################################ 1994# Add torch.dtype instances to the public API 1995################################################################################ 1996 1997import torch 1998 1999 2000__all__.extend( 2001 name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype) 2002) 2003 2004################################################################################ 2005# Import TorchDynamo's lazy APIs to avoid circular dependenices 2006################################################################################ 2007 2008# needs to be before from torch.functional import * to avoid circular dependencies 2009from torch._compile import _disable_dynamo # usort: skip 2010 2011################################################################################ 2012# Import interface functions defined in Python 2013################################################################################ 2014 2015# needs to be after the above ATen bindings so we can overwrite from Python side 2016from torch import _VF as _VF, functional as functional # usort: skip 2017from torch.functional import * # usort: skip # noqa: F403 2018 2019################################################################################ 2020# Remove unnecessary members 2021################################################################################ 2022 2023del _StorageBase 2024del _LegacyStorage 2025 2026################################################################################ 2027# Define _assert 2028################################################################################ 2029 2030 2031# needs to be before the submodule imports to avoid circular dependencies 2032def _assert(condition, message): 2033 r"""A wrapper around Python's assert which is symbolically traceable.""" 2034 if type(condition) is not torch.Tensor and overrides.has_torch_function( 2035 (condition,) 2036 ): 2037 return overrides.handle_torch_function( 2038 _assert, (condition,), condition, message 2039 ) 2040 assert condition, message 2041 2042 2043################################################################################ 2044# Import most common subpackages 2045################################################################################ 2046 2047# Use the redundant form so that type checkers know that these are a part of 2048# the public API. The "regular" import lines are there solely for the runtime 2049# side effect of adding to the imported module's members for other users. 2050 2051# needs to be before import torch.nn as nn to avoid circular dependencies 2052from torch.autograd import ( # usort: skip 2053 enable_grad as enable_grad, 2054 inference_mode as inference_mode, 2055 no_grad as no_grad, 2056 set_grad_enabled as set_grad_enabled, 2057) 2058 2059from torch import ( 2060 __config__ as __config__, 2061 __future__ as __future__, 2062 _awaits as _awaits, 2063 autograd as autograd, 2064 backends as backends, 2065 cpu as cpu, 2066 cuda as cuda, 2067 distributed as distributed, 2068 distributions as distributions, 2069 fft as fft, 2070 futures as futures, 2071 hub as hub, 2072 jit as jit, 2073 linalg as linalg, 2074 mps as mps, 2075 mtia as mtia, 2076 multiprocessing as multiprocessing, 2077 nested as nested, 2078 nn as nn, 2079 optim as optim, 2080 overrides as overrides, 2081 profiler as profiler, 2082 sparse as sparse, 2083 special as special, 2084 testing as testing, 2085 types as types, 2086 utils as utils, 2087 xpu as xpu, 2088) 2089from torch.signal import windows as windows 2090 2091 2092# Quantized, sparse, AO, etc. should be last to get imported, as nothing 2093# is expected to depend on them. 2094from torch import ao as ao # usort: skip 2095 2096# nn.quant* depends on ao -- so should be after those. 2097import torch.nn.intrinsic 2098import torch.nn.qat 2099import torch.nn.quantizable 2100import torch.nn.quantized 2101 2102 2103_C._init_names(list(_storage_classes)) 2104 2105# attach docstrings to torch and tensor functions 2106from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs 2107 2108 2109del _torch_docs, _tensor_docs, _storage_docs, _size_docs 2110 2111 2112def compiled_with_cxx11_abi() -> builtins.bool: 2113 r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1""" 2114 return _C._GLIBCXX_USE_CXX11_ABI 2115 2116 2117from torch import _library as _library, _ops as _ops 2118 2119 2120# Import the ops and classes "namespace" 2121from torch._ops import ops as ops # usort: skip 2122from torch._classes import classes as classes # usort: skip 2123 2124sys.modules.setdefault(f"{__name__}.ops", ops) 2125sys.modules.setdefault(f"{__name__}.classes", classes) 2126 2127# quantization depends on torch.fx and torch.ops 2128# Import quantization 2129from torch import quantization as quantization # usort: skip 2130 2131# Import the quasi random sampler 2132from torch import quasirandom as quasirandom # usort: skip 2133 2134# If you are seeing this, it means that this call site was not checked if 2135# the memory format could be preserved, and it was switched to old default 2136# behaviour of contiguous 2137legacy_contiguous_format = contiguous_format # defined by _C._initExtension() 2138 2139# Register fork handler to initialize OpenMP in child processes (see gh-28389) 2140from torch.multiprocessing._atfork import register_after_fork 2141 2142 2143register_after_fork(torch.get_num_threads) 2144del register_after_fork 2145 2146# Import tools that require fully imported torch (for applying 2147# torch.jit.script as a decorator, for instance): 2148from torch._lobpcg import lobpcg as lobpcg 2149 2150 2151# These were previously defined in native_functions.yaml and appeared on the 2152# `torch` namespace, but we moved them to c10 dispatch to facilitate custom 2153# class usage. We add these lines here to preserve backward compatibility. 2154quantized_lstm = ops.aten.quantized_lstm 2155quantized_gru = ops.aten.quantized_gru 2156 2157# Import experimental masked operations support. See 2158# [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more 2159# information. 2160from torch import masked as masked 2161 2162# Import removed ops with error message about removal 2163from torch._linalg_utils import ( # type: ignore[misc] 2164 _symeig as symeig, 2165 eig, 2166 lstsq, 2167 matrix_rank, 2168 solve, 2169) 2170from torch.utils.dlpack import from_dlpack, to_dlpack 2171 2172 2173class _TorchCompileInductorWrapper: 2174 compiler_name = "inductor" 2175 2176 def __init__(self, mode, options, dynamic): 2177 self.config: _Dict[str, _Any] = {} 2178 self.dynamic = dynamic 2179 self.apply_mode(mode) 2180 self.apply_options(options) 2181 2182 if self.config.get("triton.cudagraphs", False): 2183 os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" 2184 # FIXME: CUDA Graph does not work well with CUPTI teardown. 2185 # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) 2186 # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) 2187 # Workaround: turn off CUPTI teardown when using CUDA Graphs. 2188 os.environ["TEARDOWN_CUPTI"] = "0" 2189 2190 def __eq__(self, other): 2191 return ( 2192 isinstance(other, _TorchCompileInductorWrapper) 2193 and self.config == other.config 2194 and self.dynamic == other.dynamic 2195 ) 2196 2197 def apply_mode(self, mode: _Optional[str]): 2198 if mode is None or mode == "default": 2199 pass 2200 elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}: 2201 from torch._inductor import list_mode_options 2202 2203 self.apply_options(list_mode_options(mode, self.dynamic)) 2204 else: 2205 raise RuntimeError( 2206 f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs" 2207 ) 2208 2209 def apply_options(self, options: _Optional[_Dict[str, _Any]]): 2210 if not options: 2211 return 2212 2213 from torch._inductor import config 2214 2215 current_config: _Dict[str, _Any] = config.shallow_copy_dict() 2216 2217 for key, val in options.items(): 2218 attr_name = key.replace("-", "_") 2219 if attr_name not in current_config: 2220 raise RuntimeError( 2221 f"Unexpected optimization option {key}, known options are {list(current_config.keys())}" 2222 ) 2223 if type(val) is not type(current_config[attr_name]): 2224 val_type_str = type(val).__name__ 2225 expected_type_str = type(current_config[attr_name]).__name__ 2226 raise RuntimeError( 2227 f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}" 2228 ) 2229 self.config[attr_name] = val 2230 2231 def __call__(self, model_, inputs_): 2232 from torch._inductor.compile_fx import compile_fx 2233 2234 return compile_fx(model_, inputs_, config_patches=self.config) 2235 2236 def get_compiler_config(self): 2237 from torch._inductor.compile_fx import get_patched_config_dict 2238 2239 return get_patched_config_dict(config_patches=self.config) 2240 2241 def reset(self): 2242 from torch._inductor import config 2243 2244 if "triton.cudagraphs" in self.config or config.triton.cudagraphs: 2245 if self.config.get("triton.cudagraphs", True): 2246 from torch._inductor.cudagraph_trees import reset_cudagraph_trees 2247 2248 reset_cudagraph_trees() 2249 2250 2251class _TorchCompileWrapper: 2252 def __init__(self, backend, mode, options, dynamic): 2253 from torch._dynamo.backends.registry import lookup_backend 2254 2255 if isinstance(backend, str): 2256 self.compiler_name = backend 2257 elif hasattr(backend, "__name__"): 2258 self.compiler_name = backend.__name__ 2259 else: 2260 self.compiler_name = str(backend) 2261 self.dynamic = dynamic 2262 self.compiler_fn = lookup_backend(backend) 2263 self.kwargs = {} 2264 # only pass the args if they non-empty 2265 if mode and mode != "default": 2266 self.kwargs["mode"] = mode 2267 if options: 2268 self.kwargs["options"] = options 2269 2270 def __eq__(self, other): 2271 return ( 2272 isinstance(other, _TorchCompileWrapper) 2273 and self.compiler_fn == other.compiler_fn 2274 and self.kwargs == other.kwargs 2275 and self.dynamic == other.dynamic 2276 ) 2277 2278 def __call__(self, model_, inputs_): 2279 return self.compiler_fn(model_, inputs_, **self.kwargs) 2280 2281 def reset(self): 2282 if hasattr(self.compiler_fn, "reset"): 2283 self.compiler_fn.reset() 2284 2285 2286_InputT = _ParamSpec("_InputT") 2287_RetT = _TypeVar("_RetT") 2288 2289 2290@_overload 2291def compile( 2292 model: _Callable[_InputT, _RetT], 2293 *, 2294 fullgraph: builtins.bool = False, 2295 dynamic: _Optional[builtins.bool] = None, 2296 backend: _Union[str, _Callable] = "inductor", 2297 mode: _Union[str, None] = None, 2298 options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None, 2299 disable: builtins.bool = False, 2300) -> _Callable[_InputT, _RetT]: ... 2301 2302 2303@_overload 2304def compile( 2305 model: None = None, 2306 *, 2307 fullgraph: builtins.bool = False, 2308 dynamic: _Optional[builtins.bool] = None, 2309 backend: _Union[str, _Callable] = "inductor", 2310 mode: _Union[str, None] = None, 2311 options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None, 2312 disable: builtins.bool = False, 2313) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ... 2314 2315 2316def compile( 2317 model: _Optional[_Callable] = None, 2318 *, 2319 fullgraph: builtins.bool = False, 2320 dynamic: _Optional[builtins.bool] = None, 2321 backend: _Union[str, _Callable] = "inductor", 2322 mode: _Union[str, None] = None, 2323 options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None, 2324 disable: builtins.bool = False, 2325) -> _Union[ 2326 _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]], 2327 _Callable[_InputT, _RetT], 2328]: 2329 """ 2330 Optimizes given model/function using TorchDynamo and specified backend. 2331 If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` 2332 to compile the module inplace without changing its structure. 2333 2334 Concretely, for every frame executed within the compiled region, we will attempt 2335 to compile it and cache the compiled result on the code object for future 2336 use. A single frame may be compiled multiple times if previous compiled 2337 results are not applicable for subsequent calls (this is called a "guard 2338 failure), you can use TORCH_LOGS=guards to debug these situations. 2339 Multiple compiled results can be associated with a frame up to 2340 ``torch._dynamo.config.cache_size_limit``, which defaults to 8; at which 2341 point we will fall back to eager. Note that compile caches are per 2342 *code object*, not frame; if you dynamically create multiple copies of a 2343 function, they will all share the same code cache. 2344 2345 Args: 2346 model (Callable): Module/function to optimize 2347 fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions 2348 in the function that it will optimize. If True, then we require that the entire function be 2349 capturable into a single graph. If this is not possible (that is, if there are graph breaks), 2350 then this will raise an error. 2351 dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt 2352 to generate a kernel that is as dynamic as possible to avoid recompilations when 2353 sizes change. This may not always work as some operations/optimizations will 2354 force specialization; use TORCH_LOGS=dynamic to debug overspecialization. 2355 When this is False, we will NEVER generate dynamic kernels, we will always specialize. 2356 By default (None), we automatically detect if dynamism has occurred and compile a more 2357 dynamic kernel upon recompile. 2358 backend (str or Callable): backend to be used 2359 2360 - "inductor" is the default backend, which is a good balance between performance and overhead 2361 2362 - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()` 2363 2364 - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)` 2365 2366 - To register an out-of-tree custom backend: 2367 https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends 2368 mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs" 2369 2370 - "default" is the default mode, which is a good balance between performance and overhead 2371 2372 - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs, 2373 useful for small batches. Reduction of overhead can come at the cost of more memory 2374 usage, as we will cache the workspace memory required for the invocation so that we 2375 do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed 2376 to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs. 2377 There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints 2378 to debug. 2379 2380 - "max-autotune" is a mode that leverages Triton or template based matrix multiplications 2381 on supported devices and Triton based convolutions on GPU. 2382 It enables CUDA graphs by default on GPU. 2383 2384 - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs 2385 2386 - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()` 2387 2388 options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are 2389 2390 - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set 2391 2392 - `max_autotune` which will profile to pick the best matmul configuration 2393 2394 - `fallback_random` which is useful when debugging accuracy issues 2395 2396 - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores 2397 2398 - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs 2399 2400 - `trace.enabled` which is the most useful debugging flag to turn on 2401 2402 - `trace.graph_diagram` which will show you a picture of your graph after fusion 2403 2404 - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()` 2405 disable (bool): Turn torch.compile() into a no-op for testing 2406 2407 Example:: 2408 2409 @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) 2410 def foo(x): 2411 return torch.sin(x) + torch.cos(x) 2412 2413 """ 2414 _C._log_api_usage_once("torch.compile") 2415 if sys.version_info >= (3, 13): 2416 raise RuntimeError("Dynamo is not supported on Python 3.13+") 2417 2418 # Decorator mode 2419 if model is None: 2420 2421 def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: 2422 if model is None: 2423 raise RuntimeError("Model can't be None") 2424 return compile( 2425 model, 2426 fullgraph=fullgraph, 2427 dynamic=dynamic, 2428 backend=backend, 2429 mode=mode, 2430 options=options, 2431 disable=disable, 2432 ) 2433 2434 return fn 2435 2436 if mode is not None and options is not None: 2437 raise RuntimeError( 2438 "Either mode or options can be specified, but both can't be specified at the same time." 2439 ) 2440 if mode is None and options is None: 2441 mode = "default" 2442 if backend == "inductor": 2443 backend = _TorchCompileInductorWrapper(mode, options, dynamic) 2444 else: 2445 backend = _TorchCompileWrapper(backend, mode, options, dynamic) 2446 2447 return torch._dynamo.optimize( 2448 backend=backend, 2449 nopython=fullgraph, 2450 dynamic=dynamic, 2451 disable=disable, 2452 )(model) # type: ignore[return-value] 2453 2454 2455def _register_device_module(device_type, module): 2456 r"""Register an external runtime module of the specific :attr:`device_type` 2457 supported by torch. 2458 2459 After the :attr:`module` is registered correctly, the user can refer 2460 the external runtime module as part of torch with attribute torch.xxx. 2461 """ 2462 # Make sure the device_type represent a supported device type for torch. 2463 device_type = torch.device(device_type).type 2464 m = sys.modules[__name__] 2465 if hasattr(m, device_type): 2466 raise RuntimeError( 2467 f"The runtime module of '{device_type}' has already " 2468 f"been registered with '{getattr(m, device_type)}'" 2469 ) 2470 setattr(m, device_type, module) 2471 torch_module_name = ".".join([__name__, device_type]) 2472 sys.modules[torch_module_name] = module 2473 2474 2475from torch import ( 2476 export as export, 2477 func as func, 2478 library as library, 2479 return_types as return_types, 2480) 2481from torch._higher_order_ops import cond as cond, while_loop as while_loop 2482from torch.func import vmap as vmap 2483 2484 2485if not TYPE_CHECKING: 2486 from torch import _meta_registrations 2487 2488# Enable CUDA Sanitizer 2489if "TORCH_CUDA_SANITIZER" in os.environ: 2490 import torch.cuda._sanitizer as csan 2491 2492 csan.enable_cuda_sanitizer() 2493 2494# Populate magic methods on SymInt and SymFloat 2495import torch.fx.experimental.sym_node 2496 2497 2498# Register MPS specific decomps 2499torch.backends.mps._init() 2500 2501if not _running_with_deploy(): 2502 from torch import compiler as compiler 2503 2504 class _TritonLibrary: 2505 lib = torch.library.Library("triton", "DEF") 2506 ops_table: _Dict[_Tuple[str, str], _Callable] = {} 2507 2508 @classmethod 2509 def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): 2510 if (op_key, dispatch_key) not in cls.ops_table: 2511 cls.lib.define(full_schema) 2512 cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) 2513 cls.ops_table[(op_key, dispatch_key)] = op_impl 2514 2515 return cls.ops_table[(op_key, dispatch_key)] 2516 2517 2518# Deprecated attributes 2519_deprecated_attrs = { 2520 "has_mps": torch.backends.mps.is_built, 2521 "has_cuda": torch.backends.cuda.is_built, 2522 "has_cudnn": torch.backends.cudnn.is_available, 2523 "has_mkldnn": torch.backends.mkldnn.is_available, 2524} 2525 2526if TYPE_CHECKING: 2527 # Import the following modules during type checking to enable code intelligence features, 2528 # such as auto-completion in tools like pylance, even when these modules are not explicitly 2529 # imported in user code. 2530 from torch import ( 2531 _dynamo as _dynamo, 2532 _inductor as _inductor, 2533 _subclasses as _subclasses, 2534 onnx as onnx, 2535 ) 2536 2537else: 2538 _lazy_modules = { 2539 "_dynamo", 2540 "_inductor", 2541 "_export", 2542 # ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit 2543 "onnx", 2544 } 2545 2546 def __getattr__(name): 2547 # Deprecated attrs 2548 replacement = _deprecated_attrs.get(name) 2549 if replacement is not None: 2550 import warnings 2551 2552 warnings.warn( 2553 f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", 2554 stacklevel=2, 2555 ) 2556 return replacement() 2557 2558 # Lazy modules 2559 if name in _lazy_modules: 2560 return importlib.import_module(f".{name}", __name__) 2561 2562 raise AttributeError(f"module '{__name__}' has no attribute '{name}'") 2563 2564 2565def get_device_module(device: _Optional[_Union[torch.device, str]] = None): 2566 """ 2567 Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). 2568 If no device is given, return the module for the current accelerator or CPU if none is present. 2569 """ 2570 if isinstance(device, torch.device): 2571 device_module_name = device.type 2572 elif isinstance(device, str): 2573 device_module_name = torch.device(device).type 2574 elif device is None: 2575 # Using default accelerator type. If no accelerator is available, it automatically returns CPU device. 2576 device_module_name = torch._C._get_accelerator().type 2577 else: 2578 raise RuntimeError( 2579 f"Invalid value of device '{device}', expect torch.device, str, or None" 2580 ) 2581 device_module = getattr(torch, device_module_name, None) 2582 if device_module is None: 2583 raise RuntimeError( 2584 f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'." 2585 ) 2586 return device_module 2587 2588 2589def _constrain_as_size( 2590 symbol, 2591 min: _Optional[builtins.int] = None, 2592 max: _Optional[builtins.int] = None, 2593): 2594 """ 2595 This indicates that a given int is size-like, and can be used in any context where a size is expected. 2596 You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist() 2597 which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve 2598 GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts. 2599 2600 This function has unusual semantics in some circumstances in framework 2601 code, we will treat this int as >= 2 (when we do a size-oblivious guard). 2602 This makes it easier to use the unbacked int in size contexts, 2603 as we will often attempt to guard on a size being zero/one 2604 (e.g., when computing the contiguity of a tensor, or testing if 2605 broadcasting can occur), which will not work on unbacked SymInts. 2606 However, if we conservatively assume that the size is not zero/one, we will 2607 end up with a graph that will still work even if the size is zero/one. 2608 2609 For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit 2610 ``` 2611 """ 2612 torch.sym_constrain_range_for_size(symbol, min=min, max=max) 2613 2614 2615from torch import _logging 2616 2617 2618_logging._init_logs() 2619 2620 2621def _import_device_backends(): 2622 """ 2623 Leverage the Python plugin mechanism to load out-of-the-tree device extensions. 2624 See this RFC: https://github.com/pytorch/pytorch/issues/122468 2625 """ 2626 from importlib.metadata import entry_points 2627 2628 group_name = "torch.backends" 2629 if sys.version_info < (3, 10): 2630 backend_extensions = entry_points().get(group_name, ()) 2631 else: 2632 backend_extensions = entry_points(group=group_name) 2633 2634 for backend_extension in backend_extensions: 2635 try: 2636 # Load the extension 2637 entrypoint = backend_extension.load() 2638 # Call the entrypoint 2639 entrypoint() 2640 except Exception as err: 2641 raise RuntimeError( 2642 f"Failed to load the backend extension: {backend_extension.name}. " 2643 f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0." 2644 ) from err 2645 2646 2647def _is_device_backend_autoload_enabled() -> builtins.bool: 2648 """ 2649 Whether autoloading out-of-the-tree device extensions is enabled. 2650 The switch depends on the value of the environment variable 2651 `TORCH_DEVICE_BACKEND_AUTOLOAD`. 2652 2653 Returns: 2654 bool: Whether to enable autoloading the extensions. Enabled by default. 2655 2656 Examples: 2657 >>> torch._is_device_backend_autoload_enabled() 2658 True 2659 """ 2660 # enabled by default 2661 return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1" 2662 2663 2664if _is_device_backend_autoload_enabled(): 2665 _import_device_backends() 2666