xref: /aosp_15_r20/external/pytorch/torch/utils/backend_registration.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch.overrides import (
4    handle_torch_function,
5    has_torch_function_unary,
6)
7from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name
8from typing import List, Optional, Union
9
10__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"]
11
12# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get
13# renamed-backend name for `privateuse1`, but the func will cause an
14# error with torch.jit.script, so we use the global variable named
15# `_privateuse1_backend_name`.
16_privateuse1_backend_name = "privateuseone"
17
18def rename_privateuse1_backend(backend_name: str) -> None:
19    r"""
20    Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs.
21
22    The steps are:
23
24    (1) (In C++) implement kernels for various torch operations, and register them
25        to the PrivateUse1 dispatch key.
26    (2) (In python) call torch.utils.rename_privateuse1_backend("foo")
27
28    You can now use "foo" as an ordinary device string in python.
29
30    Note: this API can only be called once per process. Attempting to change
31    the external backend after it's already been set will result in an error.
32
33    Note(AMP): If you want to support AMP on your device, you can register a custom backend module.
34    The backend must register a custom backend module with ``torch._register_device_module("foo", BackendModule)``.
35    BackendModule needs to have the following API's:
36
37    (1) ``get_amp_supported_dtype() -> List[torch.dtype]``
38        get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
39
40    Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
41
42    (1) ``_is_in_bad_fork() -> bool``
43        Return ``True`` if now it is in bad_fork, else return ``False``.
44
45    (2) ``manual_seed_all(seed int) -> None``
46        Sets the seed for generating random numbers for your devices.
47
48    (3) ``device_count() -> int``
49        Returns the number of "foo"s available.
50
51    (4) ``get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor``
52        Returns a list of ByteTensor representing the random number states of all devices.
53
54    (5) ``set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None``
55        Sets the random number generator state of the specified "foo" device.
56
57    And there are some common funcs:
58
59    (1) ``is_available() -> bool``
60        Returns a bool indicating if "foo" is currently available.
61
62    (2) ``current_device() -> int``
63        Returns the index of a currently selected device.
64
65    For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
66    For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
67
68    Example::
69
70        >>> # xdoctest: +SKIP("failing")
71        >>> torch.utils.rename_privateuse1_backend("foo")
72        # This will work, assuming that you've implemented the right C++ kernels
73        # to implement torch.ones.
74        >>> a = torch.ones(2, device="foo")
75
76    """
77    _rename_privateuse1_backend(backend_name)
78    global _privateuse1_backend_name
79    _privateuse1_backend_name = backend_name
80
81def _check_register_once(module, attr):
82    if hasattr(module, attr):
83        raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}")
84
85
86def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int:
87    def _get_current_device_index():
88        _get_device_index = "current_device"
89        if hasattr(torch, custom_backend_name) and \
90                hasattr(getattr(torch, custom_backend_name), _get_device_index):
91            return getattr(getattr(torch, custom_backend_name), _get_device_index)()
92        else:
93            # The default device index is 0.
94            return 0
95
96    if device is None:
97        return _get_current_device_index()
98    # if isinstance(device, str), this means that the parameter passed in is in the string format "foo:0"
99    # convert str object to torch.device object, and then process it uniformly
100    elif isinstance(device, str):
101        device = torch.device(device)
102
103    # variable devcie can only be torch.device type or int type
104    if isinstance(device, torch.device):
105        if device.type != custom_backend_name:
106            raise RuntimeError(f"Invalid device, must be {custom_backend_name} device")
107        elif device.index is None:
108            device_idx = _get_current_device_index()
109        else:
110            device_idx = device.index
111    # if isinstance(device, int), we can take the index number directly
112    else:
113        device_idx = device
114    return device_idx
115
116
117def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
118    @property  # type: ignore[misc]
119    def wrap_tensor_backend(self: torch.Tensor) -> bool:
120        if has_torch_function_unary(self):
121            # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
122            return handle_torch_function(wrap_tensor_backend.__get__, (self,), self)  # type: ignore[attr-defined]
123        return self.device.type == custom_backend_name
124
125    _check_register_once(torch.Tensor, f'is_{custom_backend_name}')
126    wrap_tensor_backend.fget.__name__ = f'is_{custom_backend_name}'  # type: ignore[attr-defined]
127    setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend)
128
129    def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False,
130                       **kwargs) -> torch.Tensor:
131        r"""Perform Tensor device conversion. Call the to operator implementation.
132
133        .. note::
134            If the ``self`` Tensor already
135            has the correct :class:`torch.device`, then ``self`` is returned.
136            Otherwise, the returned tensor is a copy of ``self`` with the desired :class:`torch.device`.
137
138        Args:
139            device (int, optional): if specified, all parameters will be copied to that device
140            non_blocking (bool): If ``True`` and the source is in pinned memory,
141                the copy will be asynchronous with respect to the host. Otherwise,
142                the argument has no effect.
143            **kwargs (dict): For compatibility, may contain the key ``memory_format`` argument.
144        """
145        if has_torch_function_unary(self):
146            return handle_torch_function(wrap_tensor_to, (self,), self, device=device, non_blocking=False, **kwargs)
147        device_idx = _normalization_device(custom_backend_name, device)
148        return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs)
149
150    _check_register_once(torch.Tensor, custom_backend_name)
151    wrap_tensor_to.__name__ = custom_backend_name
152    setattr(torch.Tensor, custom_backend_name, wrap_tensor_to)
153
154
155def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
156    # Generate Module attributes and methods depends on Tensor methods,
157    # so we need to check whether Tensor methods is already registered.
158    if not hasattr(torch.Tensor, custom_backend_name):
159        raise RuntimeError(
160            f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module."
161            f"Because torch.Tensor doesn't has the method {custom_backend_name}()."
162            f"For this error, you can try setting for_tensor=True.")
163
164    def wrap_module_to(self: torch.nn.modules.module.T,
165                       device: Optional[Union[int, torch.device]] = None) -> torch.nn.modules.module.T:
166        r"""Move all model parameters and buffers to the custom device.
167
168        This also makes associated parameters and buffers different objects. So
169        it should be called before constructing optimizer if the module will
170        live on device while being optimized.
171
172        .. note::
173            This method modifies the module in-place.
174
175        Args:
176            device (int, optional): if specified, all parameters will be copied to that device
177        """
178        return self._apply(lambda t: getattr(t, custom_backend_name)(device))
179
180    _check_register_once(torch.nn.Module, custom_backend_name)
181    setattr(torch.nn.Module, custom_backend_name, wrap_module_to)
182
183def _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name: str) -> None:
184    # Generate PackedSequence Module attributes and methods depends on Tensor methods,
185    # so we need to check whether Tensor methods is already registered.
186    if not hasattr(torch.Tensor, f'is_{custom_backend_name}') or \
187       not hasattr(torch.Tensor, custom_backend_name):
188        raise RuntimeError(
189            f"Can not automatically generate is_{custom_backend_name}() or "
190            f"{custom_backend_name}() method for torch.nn.utils.rnn.PackedSequence."
191            f"Because torch.Tensor doesn't has the method is_{custom_backend_name}()"
192            f"or {custom_backend_name}()."
193            f"For this error, you can try setting for_tensor=True.")
194
195    @property  # type: ignore[misc]
196    def wrap_tensor_backend(self: torch.nn.utils.rnn.PackedSequence) -> bool:
197        return self.data.device.type == custom_backend_name
198
199    _check_register_once(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}')
200    setattr(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}', wrap_tensor_backend)
201
202    def wrap_module_to(self: torch.nn.utils.rnn.PackedSequence,
203                       *args, **kwargs) -> torch.nn.utils.rnn.PackedSequence:
204        r"""Move all model parameters and buffers to the custom device.
205
206        This also makes associated parameters and buffers different objects. So
207        it should be called before constructing optimizer if the module will
208        live on device while being optimized.
209
210        .. note::
211            This method modifies the module in-place.
212
213        Args:
214            device (int, optional): if specified, all parameters will be copied to that device
215        """
216        ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs)
217        if ex.device.type == custom_backend_name:
218            return self.to(*args, **kwargs)
219        kwargs.update({'device': custom_backend_name})
220        return self.to(*args, **kwargs)
221
222    _check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)
223    setattr(torch.nn.utils.rnn.PackedSequence, custom_backend_name, wrap_module_to)
224
225def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str,
226                                                      unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
227    # Attribute is registered in the _StorageBase class
228    # and UntypedStorage obtains through inheritance.
229    @property  # type: ignore[misc]
230    def wrap_storage_backend(self: torch.storage._StorageBase) -> bool:
231        r"""Return the internal :class:`torch.UntypedStorage`."""
232        return self.device.type == custom_backend_name
233
234    _check_register_once(torch.storage._StorageBase, f'is_{custom_backend_name}')
235    setattr(torch.storage._StorageBase, f'is_{custom_backend_name}', wrap_storage_backend)
236
237    def wrap_storage_to(self, device=None, non_blocking=False):
238        r"""Return a copy of this object in custom device memory.
239
240        If this object is already in device memory and on the correct device, then
241        no copy is performed and the original object is returned.
242
243        Args:
244            device (int): The destination device id. Defaults to the current device.
245            non_blocking (bool): If ``True`` and the source is in pinned memory,
246            the copy will be asynchronous with respect to the host. Otherwise,
247            the argument has no effect.
248        """
249        # There should be a judgment related to storage device and a judgment related to storage type,
250        # but it depends on the extended function, so this part is temporarily omitted in the automatic generation.
251        device_idx = _normalization_device(custom_backend_name, device)
252
253        if getattr(self, f'is_{custom_backend_name}'):
254            # storage has already on expected device.
255            if self.get_device() == device_idx:
256                return self
257        # For sparse storage, custom need to extend the implementation by themselves.
258        if self.is_sparse:
259            raise RuntimeError(f"Can not support a sparse storage move to {custom_backend_name} backend")
260        # create untyped_storage and copy data
261        untyped_storage = torch.UntypedStorage(
262            self.size(), device=torch.device(f'{custom_backend_name}:{device_idx}')
263        )
264        untyped_storage.copy_(self, non_blocking)
265        return untyped_storage
266
267    _check_register_once(torch.storage._StorageBase, custom_backend_name)
268    setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to)
269
270    # Register the corresponding attribute for the TypedStorage class.
271    # When the TypedStorage class is removed, the registration is also removed.
272
273    @property  # type: ignore[misc]
274    def wrap_typed_storage_backend(self: torch.storage.TypedStorage) -> bool:
275        torch.storage._warn_typed_storage_removal()
276        return self._untyped_storage.device.type == custom_backend_name
277
278    _check_register_once(torch.TypedStorage, f'is_{custom_backend_name}')
279    setattr(torch.storage.TypedStorage, f'is_{custom_backend_name}', wrap_typed_storage_backend)
280
281    def wrap_typed_storage_to(self: torch.storage.TypedStorage,
282                              device=None, non_blocking=False, **kwargs) -> torch.storage.TypedStorage:
283        torch.storage._warn_typed_storage_removal()
284        if unsupported_dtype and self.dtype in unsupported_dtype:
285            raise RuntimeError(f"Cannot create {custom_backend_name} storage "
286                               f"as {self.dtype} dtype is not supported by this backend")
287        custom_backend_storage: torch.UntypedStorage = getattr(
288            self._untyped_storage, custom_backend_name)(device, non_blocking, **kwargs)
289        return self._new_wrapped_storage(custom_backend_storage)
290
291    _check_register_once(torch.TypedStorage, custom_backend_name)
292    setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to)
293
294
295def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True,
296                                             for_packed_sequence: bool = True,
297                                             for_storage: bool = False,
298                                             unsupported_dtype: Optional[List[torch.dtype]] = None) -> None:
299    r"""
300    Automatically generate attributes and methods for the custom backend after rename privateuse1 backend.
301
302    In the default scenario, storage-related methods will not be generated automatically.
303
304    When you implement kernels for various torch operations, and register them to the PrivateUse1 dispatch key.
305    And call the function torch.rename_privateuse1_backend("foo") to rename your backend name.
306    At this point, you can easily register specific methods and attributes by calling this function.
307    Just like torch.Tensor.foo(), torch.Tensor.is_foo, torch.Storage.foo(), torch.Storage.is_foo.
308
309    Note: We recommend you use generic functions (check devices are equal or to(device=)).
310    We provide these methods for convenience only and they will be "monkey patched" onto the objects
311    and so will not be properly typed. For Storage methods generate, if you need to support sparse data storage,
312    you need to extend the implementation yourself.
313
314    Args:
315        for_tensor (bool): whether register related methods for torch.Tensor class.
316        for_module (bool): whether register related methods for torch.nn.Module class.
317        for_storage (bool): whether register related methods for torch.Storage class.
318        unsupported_dtype (List[torch.dtype]): takes effect only when the storage method needs to be generated,
319            indicating that the storage does not support the torch.dtype type.
320
321    Example::
322
323        >>> # xdoctest: +SKIP("failing")
324        >>> torch.utils.rename_privateuse1_backend("foo")
325        >>> torch.utils.generate_methods_for_privateuse1_backend()
326        # Then automatically generate backend-related attributes and methods.
327        >>> a = torch.tensor(2).foo()
328        >>> a.is_foo
329        >>> hasattr(torch.nn.Module, 'foo')
330    """
331    custom_backend_name = _get_privateuse1_backend_name()
332
333    if for_tensor:
334        _generate_tensor_methods_for_privateuse1_backend(custom_backend_name)
335
336    if for_module:
337        _generate_module_methods_for_privateuse1_backend(custom_backend_name)
338
339    if for_storage:
340        _generate_storage_methods_for_privateuse1_backend(custom_backend_name, unsupported_dtype)
341
342    if for_packed_sequence:
343        _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name)
344
345def _get_custom_mod_func(func_name: str):
346    r"""
347    Return the func named `func_name` defined in custom device module. If not defined,
348    return `None`. And the func is registered with `torch.utils.rename_privateuse1_backend('foo')`
349    and `torch._register_device_module('foo', BackendModule)`.
350    If the custom device module or the func is not defined, it will give warning or error message.
351    Args:
352        func_name (str): return the callable func named func_name defined in custom device module.
353    Example::
354        class DummyfooModule:
355            @staticmethod
356            def is_available():
357                return True
358            @staticmethod
359            def func_name(*args, **kwargs):
360                ....
361        torch.utils.rename_privateuse1_backend("foo")
362        torch._register_device_module("foo", DummyfooModule)
363        foo_is_available_func = torch.utils.backend_registration._get_custom_mod_func("is_available")
364        if foo_is_available_func:
365            foo_is_available = foo_is_available_func()
366        func_ = torch.utils.backend_registration._get_custom_mod_func("func_name")
367        if func_:
368            result = func_(*args, **kwargs)
369    Attention: This function is not meant to be used directly by users, which is why
370    it is marked as private. It is a convenience function for backend implementers to
371    more easily call the hooks into their backend extensions.
372    """
373    assert isinstance(func_name, str), f"func_name must be `str`, but got `{type(func_name)}`."
374    backend_name = _get_privateuse1_backend_name()
375    custom_device_mod = getattr(torch, backend_name, None)  # type: ignore[arg-type]
376    function = getattr(custom_device_mod, func_name, None)  # type: ignore[arg-type]
377    if custom_device_mod is None or function is None:
378        message = f'Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend '
379        message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And "
380        message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n"
381        raise RuntimeError(message)
382    return function
383