1# mypy: allow-untyped-defs 2from typing import Optional 3 4import torch 5 6 7class SobolEngine: 8 r""" 9 The :class:`torch.quasirandom.SobolEngine` is an engine for generating 10 (scrambled) Sobol sequences. Sobol sequences are an example of low 11 discrepancy quasi-random sequences. 12 13 This implementation of an engine for Sobol sequences is capable of 14 sampling sequences up to a maximum dimension of 21201. It uses direction 15 numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the 16 search criterion D(6) up to the dimension 21201. This is the recommended 17 choice by the authors. 18 19 References: 20 - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points. 21 Journal of Complexity, 14(4):466-489, December 1998. 22 23 - I. M. Sobol. The distribution of points in a cube and the accurate 24 evaluation of integrals. 25 Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967. 26 27 Args: 28 dimension (Int): The dimensionality of the sequence to be drawn 29 scramble (bool, optional): Setting this to ``True`` will produce 30 scrambled Sobol sequences. Scrambling is 31 capable of producing better Sobol 32 sequences. Default: ``False``. 33 seed (Int, optional): This is the seed for the scrambling. The seed 34 of the random number generator is set to this, 35 if specified. Otherwise, it uses a random seed. 36 Default: ``None`` 37 38 Examples:: 39 40 >>> # xdoctest: +SKIP("unseeded random state") 41 >>> soboleng = torch.quasirandom.SobolEngine(dimension=5) 42 >>> soboleng.draw(3) 43 tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000], 44 [0.5000, 0.5000, 0.5000, 0.5000, 0.5000], 45 [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]]) 46 """ 47 48 MAXBIT = 30 49 MAXDIM = 21201 50 51 def __init__(self, dimension, scramble=False, seed=None): 52 if dimension > self.MAXDIM or dimension < 1: 53 raise ValueError( 54 "Supported range of dimensionality " 55 f"for SobolEngine is [1, {self.MAXDIM}]" 56 ) 57 58 self.seed = seed 59 self.scramble = scramble 60 self.dimension = dimension 61 62 cpu = torch.device("cpu") 63 64 self.sobolstate = torch.zeros( 65 dimension, self.MAXBIT, device=cpu, dtype=torch.long 66 ) 67 torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension) 68 69 if not self.scramble: 70 self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long) 71 else: 72 self._scramble() 73 74 self.quasi = self.shift.clone(memory_format=torch.contiguous_format) 75 self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1) 76 self.num_generated = 0 77 78 def draw( 79 self, 80 n: int = 1, 81 out: Optional[torch.Tensor] = None, 82 dtype: Optional[torch.dtype] = None, 83 ) -> torch.Tensor: 84 r""" 85 Function to draw a sequence of :attr:`n` points from a Sobol sequence. 86 Note that the samples are dependent on the previous samples. The size 87 of the result is :math:`(n, dimension)`. 88 89 Args: 90 n (Int, optional): The length of sequence of points to draw. 91 Default: 1 92 out (Tensor, optional): The output tensor 93 dtype (:class:`torch.dtype`, optional): the desired data type of the 94 returned tensor. 95 Default: ``None`` 96 """ 97 if dtype is None: 98 dtype = torch.get_default_dtype() 99 100 if self.num_generated == 0: 101 if n == 1: 102 result = self._first_point.to(dtype) 103 else: 104 result, self.quasi = torch._sobol_engine_draw( 105 self.quasi, 106 n - 1, 107 self.sobolstate, 108 self.dimension, 109 self.num_generated, 110 dtype=dtype, 111 ) 112 result = torch.cat((self._first_point.to(dtype), result), dim=-2) 113 else: 114 result, self.quasi = torch._sobol_engine_draw( 115 self.quasi, 116 n, 117 self.sobolstate, 118 self.dimension, 119 self.num_generated - 1, 120 dtype=dtype, 121 ) 122 123 self.num_generated += n 124 125 if out is not None: 126 out.resize_as_(result).copy_(result) 127 return out 128 129 return result 130 131 def draw_base2( 132 self, 133 m: int, 134 out: Optional[torch.Tensor] = None, 135 dtype: Optional[torch.dtype] = None, 136 ) -> torch.Tensor: 137 r""" 138 Function to draw a sequence of :attr:`2**m` points from a Sobol sequence. 139 Note that the samples are dependent on the previous samples. The size 140 of the result is :math:`(2**m, dimension)`. 141 142 Args: 143 m (Int): The (base2) exponent of the number of points to draw. 144 out (Tensor, optional): The output tensor 145 dtype (:class:`torch.dtype`, optional): the desired data type of the 146 returned tensor. 147 Default: ``None`` 148 """ 149 n = 2**m 150 total_n = self.num_generated + n 151 if not (total_n & (total_n - 1) == 0): 152 raise ValueError( 153 "The balance properties of Sobol' points require " 154 f"n to be a power of 2. {self.num_generated} points have been " 155 f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. " 156 "If you still want to do this, please use " 157 "'SobolEngine.draw()' instead." 158 ) 159 return self.draw(n=n, out=out, dtype=dtype) 160 161 def reset(self): 162 r""" 163 Function to reset the ``SobolEngine`` to base state. 164 """ 165 self.quasi.copy_(self.shift) 166 self.num_generated = 0 167 return self 168 169 def fast_forward(self, n): 170 r""" 171 Function to fast-forward the state of the ``SobolEngine`` by 172 :attr:`n` steps. This is equivalent to drawing :attr:`n` samples 173 without using the samples. 174 175 Args: 176 n (Int): The number of steps to fast-forward by. 177 """ 178 if self.num_generated == 0: 179 torch._sobol_engine_ff_( 180 self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated 181 ) 182 else: 183 torch._sobol_engine_ff_( 184 self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1 185 ) 186 self.num_generated += n 187 return self 188 189 def _scramble(self): 190 g: Optional[torch.Generator] = None 191 if self.seed is not None: 192 g = torch.Generator() 193 g.manual_seed(self.seed) 194 195 cpu = torch.device("cpu") 196 197 # Generate shift vector 198 shift_ints = torch.randint( 199 2, (self.dimension, self.MAXBIT), device=cpu, generator=g 200 ) 201 self.shift = torch.mv( 202 shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)) 203 ) 204 205 # Generate lower triangular matrices (stacked across dimensions) 206 ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT) 207 ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril() 208 209 torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension) 210 211 def __repr__(self): 212 fmt_string = [f"dimension={self.dimension}"] 213 if self.scramble: 214 fmt_string += ["scramble=True"] 215 if self.seed is not None: 216 fmt_string += [f"seed={self.seed}"] 217 return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")" 218