xref: /aosp_15_r20/external/pytorch/torch/utils/_device.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional
3import torch
4from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode
5from torch.utils._contextlib import context_decorator
6from torch._C import _len_torch_function_stack
7import functools
8
9CURRENT_DEVICE: Optional[torch.device] = None
10
11@functools.lru_cache(1)
12def _device_constructors():
13    return {
14        # standard ones
15        torch.empty,
16        torch.empty_permuted,
17        torch.empty_strided,
18        torch.empty_quantized,
19        torch.ones,
20        torch.arange,
21        torch.bartlett_window,
22        torch.blackman_window,
23        torch.eye,
24        torch.fft.fftfreq,
25        torch.fft.rfftfreq,
26        torch.full,
27        torch.fill,
28        torch.hamming_window,
29        torch.hann_window,
30        torch.kaiser_window,
31        torch.linspace,
32        torch.logspace,
33        torch.nested.nested_tensor,
34        # This function doesn't actually take a device argument
35        # torch.normal,
36        torch.ones,
37        torch.rand,
38        torch.randn,
39        torch.randint,
40        torch.randperm,
41        torch.range,
42        torch.sparse_coo_tensor,
43        torch.sparse_compressed_tensor,
44        torch.sparse_csr_tensor,
45        torch.sparse_csc_tensor,
46        torch.sparse_bsr_tensor,
47        torch.sparse_bsc_tensor,
48        torch.tril_indices,
49        torch.triu_indices,
50        torch.vander,
51        torch.zeros,
52        torch.asarray,
53        # weird ones
54        torch.tensor,
55        torch.as_tensor,
56        torch.scalar_tensor,
57        torch.asarray,
58    }
59
60# NB: This is directly called from C++ in torch/csrc/Device.cpp
61class DeviceContext(TorchFunctionMode):
62    def __init__(self, device):
63        self.device = torch.device(device)
64
65    def __enter__(self):
66        global CURRENT_DEVICE
67        self.old_device = CURRENT_DEVICE
68        CURRENT_DEVICE = self.device
69        # We need to put the device at the bottom of the stack
70        # If we set default device within a function mode context
71        # exiting that context mode will pop the device function mode off
72        # of the stack incorrectly
73        cur_stack = []
74        for _ in range(_len_torch_function_stack()):
75            cur_stack.append(_pop_mode())
76
77        _push_mode(self)
78
79        for mode in reversed(cur_stack):
80            _push_mode(mode)
81
82
83    def __exit__(self, exc_type, exc_val, exc_tb):
84        global CURRENT_DEVICE
85        CURRENT_DEVICE = self.old_device
86        cur_stack = []
87        # Invariant: there should only be one DeviceContext on the stack at any time
88        # (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext
89        # or else someone else has popped it!
90        for _ in range(_len_torch_function_stack() - 1):
91            mode = _pop_mode()
92            assert not isinstance(mode, DeviceContext)
93            cur_stack.append(mode)
94
95        if _len_torch_function_stack() > 0:
96            mode = _pop_mode()
97            assert isinstance(mode, DeviceContext)
98
99        for mode in reversed(cur_stack):
100            _push_mode(mode)
101
102    def __torch_function__(self, func, types, args=(), kwargs=None):
103        kwargs = kwargs or {}
104        if func in _device_constructors() and kwargs.get('device') is None:
105            kwargs['device'] = self.device
106        return func(*args, **kwargs)
107
108# NB: This is directly called from C++ in torch/csrc/Device.cpp
109def device_decorator(device, func):
110    return context_decorator(lambda: device, func)
111
112def set_device(device):
113    """
114    Set the default device inside of the wrapped function by decorating it with this function.
115
116    If you would like to use this as a context manager, use device as a
117    context manager directly, e.g., ``with torch.device(device)``.
118    """
119    return lambda func: device_decorator(torch.device(device), func)
120