1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport contextlib 3*da0073e9SAndroid Build Coastguard Workerimport warnings 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Generator 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerfrom torch._C import default_generator 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerdef set_rng_state(new_state: torch.Tensor) -> None: 11*da0073e9SAndroid Build Coastguard Worker r"""Sets the random number generator state. 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker .. note:: This function only works for CPU. For CUDA, please use 14*da0073e9SAndroid Build Coastguard Worker :func:`torch.manual_seed`, which works for both CPU and CUDA. 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker Args: 17*da0073e9SAndroid Build Coastguard Worker new_state (torch.ByteTensor): The desired state 18*da0073e9SAndroid Build Coastguard Worker """ 19*da0073e9SAndroid Build Coastguard Worker default_generator.set_state(new_state) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerdef get_rng_state() -> torch.Tensor: 23*da0073e9SAndroid Build Coastguard Worker r"""Returns the random number generator state as a `torch.ByteTensor`. 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker .. note:: The returned state is for the default generator on CPU only. 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker See also: :func:`torch.random.fork_rng`. 28*da0073e9SAndroid Build Coastguard Worker """ 29*da0073e9SAndroid Build Coastguard Worker return default_generator.get_state() 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerdef manual_seed(seed) -> torch._C.Generator: 33*da0073e9SAndroid Build Coastguard Worker r"""Sets the seed for generating random numbers on all devices. Returns a 34*da0073e9SAndroid Build Coastguard Worker `torch.Generator` object. 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker Args: 37*da0073e9SAndroid Build Coastguard Worker seed (int): The desired seed. Value must be within the inclusive range 38*da0073e9SAndroid Build Coastguard Worker `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError 39*da0073e9SAndroid Build Coastguard Worker is raised. Negative inputs are remapped to positive values with the formula 40*da0073e9SAndroid Build Coastguard Worker `0xffff_ffff_ffff_ffff + seed`. 41*da0073e9SAndroid Build Coastguard Worker """ 42*da0073e9SAndroid Build Coastguard Worker seed = int(seed) 43*da0073e9SAndroid Build Coastguard Worker import torch.cuda 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker if not torch.cuda._is_in_bad_fork(): 46*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed_all(seed) 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker import torch.mps 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker if not torch.mps._is_in_bad_fork(): 51*da0073e9SAndroid Build Coastguard Worker torch.mps.manual_seed(seed) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker import torch.xpu 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker if not torch.xpu._is_in_bad_fork(): 56*da0073e9SAndroid Build Coastguard Worker torch.xpu.manual_seed_all(seed) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker _seed_custom_device(seed) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker return default_generator.manual_seed(seed) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Workerdef seed() -> int: 64*da0073e9SAndroid Build Coastguard Worker r"""Sets the seed for generating random numbers to a non-deterministic 65*da0073e9SAndroid Build Coastguard Worker random number on all devices. Returns a 64 bit number used to seed the RNG. 66*da0073e9SAndroid Build Coastguard Worker """ 67*da0073e9SAndroid Build Coastguard Worker seed = default_generator.seed() 68*da0073e9SAndroid Build Coastguard Worker import torch.cuda 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker if not torch.cuda._is_in_bad_fork(): 71*da0073e9SAndroid Build Coastguard Worker torch.cuda.manual_seed_all(seed) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker import torch.mps 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker if not torch.mps._is_in_bad_fork(): 76*da0073e9SAndroid Build Coastguard Worker torch.mps.manual_seed(seed) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker import torch.xpu 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker if not torch.xpu._is_in_bad_fork(): 81*da0073e9SAndroid Build Coastguard Worker torch.xpu.manual_seed_all(seed) 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker _seed_custom_device(seed) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker return seed 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Workerdef _seed_custom_device(seed) -> None: 89*da0073e9SAndroid Build Coastguard Worker r"""Sets the seed to generate random numbers for custom device. 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker Args: 92*da0073e9SAndroid Build Coastguard Worker seed (int): The desired seed. 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker See [Note: support the custom device with privateuse1] 95*da0073e9SAndroid Build Coastguard Worker """ 96*da0073e9SAndroid Build Coastguard Worker seed = int(seed) 97*da0073e9SAndroid Build Coastguard Worker custom_backend_name = torch._C._get_privateuse1_backend_name() 98*da0073e9SAndroid Build Coastguard Worker if hasattr(torch, custom_backend_name): 99*da0073e9SAndroid Build Coastguard Worker custom_device_mod = getattr(torch, custom_backend_name) 100*da0073e9SAndroid Build Coastguard Worker _bad_fork_name = "_is_in_bad_fork" 101*da0073e9SAndroid Build Coastguard Worker _seed_all_name = "manual_seed_all" 102*da0073e9SAndroid Build Coastguard Worker if hasattr(custom_device_mod, _bad_fork_name) and hasattr( 103*da0073e9SAndroid Build Coastguard Worker custom_device_mod, _seed_all_name 104*da0073e9SAndroid Build Coastguard Worker ): 105*da0073e9SAndroid Build Coastguard Worker if not getattr(custom_device_mod, _bad_fork_name)(): 106*da0073e9SAndroid Build Coastguard Worker getattr(custom_device_mod, _seed_all_name)(seed) 107*da0073e9SAndroid Build Coastguard Worker else: 108*da0073e9SAndroid Build Coastguard Worker message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's " 109*da0073e9SAndroid Build Coastguard Worker message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module." 110*da0073e9SAndroid Build Coastguard Worker warnings.warn(message, UserWarning, stacklevel=3) 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Workerdef initial_seed() -> int: 114*da0073e9SAndroid Build Coastguard Worker r"""Returns the initial seed for generating random numbers as a 115*da0073e9SAndroid Build Coastguard Worker Python `long`. 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker .. note:: The returned seed is for the default generator on CPU only. 118*da0073e9SAndroid Build Coastguard Worker """ 119*da0073e9SAndroid Build Coastguard Worker return default_generator.initial_seed() 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker_fork_rng_warned_already = False 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 126*da0073e9SAndroid Build Coastguard Workerdef fork_rng( 127*da0073e9SAndroid Build Coastguard Worker devices=None, 128*da0073e9SAndroid Build Coastguard Worker enabled=True, 129*da0073e9SAndroid Build Coastguard Worker _caller="fork_rng", 130*da0073e9SAndroid Build Coastguard Worker _devices_kw="devices", 131*da0073e9SAndroid Build Coastguard Worker device_type="cuda", 132*da0073e9SAndroid Build Coastguard Worker) -> Generator: 133*da0073e9SAndroid Build Coastguard Worker """ 134*da0073e9SAndroid Build Coastguard Worker Forks the RNG, so that when you return, the RNG is reset 135*da0073e9SAndroid Build Coastguard Worker to the state that it was previously in. 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker Args: 138*da0073e9SAndroid Build Coastguard Worker devices (iterable of Device IDs): devices for which to fork 139*da0073e9SAndroid Build Coastguard Worker the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates 140*da0073e9SAndroid Build Coastguard Worker on all devices, but will emit a warning if your machine has a lot 141*da0073e9SAndroid Build Coastguard Worker of devices, since this function will run very slowly in that case. 142*da0073e9SAndroid Build Coastguard Worker If you explicitly specify devices, this warning will be suppressed 143*da0073e9SAndroid Build Coastguard Worker enabled (bool): if ``False``, the RNG is not forked. This is a convenience 144*da0073e9SAndroid Build Coastguard Worker argument for easily disabling the context manager without having 145*da0073e9SAndroid Build Coastguard Worker to delete it and unindent your Python code under it. 146*da0073e9SAndroid Build Coastguard Worker device_type (str): device type str, default is `cuda`. As for custom device, 147*da0073e9SAndroid Build Coastguard Worker see details in [Note: support the custom device with privateuse1] 148*da0073e9SAndroid Build Coastguard Worker """ 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker device_type = torch.device(device_type).type 151*da0073e9SAndroid Build Coastguard Worker device_mod = getattr(torch, device_type, None) 152*da0073e9SAndroid Build Coastguard Worker if device_mod is None: 153*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 154*da0073e9SAndroid Build Coastguard Worker f"torch has no module of `{device_type}`, you should register " 155*da0073e9SAndroid Build Coastguard Worker + "a module by `torch._register_device_module`." 156*da0073e9SAndroid Build Coastguard Worker ) 157*da0073e9SAndroid Build Coastguard Worker global _fork_rng_warned_already 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker # Internal arguments: 160*da0073e9SAndroid Build Coastguard Worker # _caller: the function which called fork_rng, which the user used 161*da0073e9SAndroid Build Coastguard Worker # _devices_kw: the devices keyword of _caller 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker if not enabled: 164*da0073e9SAndroid Build Coastguard Worker yield 165*da0073e9SAndroid Build Coastguard Worker return 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker if devices is None: 168*da0073e9SAndroid Build Coastguard Worker num_devices = device_mod.device_count() 169*da0073e9SAndroid Build Coastguard Worker if num_devices > 1 and not _fork_rng_warned_already: 170*da0073e9SAndroid Build Coastguard Worker message = ( 171*da0073e9SAndroid Build Coastguard Worker f"{device_type.upper()} reports that you have {num_devices} available devices, and " 172*da0073e9SAndroid Build Coastguard Worker f"you have used {_caller} without explicitly specifying which devices are being used. " 173*da0073e9SAndroid Build Coastguard Worker f"For safety, we initialize *every* {device_type.upper()} device by default, which can " 174*da0073e9SAndroid Build Coastguard Worker f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only" 175*da0073e9SAndroid Build Coastguard Worker f" making use of a few {device_type.upper()} devices, set the environment variable " 176*da0073e9SAndroid Build Coastguard Worker f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} " 177*da0073e9SAndroid Build Coastguard Worker "with the set of devices you are actually using. For example, if you are using CPU only, " 178*da0073e9SAndroid Build Coastguard Worker "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, " 179*da0073e9SAndroid Build Coastguard Worker f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices " 180*da0073e9SAndroid Build Coastguard Worker f"and suppress this warning, set the '{_devices_kw}' keyword argument to " 181*da0073e9SAndroid Build Coastguard Worker f"`range(torch.{device_type}.device_count())`." 182*da0073e9SAndroid Build Coastguard Worker ) 183*da0073e9SAndroid Build Coastguard Worker warnings.warn(message) 184*da0073e9SAndroid Build Coastguard Worker _fork_rng_warned_already = True 185*da0073e9SAndroid Build Coastguard Worker devices = list(range(num_devices)) 186*da0073e9SAndroid Build Coastguard Worker else: 187*da0073e9SAndroid Build Coastguard Worker # Protect against user passing us a generator; we need to traverse this 188*da0073e9SAndroid Build Coastguard Worker # multiple times but a generator will be exhausted upon first traversal 189*da0073e9SAndroid Build Coastguard Worker devices = list(devices) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker cpu_rng_state = torch.get_rng_state() 192*da0073e9SAndroid Build Coastguard Worker device_rng_states = [device_mod.get_rng_state(device) for device in devices] 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker try: 195*da0073e9SAndroid Build Coastguard Worker yield 196*da0073e9SAndroid Build Coastguard Worker finally: 197*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(cpu_rng_state) 198*da0073e9SAndroid Build Coastguard Worker for device, device_rng_state in zip(devices, device_rng_states): 199*da0073e9SAndroid Build Coastguard Worker device_mod.set_rng_state(device_rng_state, device) 200