1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom typing import Iterable, List, Union 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom . import _lazy_call, _lazy_init, current_device, device_count 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerdef get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor: 11*da0073e9SAndroid Build Coastguard Worker r"""Return the random number generator state of the specified GPU as a ByteTensor. 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker Args: 14*da0073e9SAndroid Build Coastguard Worker device (torch.device or int, optional): The device to return the RNG state of. 15*da0073e9SAndroid Build Coastguard Worker Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker .. warning:: 18*da0073e9SAndroid Build Coastguard Worker This function eagerly initializes XPU. 19*da0073e9SAndroid Build Coastguard Worker """ 20*da0073e9SAndroid Build Coastguard Worker _lazy_init() 21*da0073e9SAndroid Build Coastguard Worker if isinstance(device, str): 22*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 23*da0073e9SAndroid Build Coastguard Worker elif isinstance(device, int): 24*da0073e9SAndroid Build Coastguard Worker device = torch.device("xpu", device) 25*da0073e9SAndroid Build Coastguard Worker idx = device.index 26*da0073e9SAndroid Build Coastguard Worker if idx is None: 27*da0073e9SAndroid Build Coastguard Worker idx = current_device() 28*da0073e9SAndroid Build Coastguard Worker default_generator = torch.xpu.default_generators[idx] 29*da0073e9SAndroid Build Coastguard Worker return default_generator.get_state() 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerdef get_rng_state_all() -> List[Tensor]: 33*da0073e9SAndroid Build Coastguard Worker r"""Return a list of ByteTensor representing the random number states of all devices.""" 34*da0073e9SAndroid Build Coastguard Worker results = [] 35*da0073e9SAndroid Build Coastguard Worker for i in range(device_count()): 36*da0073e9SAndroid Build Coastguard Worker results.append(get_rng_state(i)) 37*da0073e9SAndroid Build Coastguard Worker return results 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerdef set_rng_state( 41*da0073e9SAndroid Build Coastguard Worker new_state: Tensor, device: Union[int, str, torch.device] = "xpu" 42*da0073e9SAndroid Build Coastguard Worker) -> None: 43*da0073e9SAndroid Build Coastguard Worker r"""Set the random number generator state of the specified GPU. 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker Args: 46*da0073e9SAndroid Build Coastguard Worker new_state (torch.ByteTensor): The desired state 47*da0073e9SAndroid Build Coastguard Worker device (torch.device or int, optional): The device to set the RNG state. 48*da0073e9SAndroid Build Coastguard Worker Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). 49*da0073e9SAndroid Build Coastguard Worker """ 50*da0073e9SAndroid Build Coastguard Worker with torch._C._DisableFuncTorch(): 51*da0073e9SAndroid Build Coastguard Worker new_state_copy = new_state.clone(memory_format=torch.contiguous_format) 52*da0073e9SAndroid Build Coastguard Worker if isinstance(device, str): 53*da0073e9SAndroid Build Coastguard Worker device = torch.device(device) 54*da0073e9SAndroid Build Coastguard Worker elif isinstance(device, int): 55*da0073e9SAndroid Build Coastguard Worker device = torch.device("xpu", device) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def cb(): 58*da0073e9SAndroid Build Coastguard Worker idx = device.index 59*da0073e9SAndroid Build Coastguard Worker if idx is None: 60*da0073e9SAndroid Build Coastguard Worker idx = current_device() 61*da0073e9SAndroid Build Coastguard Worker default_generator = torch.xpu.default_generators[idx] 62*da0073e9SAndroid Build Coastguard Worker default_generator.set_state(new_state_copy) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workerdef set_rng_state_all(new_states: Iterable[Tensor]) -> None: 68*da0073e9SAndroid Build Coastguard Worker r"""Set the random number generator state of all devices. 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker Args: 71*da0073e9SAndroid Build Coastguard Worker new_states (Iterable of torch.ByteTensor): The desired state for each device. 72*da0073e9SAndroid Build Coastguard Worker """ 73*da0073e9SAndroid Build Coastguard Worker for i, state in enumerate(new_states): 74*da0073e9SAndroid Build Coastguard Worker set_rng_state(state, i) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Workerdef manual_seed(seed: int) -> None: 78*da0073e9SAndroid Build Coastguard Worker r"""Set the seed for generating random numbers for the current GPU. 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker It's safe to call this function if XPU is not available; in that case, it is silently ignored. 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker Args: 83*da0073e9SAndroid Build Coastguard Worker seed (int): The desired seed. 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker .. warning:: 86*da0073e9SAndroid Build Coastguard Worker If you are working with a multi-GPU model, this function is insufficient 87*da0073e9SAndroid Build Coastguard Worker to get determinism. To seed all GPUs, use :func:`manual_seed_all`. 88*da0073e9SAndroid Build Coastguard Worker """ 89*da0073e9SAndroid Build Coastguard Worker seed = int(seed) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def cb(): 92*da0073e9SAndroid Build Coastguard Worker idx = current_device() 93*da0073e9SAndroid Build Coastguard Worker default_generator = torch.xpu.default_generators[idx] 94*da0073e9SAndroid Build Coastguard Worker default_generator.manual_seed(seed) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb, seed=True) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Workerdef manual_seed_all(seed: int) -> None: 100*da0073e9SAndroid Build Coastguard Worker r"""Set the seed for generating random numbers on all GPUs. 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker It's safe to call this function if XPU is not available; in that case, it is silently ignored. 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker Args: 105*da0073e9SAndroid Build Coastguard Worker seed (int): The desired seed. 106*da0073e9SAndroid Build Coastguard Worker """ 107*da0073e9SAndroid Build Coastguard Worker seed = int(seed) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def cb(): 110*da0073e9SAndroid Build Coastguard Worker for i in range(device_count()): 111*da0073e9SAndroid Build Coastguard Worker default_generator = torch.xpu.default_generators[i] 112*da0073e9SAndroid Build Coastguard Worker default_generator.manual_seed(seed) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb, seed_all=True) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Workerdef seed() -> None: 118*da0073e9SAndroid Build Coastguard Worker r"""Set the seed for generating random numbers to a random number for the current GPU. 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker It's safe to call this function if XPU is not available; in that case, it is silently ignored. 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker .. warning:: 123*da0073e9SAndroid Build Coastguard Worker If you are working with a multi-GPU model, this function will only initialize 124*da0073e9SAndroid Build Coastguard Worker the seed on one GPU. To initialize all GPUs, use :func:`seed_all`. 125*da0073e9SAndroid Build Coastguard Worker """ 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker def cb(): 128*da0073e9SAndroid Build Coastguard Worker idx = current_device() 129*da0073e9SAndroid Build Coastguard Worker default_generator = torch.xpu.default_generators[idx] 130*da0073e9SAndroid Build Coastguard Worker default_generator.seed() 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Workerdef seed_all() -> None: 136*da0073e9SAndroid Build Coastguard Worker r"""Set the seed for generating random numbers to a random number on all GPUs. 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker It's safe to call this function if XPU is not available; in that case, it is silently ignored. 139*da0073e9SAndroid Build Coastguard Worker """ 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker def cb(): 142*da0073e9SAndroid Build Coastguard Worker random_seed = 0 143*da0073e9SAndroid Build Coastguard Worker seeded = False 144*da0073e9SAndroid Build Coastguard Worker for i in range(device_count()): 145*da0073e9SAndroid Build Coastguard Worker default_generator = torch.xpu.default_generators[i] 146*da0073e9SAndroid Build Coastguard Worker if not seeded: 147*da0073e9SAndroid Build Coastguard Worker default_generator.seed() 148*da0073e9SAndroid Build Coastguard Worker random_seed = default_generator.initial_seed() 149*da0073e9SAndroid Build Coastguard Worker seeded = True 150*da0073e9SAndroid Build Coastguard Worker else: 151*da0073e9SAndroid Build Coastguard Worker default_generator.manual_seed(random_seed) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker _lazy_call(cb) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Workerdef initial_seed() -> int: 157*da0073e9SAndroid Build Coastguard Worker r"""Return the current random seed of the current GPU. 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker .. warning:: 160*da0073e9SAndroid Build Coastguard Worker This function eagerly initializes XPU. 161*da0073e9SAndroid Build Coastguard Worker """ 162*da0073e9SAndroid Build Coastguard Worker _lazy_init() 163*da0073e9SAndroid Build Coastguard Worker idx = current_device() 164*da0073e9SAndroid Build Coastguard Worker default_generator = torch.xpu.default_generators[idx] 165*da0073e9SAndroid Build Coastguard Worker return default_generator.initial_seed() 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker__all__ = [ 169*da0073e9SAndroid Build Coastguard Worker "get_rng_state", 170*da0073e9SAndroid Build Coastguard Worker "get_rng_state_all", 171*da0073e9SAndroid Build Coastguard Worker "set_rng_state", 172*da0073e9SAndroid Build Coastguard Worker "set_rng_state_all", 173*da0073e9SAndroid Build Coastguard Worker "manual_seed", 174*da0073e9SAndroid Build Coastguard Worker "manual_seed_all", 175*da0073e9SAndroid Build Coastguard Worker "seed", 176*da0073e9SAndroid Build Coastguard Worker "seed_all", 177*da0073e9SAndroid Build Coastguard Worker "initial_seed", 178*da0073e9SAndroid Build Coastguard Worker] 179