1# mypy: allow-untyped-defs 2import torch 3from torch.overrides import ( 4 handle_torch_function, 5 has_torch_function_unary, 6) 7from torch._C import _rename_privateuse1_backend, _get_privateuse1_backend_name 8from typing import List, Optional, Union 9 10__all__ = ["rename_privateuse1_backend", "generate_methods_for_privateuse1_backend"] 11 12# TODO: Should use `torch._C._get_privateuse1_backend_name()` to get 13# renamed-backend name for `privateuse1`, but the func will cause an 14# error with torch.jit.script, so we use the global variable named 15# `_privateuse1_backend_name`. 16_privateuse1_backend_name = "privateuseone" 17 18def rename_privateuse1_backend(backend_name: str) -> None: 19 r""" 20 Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs. 21 22 The steps are: 23 24 (1) (In C++) implement kernels for various torch operations, and register them 25 to the PrivateUse1 dispatch key. 26 (2) (In python) call torch.utils.rename_privateuse1_backend("foo") 27 28 You can now use "foo" as an ordinary device string in python. 29 30 Note: this API can only be called once per process. Attempting to change 31 the external backend after it's already been set will result in an error. 32 33 Note(AMP): If you want to support AMP on your device, you can register a custom backend module. 34 The backend must register a custom backend module with ``torch._register_device_module("foo", BackendModule)``. 35 BackendModule needs to have the following API's: 36 37 (1) ``get_amp_supported_dtype() -> List[torch.dtype]`` 38 get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype. 39 40 Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's: 41 42 (1) ``_is_in_bad_fork() -> bool`` 43 Return ``True`` if now it is in bad_fork, else return ``False``. 44 45 (2) ``manual_seed_all(seed int) -> None`` 46 Sets the seed for generating random numbers for your devices. 47 48 (3) ``device_count() -> int`` 49 Returns the number of "foo"s available. 50 51 (4) ``get_rng_state(device: Union[int, str, torch.device] = 'foo') -> Tensor`` 52 Returns a list of ByteTensor representing the random number states of all devices. 53 54 (5) ``set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> None`` 55 Sets the random number generator state of the specified "foo" device. 56 57 And there are some common funcs: 58 59 (1) ``is_available() -> bool`` 60 Returns a bool indicating if "foo" is currently available. 61 62 (2) ``current_device() -> int`` 63 Returns the index of a currently selected device. 64 65 For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend 66 For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example 67 68 Example:: 69 70 >>> # xdoctest: +SKIP("failing") 71 >>> torch.utils.rename_privateuse1_backend("foo") 72 # This will work, assuming that you've implemented the right C++ kernels 73 # to implement torch.ones. 74 >>> a = torch.ones(2, device="foo") 75 76 """ 77 _rename_privateuse1_backend(backend_name) 78 global _privateuse1_backend_name 79 _privateuse1_backend_name = backend_name 80 81def _check_register_once(module, attr): 82 if hasattr(module, attr): 83 raise RuntimeError(f"The custom device module of {module} has already been registered with {attr}") 84 85 86def _normalization_device(custom_backend_name: str, device: Optional[Union[int, str, torch.device]] = None) -> int: 87 def _get_current_device_index(): 88 _get_device_index = "current_device" 89 if hasattr(torch, custom_backend_name) and \ 90 hasattr(getattr(torch, custom_backend_name), _get_device_index): 91 return getattr(getattr(torch, custom_backend_name), _get_device_index)() 92 else: 93 # The default device index is 0. 94 return 0 95 96 if device is None: 97 return _get_current_device_index() 98 # if isinstance(device, str), this means that the parameter passed in is in the string format "foo:0" 99 # convert str object to torch.device object, and then process it uniformly 100 elif isinstance(device, str): 101 device = torch.device(device) 102 103 # variable devcie can only be torch.device type or int type 104 if isinstance(device, torch.device): 105 if device.type != custom_backend_name: 106 raise RuntimeError(f"Invalid device, must be {custom_backend_name} device") 107 elif device.index is None: 108 device_idx = _get_current_device_index() 109 else: 110 device_idx = device.index 111 # if isinstance(device, int), we can take the index number directly 112 else: 113 device_idx = device 114 return device_idx 115 116 117def _generate_tensor_methods_for_privateuse1_backend(custom_backend_name: str) -> None: 118 @property # type: ignore[misc] 119 def wrap_tensor_backend(self: torch.Tensor) -> bool: 120 if has_torch_function_unary(self): 121 # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185 122 return handle_torch_function(wrap_tensor_backend.__get__, (self,), self) # type: ignore[attr-defined] 123 return self.device.type == custom_backend_name 124 125 _check_register_once(torch.Tensor, f'is_{custom_backend_name}') 126 wrap_tensor_backend.fget.__name__ = f'is_{custom_backend_name}' # type: ignore[attr-defined] 127 setattr(torch.Tensor, f'is_{custom_backend_name}', wrap_tensor_backend) 128 129 def wrap_tensor_to(self: torch.Tensor, device: Optional[Union[int, torch.device]] = None, non_blocking=False, 130 **kwargs) -> torch.Tensor: 131 r"""Perform Tensor device conversion. Call the to operator implementation. 132 133 .. note:: 134 If the ``self`` Tensor already 135 has the correct :class:`torch.device`, then ``self`` is returned. 136 Otherwise, the returned tensor is a copy of ``self`` with the desired :class:`torch.device`. 137 138 Args: 139 device (int, optional): if specified, all parameters will be copied to that device 140 non_blocking (bool): If ``True`` and the source is in pinned memory, 141 the copy will be asynchronous with respect to the host. Otherwise, 142 the argument has no effect. 143 **kwargs (dict): For compatibility, may contain the key ``memory_format`` argument. 144 """ 145 if has_torch_function_unary(self): 146 return handle_torch_function(wrap_tensor_to, (self,), self, device=device, non_blocking=False, **kwargs) 147 device_idx = _normalization_device(custom_backend_name, device) 148 return self.to(device=torch.device(f'{custom_backend_name}:{device_idx}'), non_blocking=non_blocking, **kwargs) 149 150 _check_register_once(torch.Tensor, custom_backend_name) 151 wrap_tensor_to.__name__ = custom_backend_name 152 setattr(torch.Tensor, custom_backend_name, wrap_tensor_to) 153 154 155def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -> None: 156 # Generate Module attributes and methods depends on Tensor methods, 157 # so we need to check whether Tensor methods is already registered. 158 if not hasattr(torch.Tensor, custom_backend_name): 159 raise RuntimeError( 160 f"Can not automatically generate {custom_backend_name}() method for torch.nn.Module." 161 f"Because torch.Tensor doesn't has the method {custom_backend_name}()." 162 f"For this error, you can try setting for_tensor=True.") 163 164 def wrap_module_to(self: torch.nn.modules.module.T, 165 device: Optional[Union[int, torch.device]] = None) -> torch.nn.modules.module.T: 166 r"""Move all model parameters and buffers to the custom device. 167 168 This also makes associated parameters and buffers different objects. So 169 it should be called before constructing optimizer if the module will 170 live on device while being optimized. 171 172 .. note:: 173 This method modifies the module in-place. 174 175 Args: 176 device (int, optional): if specified, all parameters will be copied to that device 177 """ 178 return self._apply(lambda t: getattr(t, custom_backend_name)(device)) 179 180 _check_register_once(torch.nn.Module, custom_backend_name) 181 setattr(torch.nn.Module, custom_backend_name, wrap_module_to) 182 183def _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name: str) -> None: 184 # Generate PackedSequence Module attributes and methods depends on Tensor methods, 185 # so we need to check whether Tensor methods is already registered. 186 if not hasattr(torch.Tensor, f'is_{custom_backend_name}') or \ 187 not hasattr(torch.Tensor, custom_backend_name): 188 raise RuntimeError( 189 f"Can not automatically generate is_{custom_backend_name}() or " 190 f"{custom_backend_name}() method for torch.nn.utils.rnn.PackedSequence." 191 f"Because torch.Tensor doesn't has the method is_{custom_backend_name}()" 192 f"or {custom_backend_name}()." 193 f"For this error, you can try setting for_tensor=True.") 194 195 @property # type: ignore[misc] 196 def wrap_tensor_backend(self: torch.nn.utils.rnn.PackedSequence) -> bool: 197 return self.data.device.type == custom_backend_name 198 199 _check_register_once(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}') 200 setattr(torch.nn.utils.rnn.PackedSequence, f'is_{custom_backend_name}', wrap_tensor_backend) 201 202 def wrap_module_to(self: torch.nn.utils.rnn.PackedSequence, 203 *args, **kwargs) -> torch.nn.utils.rnn.PackedSequence: 204 r"""Move all model parameters and buffers to the custom device. 205 206 This also makes associated parameters and buffers different objects. So 207 it should be called before constructing optimizer if the module will 208 live on device while being optimized. 209 210 .. note:: 211 This method modifies the module in-place. 212 213 Args: 214 device (int, optional): if specified, all parameters will be copied to that device 215 """ 216 ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(*args, **kwargs) 217 if ex.device.type == custom_backend_name: 218 return self.to(*args, **kwargs) 219 kwargs.update({'device': custom_backend_name}) 220 return self.to(*args, **kwargs) 221 222 _check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name) 223 setattr(torch.nn.utils.rnn.PackedSequence, custom_backend_name, wrap_module_to) 224 225def _generate_storage_methods_for_privateuse1_backend(custom_backend_name: str, 226 unsupported_dtype: Optional[List[torch.dtype]] = None) -> None: 227 # Attribute is registered in the _StorageBase class 228 # and UntypedStorage obtains through inheritance. 229 @property # type: ignore[misc] 230 def wrap_storage_backend(self: torch.storage._StorageBase) -> bool: 231 r"""Return the internal :class:`torch.UntypedStorage`.""" 232 return self.device.type == custom_backend_name 233 234 _check_register_once(torch.storage._StorageBase, f'is_{custom_backend_name}') 235 setattr(torch.storage._StorageBase, f'is_{custom_backend_name}', wrap_storage_backend) 236 237 def wrap_storage_to(self, device=None, non_blocking=False): 238 r"""Return a copy of this object in custom device memory. 239 240 If this object is already in device memory and on the correct device, then 241 no copy is performed and the original object is returned. 242 243 Args: 244 device (int): The destination device id. Defaults to the current device. 245 non_blocking (bool): If ``True`` and the source is in pinned memory, 246 the copy will be asynchronous with respect to the host. Otherwise, 247 the argument has no effect. 248 """ 249 # There should be a judgment related to storage device and a judgment related to storage type, 250 # but it depends on the extended function, so this part is temporarily omitted in the automatic generation. 251 device_idx = _normalization_device(custom_backend_name, device) 252 253 if getattr(self, f'is_{custom_backend_name}'): 254 # storage has already on expected device. 255 if self.get_device() == device_idx: 256 return self 257 # For sparse storage, custom need to extend the implementation by themselves. 258 if self.is_sparse: 259 raise RuntimeError(f"Can not support a sparse storage move to {custom_backend_name} backend") 260 # create untyped_storage and copy data 261 untyped_storage = torch.UntypedStorage( 262 self.size(), device=torch.device(f'{custom_backend_name}:{device_idx}') 263 ) 264 untyped_storage.copy_(self, non_blocking) 265 return untyped_storage 266 267 _check_register_once(torch.storage._StorageBase, custom_backend_name) 268 setattr(torch.storage._StorageBase, custom_backend_name, wrap_storage_to) 269 270 # Register the corresponding attribute for the TypedStorage class. 271 # When the TypedStorage class is removed, the registration is also removed. 272 273 @property # type: ignore[misc] 274 def wrap_typed_storage_backend(self: torch.storage.TypedStorage) -> bool: 275 torch.storage._warn_typed_storage_removal() 276 return self._untyped_storage.device.type == custom_backend_name 277 278 _check_register_once(torch.TypedStorage, f'is_{custom_backend_name}') 279 setattr(torch.storage.TypedStorage, f'is_{custom_backend_name}', wrap_typed_storage_backend) 280 281 def wrap_typed_storage_to(self: torch.storage.TypedStorage, 282 device=None, non_blocking=False, **kwargs) -> torch.storage.TypedStorage: 283 torch.storage._warn_typed_storage_removal() 284 if unsupported_dtype and self.dtype in unsupported_dtype: 285 raise RuntimeError(f"Cannot create {custom_backend_name} storage " 286 f"as {self.dtype} dtype is not supported by this backend") 287 custom_backend_storage: torch.UntypedStorage = getattr( 288 self._untyped_storage, custom_backend_name)(device, non_blocking, **kwargs) 289 return self._new_wrapped_storage(custom_backend_storage) 290 291 _check_register_once(torch.TypedStorage, custom_backend_name) 292 setattr(torch.TypedStorage, custom_backend_name, wrap_typed_storage_to) 293 294 295def generate_methods_for_privateuse1_backend(for_tensor: bool = True, for_module: bool = True, 296 for_packed_sequence: bool = True, 297 for_storage: bool = False, 298 unsupported_dtype: Optional[List[torch.dtype]] = None) -> None: 299 r""" 300 Automatically generate attributes and methods for the custom backend after rename privateuse1 backend. 301 302 In the default scenario, storage-related methods will not be generated automatically. 303 304 When you implement kernels for various torch operations, and register them to the PrivateUse1 dispatch key. 305 And call the function torch.rename_privateuse1_backend("foo") to rename your backend name. 306 At this point, you can easily register specific methods and attributes by calling this function. 307 Just like torch.Tensor.foo(), torch.Tensor.is_foo, torch.Storage.foo(), torch.Storage.is_foo. 308 309 Note: We recommend you use generic functions (check devices are equal or to(device=)). 310 We provide these methods for convenience only and they will be "monkey patched" onto the objects 311 and so will not be properly typed. For Storage methods generate, if you need to support sparse data storage, 312 you need to extend the implementation yourself. 313 314 Args: 315 for_tensor (bool): whether register related methods for torch.Tensor class. 316 for_module (bool): whether register related methods for torch.nn.Module class. 317 for_storage (bool): whether register related methods for torch.Storage class. 318 unsupported_dtype (List[torch.dtype]): takes effect only when the storage method needs to be generated, 319 indicating that the storage does not support the torch.dtype type. 320 321 Example:: 322 323 >>> # xdoctest: +SKIP("failing") 324 >>> torch.utils.rename_privateuse1_backend("foo") 325 >>> torch.utils.generate_methods_for_privateuse1_backend() 326 # Then automatically generate backend-related attributes and methods. 327 >>> a = torch.tensor(2).foo() 328 >>> a.is_foo 329 >>> hasattr(torch.nn.Module, 'foo') 330 """ 331 custom_backend_name = _get_privateuse1_backend_name() 332 333 if for_tensor: 334 _generate_tensor_methods_for_privateuse1_backend(custom_backend_name) 335 336 if for_module: 337 _generate_module_methods_for_privateuse1_backend(custom_backend_name) 338 339 if for_storage: 340 _generate_storage_methods_for_privateuse1_backend(custom_backend_name, unsupported_dtype) 341 342 if for_packed_sequence: 343 _generate_packed_sequence_methods_for_privateuse1_backend(custom_backend_name) 344 345def _get_custom_mod_func(func_name: str): 346 r""" 347 Return the func named `func_name` defined in custom device module. If not defined, 348 return `None`. And the func is registered with `torch.utils.rename_privateuse1_backend('foo')` 349 and `torch._register_device_module('foo', BackendModule)`. 350 If the custom device module or the func is not defined, it will give warning or error message. 351 Args: 352 func_name (str): return the callable func named func_name defined in custom device module. 353 Example:: 354 class DummyfooModule: 355 @staticmethod 356 def is_available(): 357 return True 358 @staticmethod 359 def func_name(*args, **kwargs): 360 .... 361 torch.utils.rename_privateuse1_backend("foo") 362 torch._register_device_module("foo", DummyfooModule) 363 foo_is_available_func = torch.utils.backend_registration._get_custom_mod_func("is_available") 364 if foo_is_available_func: 365 foo_is_available = foo_is_available_func() 366 func_ = torch.utils.backend_registration._get_custom_mod_func("func_name") 367 if func_: 368 result = func_(*args, **kwargs) 369 Attention: This function is not meant to be used directly by users, which is why 370 it is marked as private. It is a convenience function for backend implementers to 371 more easily call the hooks into their backend extensions. 372 """ 373 assert isinstance(func_name, str), f"func_name must be `str`, but got `{type(func_name)}`." 374 backend_name = _get_privateuse1_backend_name() 375 custom_device_mod = getattr(torch, backend_name, None) # type: ignore[arg-type] 376 function = getattr(custom_device_mod, func_name, None) # type: ignore[arg-type] 377 if custom_device_mod is None or function is None: 378 message = f'Try to call torch.{backend_name}.{func_name}. The backend must register a custom backend ' 379 message += f"module with `torch._register_device_module('{backend_name}', BackendModule)`. And " 380 message += f"BackendModule needs to have the following API's:\n `{func_name}(*args, **kwargs)`. \n" 381 raise RuntimeError(message) 382 return function 383