xref: /aosp_15_r20/external/pytorch/torch/xpu/random.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Iterable, List, Union
3
4import torch
5from torch import Tensor
6
7from . import _lazy_call, _lazy_init, current_device, device_count
8
9
10def get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor:
11    r"""Return the random number generator state of the specified GPU as a ByteTensor.
12
13    Args:
14        device (torch.device or int, optional): The device to return the RNG state of.
15            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
16
17    .. warning::
18        This function eagerly initializes XPU.
19    """
20    _lazy_init()
21    if isinstance(device, str):
22        device = torch.device(device)
23    elif isinstance(device, int):
24        device = torch.device("xpu", device)
25    idx = device.index
26    if idx is None:
27        idx = current_device()
28    default_generator = torch.xpu.default_generators[idx]
29    return default_generator.get_state()
30
31
32def get_rng_state_all() -> List[Tensor]:
33    r"""Return a list of ByteTensor representing the random number states of all devices."""
34    results = []
35    for i in range(device_count()):
36        results.append(get_rng_state(i))
37    return results
38
39
40def set_rng_state(
41    new_state: Tensor, device: Union[int, str, torch.device] = "xpu"
42) -> None:
43    r"""Set the random number generator state of the specified GPU.
44
45    Args:
46        new_state (torch.ByteTensor): The desired state
47        device (torch.device or int, optional): The device to set the RNG state.
48            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
49    """
50    with torch._C._DisableFuncTorch():
51        new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
52    if isinstance(device, str):
53        device = torch.device(device)
54    elif isinstance(device, int):
55        device = torch.device("xpu", device)
56
57    def cb():
58        idx = device.index
59        if idx is None:
60            idx = current_device()
61        default_generator = torch.xpu.default_generators[idx]
62        default_generator.set_state(new_state_copy)
63
64    _lazy_call(cb)
65
66
67def set_rng_state_all(new_states: Iterable[Tensor]) -> None:
68    r"""Set the random number generator state of all devices.
69
70    Args:
71        new_states (Iterable of torch.ByteTensor): The desired state for each device.
72    """
73    for i, state in enumerate(new_states):
74        set_rng_state(state, i)
75
76
77def manual_seed(seed: int) -> None:
78    r"""Set the seed for generating random numbers for the current GPU.
79
80    It's safe to call this function if XPU is not available; in that case, it is silently ignored.
81
82    Args:
83        seed (int): The desired seed.
84
85    .. warning::
86        If you are working with a multi-GPU model, this function is insufficient
87        to get determinism.  To seed all GPUs, use :func:`manual_seed_all`.
88    """
89    seed = int(seed)
90
91    def cb():
92        idx = current_device()
93        default_generator = torch.xpu.default_generators[idx]
94        default_generator.manual_seed(seed)
95
96    _lazy_call(cb, seed=True)
97
98
99def manual_seed_all(seed: int) -> None:
100    r"""Set the seed for generating random numbers on all GPUs.
101
102    It's safe to call this function if XPU is not available; in that case, it is silently ignored.
103
104    Args:
105        seed (int): The desired seed.
106    """
107    seed = int(seed)
108
109    def cb():
110        for i in range(device_count()):
111            default_generator = torch.xpu.default_generators[i]
112            default_generator.manual_seed(seed)
113
114    _lazy_call(cb, seed_all=True)
115
116
117def seed() -> None:
118    r"""Set the seed for generating random numbers to a random number for the current GPU.
119
120    It's safe to call this function if XPU is not available; in that case, it is silently ignored.
121
122    .. warning::
123        If you are working with a multi-GPU model, this function will only initialize
124        the seed on one GPU.  To initialize all GPUs, use :func:`seed_all`.
125    """
126
127    def cb():
128        idx = current_device()
129        default_generator = torch.xpu.default_generators[idx]
130        default_generator.seed()
131
132    _lazy_call(cb)
133
134
135def seed_all() -> None:
136    r"""Set the seed for generating random numbers to a random number on all GPUs.
137
138    It's safe to call this function if XPU is not available; in that case, it is silently ignored.
139    """
140
141    def cb():
142        random_seed = 0
143        seeded = False
144        for i in range(device_count()):
145            default_generator = torch.xpu.default_generators[i]
146            if not seeded:
147                default_generator.seed()
148                random_seed = default_generator.initial_seed()
149                seeded = True
150            else:
151                default_generator.manual_seed(random_seed)
152
153    _lazy_call(cb)
154
155
156def initial_seed() -> int:
157    r"""Return the current random seed of the current GPU.
158
159    .. warning::
160        This function eagerly initializes XPU.
161    """
162    _lazy_init()
163    idx = current_device()
164    default_generator = torch.xpu.default_generators[idx]
165    return default_generator.initial_seed()
166
167
168__all__ = [
169    "get_rng_state",
170    "get_rng_state_all",
171    "set_rng_state",
172    "set_rng_state_all",
173    "manual_seed",
174    "manual_seed_all",
175    "seed",
176    "seed_all",
177    "initial_seed",
178]
179