1# mypy: ignore-errors 2 3"""Wrapper to mimic (parts of) np.random API surface. 4 5NumPy has strict guarantees on reproducibility etc; here we don't give any. 6 7Q: default dtype is float64 in numpy 8 9""" 10from __future__ import annotations 11 12import functools 13from math import sqrt 14from typing import Optional 15 16import torch 17 18from . import _dtypes_impl, _util 19from ._normalizations import array_or_scalar, ArrayLike, normalizer 20 21 22__all__ = [ 23 "seed", 24 "random_sample", 25 "sample", 26 "random", 27 "rand", 28 "randn", 29 "normal", 30 "choice", 31 "randint", 32 "shuffle", 33 "uniform", 34] 35 36 37def use_numpy_random(): 38 # local import to avoid ref cycles 39 import torch._dynamo.config as config 40 41 return config.use_numpy_random_stream 42 43 44def deco_stream(func): 45 @functools.wraps(func) 46 def inner(*args, **kwds): 47 if not use_numpy_random(): 48 return func(*args, **kwds) 49 else: 50 import numpy 51 52 from ._ndarray import ndarray 53 54 f = getattr(numpy.random, func.__name__) 55 56 # numpy funcs accept numpy ndarrays, unwrap 57 args = tuple( 58 arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args 59 ) 60 kwds = { 61 key: val.tensor.numpy() if isinstance(val, ndarray) else val 62 for key, val in kwds.items() 63 } 64 65 value = f(*args, **kwds) 66 67 # `value` can be either numpy.ndarray or python scalar (or None) 68 if isinstance(value, numpy.ndarray): 69 value = ndarray(torch.as_tensor(value)) 70 71 return value 72 73 return inner 74 75 76@deco_stream 77def seed(seed=None): 78 if seed is not None: 79 torch.random.manual_seed(seed) 80 81 82@deco_stream 83def random_sample(size=None): 84 if size is None: 85 size = () 86 dtype = _dtypes_impl.default_dtypes().float_dtype 87 values = torch.empty(size, dtype=dtype).uniform_() 88 return array_or_scalar(values, return_scalar=size == ()) 89 90 91def rand(*size): 92 if size == (): 93 size = None 94 return random_sample(size) 95 96 97sample = random_sample 98random = random_sample 99 100 101@deco_stream 102def uniform(low=0.0, high=1.0, size=None): 103 if size is None: 104 size = () 105 dtype = _dtypes_impl.default_dtypes().float_dtype 106 values = torch.empty(size, dtype=dtype).uniform_(low, high) 107 return array_or_scalar(values, return_scalar=size == ()) 108 109 110@deco_stream 111def randn(*size): 112 dtype = _dtypes_impl.default_dtypes().float_dtype 113 values = torch.randn(size, dtype=dtype) 114 return array_or_scalar(values, return_scalar=size == ()) 115 116 117@deco_stream 118def normal(loc=0.0, scale=1.0, size=None): 119 if size is None: 120 size = () 121 dtype = _dtypes_impl.default_dtypes().float_dtype 122 values = torch.empty(size, dtype=dtype).normal_(loc, scale) 123 return array_or_scalar(values, return_scalar=size == ()) 124 125 126@deco_stream 127def shuffle(x): 128 # no @normalizer because we do not cast e.g. lists to tensors 129 from ._ndarray import ndarray 130 131 if isinstance(x, torch.Tensor): 132 tensor = x 133 elif isinstance(x, ndarray): 134 tensor = x.tensor 135 else: 136 raise NotImplementedError("We do not random.shuffle lists in-place") 137 138 perm = torch.randperm(tensor.shape[0]) 139 xp = tensor[perm] 140 tensor.copy_(xp) 141 142 143@deco_stream 144def randint(low, high=None, size=None): 145 if size is None: 146 size = () 147 if not isinstance(size, (tuple, list)): 148 size = (size,) 149 if high is None: 150 low, high = 0, low 151 values = torch.randint(low, high, size=size) 152 return array_or_scalar(values, int, return_scalar=size == ()) 153 154 155@deco_stream 156@normalizer 157def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): 158 # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch 159 if a.numel() == 1: 160 a = torch.arange(a) 161 162 # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises 163 164 # number of draws 165 if size is None: 166 num_el = 1 167 elif _util.is_sequence(size): 168 num_el = 1 169 for el in size: 170 num_el *= el 171 else: 172 num_el = size 173 174 # prepare the probabilities 175 if p is None: 176 p = torch.ones_like(a) / a.shape[0] 177 178 # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973 179 atol = sqrt(torch.finfo(p.dtype).eps) 180 if abs(p.sum() - 1.0) > atol: 181 raise ValueError("probabilities do not sum to 1.") 182 183 # actually sample 184 indices = torch.multinomial(p, num_el, replacement=replace) 185 186 if _util.is_sequence(size): 187 indices = indices.reshape(size) 188 189 samples = a[indices] 190 191 return samples 192