1# mypy: allow-untyped-defs 2from typing import Optional 3import torch 4from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode 5from torch.utils._contextlib import context_decorator 6from torch._C import _len_torch_function_stack 7import functools 8 9CURRENT_DEVICE: Optional[torch.device] = None 10 11@functools.lru_cache(1) 12def _device_constructors(): 13 return { 14 # standard ones 15 torch.empty, 16 torch.empty_permuted, 17 torch.empty_strided, 18 torch.empty_quantized, 19 torch.ones, 20 torch.arange, 21 torch.bartlett_window, 22 torch.blackman_window, 23 torch.eye, 24 torch.fft.fftfreq, 25 torch.fft.rfftfreq, 26 torch.full, 27 torch.fill, 28 torch.hamming_window, 29 torch.hann_window, 30 torch.kaiser_window, 31 torch.linspace, 32 torch.logspace, 33 torch.nested.nested_tensor, 34 # This function doesn't actually take a device argument 35 # torch.normal, 36 torch.ones, 37 torch.rand, 38 torch.randn, 39 torch.randint, 40 torch.randperm, 41 torch.range, 42 torch.sparse_coo_tensor, 43 torch.sparse_compressed_tensor, 44 torch.sparse_csr_tensor, 45 torch.sparse_csc_tensor, 46 torch.sparse_bsr_tensor, 47 torch.sparse_bsc_tensor, 48 torch.tril_indices, 49 torch.triu_indices, 50 torch.vander, 51 torch.zeros, 52 torch.asarray, 53 # weird ones 54 torch.tensor, 55 torch.as_tensor, 56 torch.scalar_tensor, 57 torch.asarray, 58 } 59 60# NB: This is directly called from C++ in torch/csrc/Device.cpp 61class DeviceContext(TorchFunctionMode): 62 def __init__(self, device): 63 self.device = torch.device(device) 64 65 def __enter__(self): 66 global CURRENT_DEVICE 67 self.old_device = CURRENT_DEVICE 68 CURRENT_DEVICE = self.device 69 # We need to put the device at the bottom of the stack 70 # If we set default device within a function mode context 71 # exiting that context mode will pop the device function mode off 72 # of the stack incorrectly 73 cur_stack = [] 74 for _ in range(_len_torch_function_stack()): 75 cur_stack.append(_pop_mode()) 76 77 _push_mode(self) 78 79 for mode in reversed(cur_stack): 80 _push_mode(mode) 81 82 83 def __exit__(self, exc_type, exc_val, exc_tb): 84 global CURRENT_DEVICE 85 CURRENT_DEVICE = self.old_device 86 cur_stack = [] 87 # Invariant: there should only be one DeviceContext on the stack at any time 88 # (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext 89 # or else someone else has popped it! 90 for _ in range(_len_torch_function_stack() - 1): 91 mode = _pop_mode() 92 assert not isinstance(mode, DeviceContext) 93 cur_stack.append(mode) 94 95 if _len_torch_function_stack() > 0: 96 mode = _pop_mode() 97 assert isinstance(mode, DeviceContext) 98 99 for mode in reversed(cur_stack): 100 _push_mode(mode) 101 102 def __torch_function__(self, func, types, args=(), kwargs=None): 103 kwargs = kwargs or {} 104 if func in _device_constructors() and kwargs.get('device') is None: 105 kwargs['device'] = self.device 106 return func(*args, **kwargs) 107 108# NB: This is directly called from C++ in torch/csrc/Device.cpp 109def device_decorator(device, func): 110 return context_decorator(lambda: device, func) 111 112def set_device(device): 113 """ 114 Set the default device inside of the wrapped function by decorating it with this function. 115 116 If you would like to use this as a context manager, use device as a 117 context manager directly, e.g., ``with torch.device(device)``. 118 """ 119 return lambda func: device_decorator(torch.device(device), func) 120