xref: /aosp_15_r20/external/pytorch/torch/xpu/random.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom typing import Iterable, List, Union
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom . import _lazy_call, _lazy_init, current_device, device_count
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerdef get_rng_state(device: Union[int, str, torch.device] = "xpu") -> Tensor:
11*da0073e9SAndroid Build Coastguard Worker    r"""Return the random number generator state of the specified GPU as a ByteTensor.
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    Args:
14*da0073e9SAndroid Build Coastguard Worker        device (torch.device or int, optional): The device to return the RNG state of.
15*da0073e9SAndroid Build Coastguard Worker            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    .. warning::
18*da0073e9SAndroid Build Coastguard Worker        This function eagerly initializes XPU.
19*da0073e9SAndroid Build Coastguard Worker    """
20*da0073e9SAndroid Build Coastguard Worker    _lazy_init()
21*da0073e9SAndroid Build Coastguard Worker    if isinstance(device, str):
22*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
23*da0073e9SAndroid Build Coastguard Worker    elif isinstance(device, int):
24*da0073e9SAndroid Build Coastguard Worker        device = torch.device("xpu", device)
25*da0073e9SAndroid Build Coastguard Worker    idx = device.index
26*da0073e9SAndroid Build Coastguard Worker    if idx is None:
27*da0073e9SAndroid Build Coastguard Worker        idx = current_device()
28*da0073e9SAndroid Build Coastguard Worker    default_generator = torch.xpu.default_generators[idx]
29*da0073e9SAndroid Build Coastguard Worker    return default_generator.get_state()
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerdef get_rng_state_all() -> List[Tensor]:
33*da0073e9SAndroid Build Coastguard Worker    r"""Return a list of ByteTensor representing the random number states of all devices."""
34*da0073e9SAndroid Build Coastguard Worker    results = []
35*da0073e9SAndroid Build Coastguard Worker    for i in range(device_count()):
36*da0073e9SAndroid Build Coastguard Worker        results.append(get_rng_state(i))
37*da0073e9SAndroid Build Coastguard Worker    return results
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Workerdef set_rng_state(
41*da0073e9SAndroid Build Coastguard Worker    new_state: Tensor, device: Union[int, str, torch.device] = "xpu"
42*da0073e9SAndroid Build Coastguard Worker) -> None:
43*da0073e9SAndroid Build Coastguard Worker    r"""Set the random number generator state of the specified GPU.
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    Args:
46*da0073e9SAndroid Build Coastguard Worker        new_state (torch.ByteTensor): The desired state
47*da0073e9SAndroid Build Coastguard Worker        device (torch.device or int, optional): The device to set the RNG state.
48*da0073e9SAndroid Build Coastguard Worker            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
49*da0073e9SAndroid Build Coastguard Worker    """
50*da0073e9SAndroid Build Coastguard Worker    with torch._C._DisableFuncTorch():
51*da0073e9SAndroid Build Coastguard Worker        new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
52*da0073e9SAndroid Build Coastguard Worker    if isinstance(device, str):
53*da0073e9SAndroid Build Coastguard Worker        device = torch.device(device)
54*da0073e9SAndroid Build Coastguard Worker    elif isinstance(device, int):
55*da0073e9SAndroid Build Coastguard Worker        device = torch.device("xpu", device)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    def cb():
58*da0073e9SAndroid Build Coastguard Worker        idx = device.index
59*da0073e9SAndroid Build Coastguard Worker        if idx is None:
60*da0073e9SAndroid Build Coastguard Worker            idx = current_device()
61*da0073e9SAndroid Build Coastguard Worker        default_generator = torch.xpu.default_generators[idx]
62*da0073e9SAndroid Build Coastguard Worker        default_generator.set_state(new_state_copy)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    _lazy_call(cb)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workerdef set_rng_state_all(new_states: Iterable[Tensor]) -> None:
68*da0073e9SAndroid Build Coastguard Worker    r"""Set the random number generator state of all devices.
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    Args:
71*da0073e9SAndroid Build Coastguard Worker        new_states (Iterable of torch.ByteTensor): The desired state for each device.
72*da0073e9SAndroid Build Coastguard Worker    """
73*da0073e9SAndroid Build Coastguard Worker    for i, state in enumerate(new_states):
74*da0073e9SAndroid Build Coastguard Worker        set_rng_state(state, i)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Workerdef manual_seed(seed: int) -> None:
78*da0073e9SAndroid Build Coastguard Worker    r"""Set the seed for generating random numbers for the current GPU.
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    It's safe to call this function if XPU is not available; in that case, it is silently ignored.
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    Args:
83*da0073e9SAndroid Build Coastguard Worker        seed (int): The desired seed.
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    .. warning::
86*da0073e9SAndroid Build Coastguard Worker        If you are working with a multi-GPU model, this function is insufficient
87*da0073e9SAndroid Build Coastguard Worker        to get determinism.  To seed all GPUs, use :func:`manual_seed_all`.
88*da0073e9SAndroid Build Coastguard Worker    """
89*da0073e9SAndroid Build Coastguard Worker    seed = int(seed)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def cb():
92*da0073e9SAndroid Build Coastguard Worker        idx = current_device()
93*da0073e9SAndroid Build Coastguard Worker        default_generator = torch.xpu.default_generators[idx]
94*da0073e9SAndroid Build Coastguard Worker        default_generator.manual_seed(seed)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    _lazy_call(cb, seed=True)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerdef manual_seed_all(seed: int) -> None:
100*da0073e9SAndroid Build Coastguard Worker    r"""Set the seed for generating random numbers on all GPUs.
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    It's safe to call this function if XPU is not available; in that case, it is silently ignored.
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    Args:
105*da0073e9SAndroid Build Coastguard Worker        seed (int): The desired seed.
106*da0073e9SAndroid Build Coastguard Worker    """
107*da0073e9SAndroid Build Coastguard Worker    seed = int(seed)
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def cb():
110*da0073e9SAndroid Build Coastguard Worker        for i in range(device_count()):
111*da0073e9SAndroid Build Coastguard Worker            default_generator = torch.xpu.default_generators[i]
112*da0073e9SAndroid Build Coastguard Worker            default_generator.manual_seed(seed)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    _lazy_call(cb, seed_all=True)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Workerdef seed() -> None:
118*da0073e9SAndroid Build Coastguard Worker    r"""Set the seed for generating random numbers to a random number for the current GPU.
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker    It's safe to call this function if XPU is not available; in that case, it is silently ignored.
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    .. warning::
123*da0073e9SAndroid Build Coastguard Worker        If you are working with a multi-GPU model, this function will only initialize
124*da0073e9SAndroid Build Coastguard Worker        the seed on one GPU.  To initialize all GPUs, use :func:`seed_all`.
125*da0073e9SAndroid Build Coastguard Worker    """
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    def cb():
128*da0073e9SAndroid Build Coastguard Worker        idx = current_device()
129*da0073e9SAndroid Build Coastguard Worker        default_generator = torch.xpu.default_generators[idx]
130*da0073e9SAndroid Build Coastguard Worker        default_generator.seed()
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker    _lazy_call(cb)
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Workerdef seed_all() -> None:
136*da0073e9SAndroid Build Coastguard Worker    r"""Set the seed for generating random numbers to a random number on all GPUs.
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    It's safe to call this function if XPU is not available; in that case, it is silently ignored.
139*da0073e9SAndroid Build Coastguard Worker    """
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker    def cb():
142*da0073e9SAndroid Build Coastguard Worker        random_seed = 0
143*da0073e9SAndroid Build Coastguard Worker        seeded = False
144*da0073e9SAndroid Build Coastguard Worker        for i in range(device_count()):
145*da0073e9SAndroid Build Coastguard Worker            default_generator = torch.xpu.default_generators[i]
146*da0073e9SAndroid Build Coastguard Worker            if not seeded:
147*da0073e9SAndroid Build Coastguard Worker                default_generator.seed()
148*da0073e9SAndroid Build Coastguard Worker                random_seed = default_generator.initial_seed()
149*da0073e9SAndroid Build Coastguard Worker                seeded = True
150*da0073e9SAndroid Build Coastguard Worker            else:
151*da0073e9SAndroid Build Coastguard Worker                default_generator.manual_seed(random_seed)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    _lazy_call(cb)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Workerdef initial_seed() -> int:
157*da0073e9SAndroid Build Coastguard Worker    r"""Return the current random seed of the current GPU.
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    .. warning::
160*da0073e9SAndroid Build Coastguard Worker        This function eagerly initializes XPU.
161*da0073e9SAndroid Build Coastguard Worker    """
162*da0073e9SAndroid Build Coastguard Worker    _lazy_init()
163*da0073e9SAndroid Build Coastguard Worker    idx = current_device()
164*da0073e9SAndroid Build Coastguard Worker    default_generator = torch.xpu.default_generators[idx]
165*da0073e9SAndroid Build Coastguard Worker    return default_generator.initial_seed()
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker__all__ = [
169*da0073e9SAndroid Build Coastguard Worker    "get_rng_state",
170*da0073e9SAndroid Build Coastguard Worker    "get_rng_state_all",
171*da0073e9SAndroid Build Coastguard Worker    "set_rng_state",
172*da0073e9SAndroid Build Coastguard Worker    "set_rng_state_all",
173*da0073e9SAndroid Build Coastguard Worker    "manual_seed",
174*da0073e9SAndroid Build Coastguard Worker    "manual_seed_all",
175*da0073e9SAndroid Build Coastguard Worker    "seed",
176*da0073e9SAndroid Build Coastguard Worker    "seed_all",
177*da0073e9SAndroid Build Coastguard Worker    "initial_seed",
178*da0073e9SAndroid Build Coastguard Worker]
179