xref: /aosp_15_r20/external/pytorch/torch/random.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport contextlib
3*da0073e9SAndroid Build Coastguard Workerimport warnings
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Generator
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch._C import default_generator
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerdef set_rng_state(new_state: torch.Tensor) -> None:
11*da0073e9SAndroid Build Coastguard Worker    r"""Sets the random number generator state.
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    .. note:: This function only works for CPU. For CUDA, please use
14*da0073e9SAndroid Build Coastguard Worker        :func:`torch.manual_seed`, which works for both CPU and CUDA.
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker    Args:
17*da0073e9SAndroid Build Coastguard Worker        new_state (torch.ByteTensor): The desired state
18*da0073e9SAndroid Build Coastguard Worker    """
19*da0073e9SAndroid Build Coastguard Worker    default_generator.set_state(new_state)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerdef get_rng_state() -> torch.Tensor:
23*da0073e9SAndroid Build Coastguard Worker    r"""Returns the random number generator state as a `torch.ByteTensor`.
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    .. note:: The returned state is for the default generator on CPU only.
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    See also: :func:`torch.random.fork_rng`.
28*da0073e9SAndroid Build Coastguard Worker    """
29*da0073e9SAndroid Build Coastguard Worker    return default_generator.get_state()
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerdef manual_seed(seed) -> torch._C.Generator:
33*da0073e9SAndroid Build Coastguard Worker    r"""Sets the seed for generating random numbers on all devices. Returns a
34*da0073e9SAndroid Build Coastguard Worker    `torch.Generator` object.
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    Args:
37*da0073e9SAndroid Build Coastguard Worker        seed (int): The desired seed. Value must be within the inclusive range
38*da0073e9SAndroid Build Coastguard Worker            `[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError
39*da0073e9SAndroid Build Coastguard Worker            is raised. Negative inputs are remapped to positive values with the formula
40*da0073e9SAndroid Build Coastguard Worker            `0xffff_ffff_ffff_ffff + seed`.
41*da0073e9SAndroid Build Coastguard Worker    """
42*da0073e9SAndroid Build Coastguard Worker    seed = int(seed)
43*da0073e9SAndroid Build Coastguard Worker    import torch.cuda
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    if not torch.cuda._is_in_bad_fork():
46*da0073e9SAndroid Build Coastguard Worker        torch.cuda.manual_seed_all(seed)
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    import torch.mps
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    if not torch.mps._is_in_bad_fork():
51*da0073e9SAndroid Build Coastguard Worker        torch.mps.manual_seed(seed)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    import torch.xpu
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    if not torch.xpu._is_in_bad_fork():
56*da0073e9SAndroid Build Coastguard Worker        torch.xpu.manual_seed_all(seed)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    _seed_custom_device(seed)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker    return default_generator.manual_seed(seed)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Workerdef seed() -> int:
64*da0073e9SAndroid Build Coastguard Worker    r"""Sets the seed for generating random numbers to a non-deterministic
65*da0073e9SAndroid Build Coastguard Worker    random number on all devices. Returns a 64 bit number used to seed the RNG.
66*da0073e9SAndroid Build Coastguard Worker    """
67*da0073e9SAndroid Build Coastguard Worker    seed = default_generator.seed()
68*da0073e9SAndroid Build Coastguard Worker    import torch.cuda
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    if not torch.cuda._is_in_bad_fork():
71*da0073e9SAndroid Build Coastguard Worker        torch.cuda.manual_seed_all(seed)
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    import torch.mps
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    if not torch.mps._is_in_bad_fork():
76*da0073e9SAndroid Build Coastguard Worker        torch.mps.manual_seed(seed)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    import torch.xpu
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    if not torch.xpu._is_in_bad_fork():
81*da0073e9SAndroid Build Coastguard Worker        torch.xpu.manual_seed_all(seed)
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    _seed_custom_device(seed)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    return seed
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Workerdef _seed_custom_device(seed) -> None:
89*da0073e9SAndroid Build Coastguard Worker    r"""Sets the seed to generate random numbers for custom device.
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    Args:
92*da0073e9SAndroid Build Coastguard Worker        seed (int): The desired seed.
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    See [Note: support the custom device with privateuse1]
95*da0073e9SAndroid Build Coastguard Worker    """
96*da0073e9SAndroid Build Coastguard Worker    seed = int(seed)
97*da0073e9SAndroid Build Coastguard Worker    custom_backend_name = torch._C._get_privateuse1_backend_name()
98*da0073e9SAndroid Build Coastguard Worker    if hasattr(torch, custom_backend_name):
99*da0073e9SAndroid Build Coastguard Worker        custom_device_mod = getattr(torch, custom_backend_name)
100*da0073e9SAndroid Build Coastguard Worker        _bad_fork_name = "_is_in_bad_fork"
101*da0073e9SAndroid Build Coastguard Worker        _seed_all_name = "manual_seed_all"
102*da0073e9SAndroid Build Coastguard Worker        if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
103*da0073e9SAndroid Build Coastguard Worker            custom_device_mod, _seed_all_name
104*da0073e9SAndroid Build Coastguard Worker        ):
105*da0073e9SAndroid Build Coastguard Worker            if not getattr(custom_device_mod, _bad_fork_name)():
106*da0073e9SAndroid Build Coastguard Worker                getattr(custom_device_mod, _seed_all_name)(seed)
107*da0073e9SAndroid Build Coastguard Worker        else:
108*da0073e9SAndroid Build Coastguard Worker            message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's "
109*da0073e9SAndroid Build Coastguard Worker            message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module."
110*da0073e9SAndroid Build Coastguard Worker            warnings.warn(message, UserWarning, stacklevel=3)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Workerdef initial_seed() -> int:
114*da0073e9SAndroid Build Coastguard Worker    r"""Returns the initial seed for generating random numbers as a
115*da0073e9SAndroid Build Coastguard Worker    Python `long`.
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    .. note:: The returned seed is for the default generator on CPU only.
118*da0073e9SAndroid Build Coastguard Worker    """
119*da0073e9SAndroid Build Coastguard Worker    return default_generator.initial_seed()
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker_fork_rng_warned_already = False
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
126*da0073e9SAndroid Build Coastguard Workerdef fork_rng(
127*da0073e9SAndroid Build Coastguard Worker    devices=None,
128*da0073e9SAndroid Build Coastguard Worker    enabled=True,
129*da0073e9SAndroid Build Coastguard Worker    _caller="fork_rng",
130*da0073e9SAndroid Build Coastguard Worker    _devices_kw="devices",
131*da0073e9SAndroid Build Coastguard Worker    device_type="cuda",
132*da0073e9SAndroid Build Coastguard Worker) -> Generator:
133*da0073e9SAndroid Build Coastguard Worker    """
134*da0073e9SAndroid Build Coastguard Worker    Forks the RNG, so that when you return, the RNG is reset
135*da0073e9SAndroid Build Coastguard Worker    to the state that it was previously in.
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    Args:
138*da0073e9SAndroid Build Coastguard Worker        devices (iterable of Device IDs): devices for which to fork
139*da0073e9SAndroid Build Coastguard Worker            the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates
140*da0073e9SAndroid Build Coastguard Worker            on all devices, but will emit a warning if your machine has a lot
141*da0073e9SAndroid Build Coastguard Worker            of devices, since this function will run very slowly in that case.
142*da0073e9SAndroid Build Coastguard Worker            If you explicitly specify devices, this warning will be suppressed
143*da0073e9SAndroid Build Coastguard Worker        enabled (bool): if ``False``, the RNG is not forked.  This is a convenience
144*da0073e9SAndroid Build Coastguard Worker            argument for easily disabling the context manager without having
145*da0073e9SAndroid Build Coastguard Worker            to delete it and unindent your Python code under it.
146*da0073e9SAndroid Build Coastguard Worker        device_type (str): device type str, default is `cuda`. As for custom device,
147*da0073e9SAndroid Build Coastguard Worker            see details in [Note: support the custom device with privateuse1]
148*da0073e9SAndroid Build Coastguard Worker    """
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker    device_type = torch.device(device_type).type
151*da0073e9SAndroid Build Coastguard Worker    device_mod = getattr(torch, device_type, None)
152*da0073e9SAndroid Build Coastguard Worker    if device_mod is None:
153*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
154*da0073e9SAndroid Build Coastguard Worker            f"torch has no module of `{device_type}`, you should register "
155*da0073e9SAndroid Build Coastguard Worker            + "a module by `torch._register_device_module`."
156*da0073e9SAndroid Build Coastguard Worker        )
157*da0073e9SAndroid Build Coastguard Worker    global _fork_rng_warned_already
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    # Internal arguments:
160*da0073e9SAndroid Build Coastguard Worker    #   _caller: the function which called fork_rng, which the user used
161*da0073e9SAndroid Build Coastguard Worker    #   _devices_kw: the devices keyword of _caller
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    if not enabled:
164*da0073e9SAndroid Build Coastguard Worker        yield
165*da0073e9SAndroid Build Coastguard Worker        return
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    if devices is None:
168*da0073e9SAndroid Build Coastguard Worker        num_devices = device_mod.device_count()
169*da0073e9SAndroid Build Coastguard Worker        if num_devices > 1 and not _fork_rng_warned_already:
170*da0073e9SAndroid Build Coastguard Worker            message = (
171*da0073e9SAndroid Build Coastguard Worker                f"{device_type.upper()} reports that you have {num_devices} available devices, and "
172*da0073e9SAndroid Build Coastguard Worker                f"you have used {_caller} without explicitly specifying which devices are being used. "
173*da0073e9SAndroid Build Coastguard Worker                f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
174*da0073e9SAndroid Build Coastguard Worker                f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
175*da0073e9SAndroid Build Coastguard Worker                f" making use of a few {device_type.upper()} devices, set the environment variable "
176*da0073e9SAndroid Build Coastguard Worker                f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
177*da0073e9SAndroid Build Coastguard Worker                "with the set of devices you are actually using. For example, if you are using CPU only, "
178*da0073e9SAndroid Build Coastguard Worker                "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
179*da0073e9SAndroid Build Coastguard Worker                f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0].  To initialize all devices "
180*da0073e9SAndroid Build Coastguard Worker                f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
181*da0073e9SAndroid Build Coastguard Worker                f"`range(torch.{device_type}.device_count())`."
182*da0073e9SAndroid Build Coastguard Worker            )
183*da0073e9SAndroid Build Coastguard Worker            warnings.warn(message)
184*da0073e9SAndroid Build Coastguard Worker            _fork_rng_warned_already = True
185*da0073e9SAndroid Build Coastguard Worker        devices = list(range(num_devices))
186*da0073e9SAndroid Build Coastguard Worker    else:
187*da0073e9SAndroid Build Coastguard Worker        # Protect against user passing us a generator; we need to traverse this
188*da0073e9SAndroid Build Coastguard Worker        # multiple times but a generator will be exhausted upon first traversal
189*da0073e9SAndroid Build Coastguard Worker        devices = list(devices)
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    cpu_rng_state = torch.get_rng_state()
192*da0073e9SAndroid Build Coastguard Worker    device_rng_states = [device_mod.get_rng_state(device) for device in devices]
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    try:
195*da0073e9SAndroid Build Coastguard Worker        yield
196*da0073e9SAndroid Build Coastguard Worker    finally:
197*da0073e9SAndroid Build Coastguard Worker        torch.set_rng_state(cpu_rng_state)
198*da0073e9SAndroid Build Coastguard Worker        for device, device_rng_state in zip(devices, device_rng_states):
199*da0073e9SAndroid Build Coastguard Worker            device_mod.set_rng_state(device_rng_state, device)
200