xref: /aosp_15_r20/external/pytorch/torch/_numpy/random.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3"""Wrapper to mimic (parts of) np.random API surface.
4
5NumPy has strict guarantees on reproducibility etc; here we don't give any.
6
7Q: default dtype is float64 in numpy
8
9"""
10from __future__ import annotations
11
12import functools
13from math import sqrt
14from typing import Optional
15
16import torch
17
18from . import _dtypes_impl, _util
19from ._normalizations import array_or_scalar, ArrayLike, normalizer
20
21
22__all__ = [
23    "seed",
24    "random_sample",
25    "sample",
26    "random",
27    "rand",
28    "randn",
29    "normal",
30    "choice",
31    "randint",
32    "shuffle",
33    "uniform",
34]
35
36
37def use_numpy_random():
38    # local import to avoid ref cycles
39    import torch._dynamo.config as config
40
41    return config.use_numpy_random_stream
42
43
44def deco_stream(func):
45    @functools.wraps(func)
46    def inner(*args, **kwds):
47        if not use_numpy_random():
48            return func(*args, **kwds)
49        else:
50            import numpy
51
52            from ._ndarray import ndarray
53
54            f = getattr(numpy.random, func.__name__)
55
56            # numpy funcs accept numpy ndarrays, unwrap
57            args = tuple(
58                arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args
59            )
60            kwds = {
61                key: val.tensor.numpy() if isinstance(val, ndarray) else val
62                for key, val in kwds.items()
63            }
64
65            value = f(*args, **kwds)
66
67            # `value` can be either numpy.ndarray or python scalar (or None)
68            if isinstance(value, numpy.ndarray):
69                value = ndarray(torch.as_tensor(value))
70
71            return value
72
73    return inner
74
75
76@deco_stream
77def seed(seed=None):
78    if seed is not None:
79        torch.random.manual_seed(seed)
80
81
82@deco_stream
83def random_sample(size=None):
84    if size is None:
85        size = ()
86    dtype = _dtypes_impl.default_dtypes().float_dtype
87    values = torch.empty(size, dtype=dtype).uniform_()
88    return array_or_scalar(values, return_scalar=size == ())
89
90
91def rand(*size):
92    if size == ():
93        size = None
94    return random_sample(size)
95
96
97sample = random_sample
98random = random_sample
99
100
101@deco_stream
102def uniform(low=0.0, high=1.0, size=None):
103    if size is None:
104        size = ()
105    dtype = _dtypes_impl.default_dtypes().float_dtype
106    values = torch.empty(size, dtype=dtype).uniform_(low, high)
107    return array_or_scalar(values, return_scalar=size == ())
108
109
110@deco_stream
111def randn(*size):
112    dtype = _dtypes_impl.default_dtypes().float_dtype
113    values = torch.randn(size, dtype=dtype)
114    return array_or_scalar(values, return_scalar=size == ())
115
116
117@deco_stream
118def normal(loc=0.0, scale=1.0, size=None):
119    if size is None:
120        size = ()
121    dtype = _dtypes_impl.default_dtypes().float_dtype
122    values = torch.empty(size, dtype=dtype).normal_(loc, scale)
123    return array_or_scalar(values, return_scalar=size == ())
124
125
126@deco_stream
127def shuffle(x):
128    # no @normalizer because we do not cast e.g. lists to tensors
129    from ._ndarray import ndarray
130
131    if isinstance(x, torch.Tensor):
132        tensor = x
133    elif isinstance(x, ndarray):
134        tensor = x.tensor
135    else:
136        raise NotImplementedError("We do not random.shuffle lists in-place")
137
138    perm = torch.randperm(tensor.shape[0])
139    xp = tensor[perm]
140    tensor.copy_(xp)
141
142
143@deco_stream
144def randint(low, high=None, size=None):
145    if size is None:
146        size = ()
147    if not isinstance(size, (tuple, list)):
148        size = (size,)
149    if high is None:
150        low, high = 0, low
151    values = torch.randint(low, high, size=size)
152    return array_or_scalar(values, int, return_scalar=size == ())
153
154
155@deco_stream
156@normalizer
157def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None):
158    # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch
159    if a.numel() == 1:
160        a = torch.arange(a)
161
162    # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises
163
164    # number of draws
165    if size is None:
166        num_el = 1
167    elif _util.is_sequence(size):
168        num_el = 1
169        for el in size:
170            num_el *= el
171    else:
172        num_el = size
173
174    # prepare the probabilities
175    if p is None:
176        p = torch.ones_like(a) / a.shape[0]
177
178    # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973
179    atol = sqrt(torch.finfo(p.dtype).eps)
180    if abs(p.sum() - 1.0) > atol:
181        raise ValueError("probabilities do not sum to 1.")
182
183    # actually sample
184    indices = torch.multinomial(p, num_el, replacement=replace)
185
186    if _util.is_sequence(size):
187        indices = indices.reshape(size)
188
189    samples = a[indices]
190
191    return samples
192