1# mypy: allow-untyped-defs 2from typing import Iterable, List, Union 3 4import torch 5from torch import Tensor 6 7from . import _lazy_call, _lazy_init, current_device, device_count 8 9 10def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor: 11 r"""Return the random number generator state of the specified GPU as a ByteTensor. 12 13 Args: 14 device (torch.device or int, optional): The device to return the RNG state of. 15 Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). 16 17 .. warning:: 18 This function eagerly initializes XPU. 19 """ 20 _lazy_init() 21 if isinstance(device, str): 22 device = torch.device(device) 23 elif isinstance(device, int): 24 device = torch.device("xpu", device) 25 idx = device.index 26 if idx is None: 27 idx = current_device() 28 default_generator = torch.xpu.default_generators[idx] 29 return default_generator.get_state() 30 31 32def get_rng_state_all() -> List[Tensor]: 33 r"""Return a list of ByteTensor representing the random number states of all devices.""" 34 results = [] 35 for i in range(device_count()): 36 results.append(get_rng_state(i)) 37 return results 38 39 40def set_rng_state( 41 new_state: Tensor, device: Union[int, str, torch.device] = "xpu" 42) -> None: 43 r"""Set the random number generator state of the specified GPU. 44 45 Args: 46 new_state (torch.ByteTensor): The desired state 47 device (torch.device or int, optional): The device to set the RNG state. 48 Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). 49 """ 50 with torch._C._DisableFuncTorch(): 51 new_state_copy = new_state.clone(memory_format=torch.contiguous_format) 52 if isinstance(device, str): 53 device = torch.device(device) 54 elif isinstance(device, int): 55 device = torch.device("xpu", device) 56 57 def cb(): 58 idx = device.index 59 if idx is None: 60 idx = current_device() 61 default_generator = torch.xpu.default_generators[idx] 62 default_generator.set_state(new_state_copy) 63 64 _lazy_call(cb) 65 66 67def set_rng_state_all(new_states: Iterable[Tensor]) -> None: 68 r"""Set the random number generator state of all devices. 69 70 Args: 71 new_states (Iterable of torch.ByteTensor): The desired state for each device. 72 """ 73 for i, state in enumerate(new_states): 74 set_rng_state(state, i) 75 76 77def manual_seed(seed: int) -> None: 78 r"""Set the seed for generating random numbers for the current GPU. 79 80 It's safe to call this function if XPU is not available; in that case, it is silently ignored. 81 82 Args: 83 seed (int): The desired seed. 84 85 .. warning:: 86 If you are working with a multi-GPU model, this function is insufficient 87 to get determinism. To seed all GPUs, use :func:`manual_seed_all`. 88 """ 89 seed = int(seed) 90 91 def cb(): 92 idx = current_device() 93 default_generator = torch.xpu.default_generators[idx] 94 default_generator.manual_seed(seed) 95 96 _lazy_call(cb, seed=True) 97 98 99def manual_seed_all(seed: int) -> None: 100 r"""Set the seed for generating random numbers on all GPUs. 101 102 It's safe to call this function if XPU is not available; in that case, it is silently ignored. 103 104 Args: 105 seed (int): The desired seed. 106 """ 107 seed = int(seed) 108 109 def cb(): 110 for i in range(device_count()): 111 default_generator = torch.xpu.default_generators[i] 112 default_generator.manual_seed(seed) 113 114 _lazy_call(cb, seed_all=True) 115 116 117def seed() -> None: 118 r"""Set the seed for generating random numbers to a random number for the current GPU. 119 120 It's safe to call this function if XPU is not available; in that case, it is silently ignored. 121 122 .. warning:: 123 If you are working with a multi-GPU model, this function will only initialize 124 the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. 125 """ 126 127 def cb(): 128 idx = current_device() 129 default_generator = torch.xpu.default_generators[idx] 130 default_generator.seed() 131 132 _lazy_call(cb) 133 134 135def seed_all() -> None: 136 r"""Set the seed for generating random numbers to a random number on all GPUs. 137 138 It's safe to call this function if XPU is not available; in that case, it is silently ignored. 139 """ 140 141 def cb(): 142 random_seed = 0 143 seeded = False 144 for i in range(device_count()): 145 default_generator = torch.xpu.default_generators[i] 146 if not seeded: 147 default_generator.seed() 148 random_seed = default_generator.initial_seed() 149 seeded = True 150 else: 151 default_generator.manual_seed(random_seed) 152 153 _lazy_call(cb) 154 155 156def initial_seed() -> int: 157 r"""Return the current random seed of the current GPU. 158 159 .. warning:: 160 This function eagerly initializes XPU. 161 """ 162 _lazy_init() 163 idx = current_device() 164 default_generator = torch.xpu.default_generators[idx] 165 return default_generator.initial_seed() 166 167 168__all__ = [ 169 "get_rng_state", 170 "get_rng_state_all", 171 "set_rng_state", 172 "set_rng_state_all", 173 "manual_seed", 174 "manual_seed_all", 175 "seed", 176 "seed_all", 177 "initial_seed", 178] 179