xref: /aosp_15_r20/external/pytorch/torch/quasirandom.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional
3
4import torch
5
6
7class SobolEngine:
8    r"""
9    The :class:`torch.quasirandom.SobolEngine` is an engine for generating
10    (scrambled) Sobol sequences. Sobol sequences are an example of low
11    discrepancy quasi-random sequences.
12
13    This implementation of an engine for Sobol sequences is capable of
14    sampling sequences up to a maximum dimension of 21201. It uses direction
15    numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the
16    search criterion D(6) up to the dimension 21201. This is the recommended
17    choice by the authors.
18
19    References:
20      - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points.
21        Journal of Complexity, 14(4):466-489, December 1998.
22
23      - I. M. Sobol. The distribution of points in a cube and the accurate
24        evaluation of integrals.
25        Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967.
26
27    Args:
28        dimension (Int): The dimensionality of the sequence to be drawn
29        scramble (bool, optional): Setting this to ``True`` will produce
30                                   scrambled Sobol sequences. Scrambling is
31                                   capable of producing better Sobol
32                                   sequences. Default: ``False``.
33        seed (Int, optional): This is the seed for the scrambling. The seed
34                              of the random number generator is set to this,
35                              if specified. Otherwise, it uses a random seed.
36                              Default: ``None``
37
38    Examples::
39
40        >>> # xdoctest: +SKIP("unseeded random state")
41        >>> soboleng = torch.quasirandom.SobolEngine(dimension=5)
42        >>> soboleng.draw(3)
43        tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
44                [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
45                [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]])
46    """
47
48    MAXBIT = 30
49    MAXDIM = 21201
50
51    def __init__(self, dimension, scramble=False, seed=None):
52        if dimension > self.MAXDIM or dimension < 1:
53            raise ValueError(
54                "Supported range of dimensionality "
55                f"for SobolEngine is [1, {self.MAXDIM}]"
56            )
57
58        self.seed = seed
59        self.scramble = scramble
60        self.dimension = dimension
61
62        cpu = torch.device("cpu")
63
64        self.sobolstate = torch.zeros(
65            dimension, self.MAXBIT, device=cpu, dtype=torch.long
66        )
67        torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
68
69        if not self.scramble:
70            self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long)
71        else:
72            self._scramble()
73
74        self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
75        self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
76        self.num_generated = 0
77
78    def draw(
79        self,
80        n: int = 1,
81        out: Optional[torch.Tensor] = None,
82        dtype: Optional[torch.dtype] = None,
83    ) -> torch.Tensor:
84        r"""
85        Function to draw a sequence of :attr:`n` points from a Sobol sequence.
86        Note that the samples are dependent on the previous samples. The size
87        of the result is :math:`(n, dimension)`.
88
89        Args:
90            n (Int, optional): The length of sequence of points to draw.
91                               Default: 1
92            out (Tensor, optional): The output tensor
93            dtype (:class:`torch.dtype`, optional): the desired data type of the
94                                                    returned tensor.
95                                                    Default: ``None``
96        """
97        if dtype is None:
98            dtype = torch.get_default_dtype()
99
100        if self.num_generated == 0:
101            if n == 1:
102                result = self._first_point.to(dtype)
103            else:
104                result, self.quasi = torch._sobol_engine_draw(
105                    self.quasi,
106                    n - 1,
107                    self.sobolstate,
108                    self.dimension,
109                    self.num_generated,
110                    dtype=dtype,
111                )
112                result = torch.cat((self._first_point.to(dtype), result), dim=-2)
113        else:
114            result, self.quasi = torch._sobol_engine_draw(
115                self.quasi,
116                n,
117                self.sobolstate,
118                self.dimension,
119                self.num_generated - 1,
120                dtype=dtype,
121            )
122
123        self.num_generated += n
124
125        if out is not None:
126            out.resize_as_(result).copy_(result)
127            return out
128
129        return result
130
131    def draw_base2(
132        self,
133        m: int,
134        out: Optional[torch.Tensor] = None,
135        dtype: Optional[torch.dtype] = None,
136    ) -> torch.Tensor:
137        r"""
138        Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
139        Note that the samples are dependent on the previous samples. The size
140        of the result is :math:`(2**m, dimension)`.
141
142        Args:
143            m (Int): The (base2) exponent of the number of points to draw.
144            out (Tensor, optional): The output tensor
145            dtype (:class:`torch.dtype`, optional): the desired data type of the
146                                                    returned tensor.
147                                                    Default: ``None``
148        """
149        n = 2**m
150        total_n = self.num_generated + n
151        if not (total_n & (total_n - 1) == 0):
152            raise ValueError(
153                "The balance properties of Sobol' points require "
154                f"n to be a power of 2. {self.num_generated} points have been "
155                f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
156                "If you still want to do this, please use "
157                "'SobolEngine.draw()' instead."
158            )
159        return self.draw(n=n, out=out, dtype=dtype)
160
161    def reset(self):
162        r"""
163        Function to reset the ``SobolEngine`` to base state.
164        """
165        self.quasi.copy_(self.shift)
166        self.num_generated = 0
167        return self
168
169    def fast_forward(self, n):
170        r"""
171        Function to fast-forward the state of the ``SobolEngine`` by
172        :attr:`n` steps. This is equivalent to drawing :attr:`n` samples
173        without using the samples.
174
175        Args:
176            n (Int): The number of steps to fast-forward by.
177        """
178        if self.num_generated == 0:
179            torch._sobol_engine_ff_(
180                self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
181            )
182        else:
183            torch._sobol_engine_ff_(
184                self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
185            )
186        self.num_generated += n
187        return self
188
189    def _scramble(self):
190        g: Optional[torch.Generator] = None
191        if self.seed is not None:
192            g = torch.Generator()
193            g.manual_seed(self.seed)
194
195        cpu = torch.device("cpu")
196
197        # Generate shift vector
198        shift_ints = torch.randint(
199            2, (self.dimension, self.MAXBIT), device=cpu, generator=g
200        )
201        self.shift = torch.mv(
202            shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
203        )
204
205        # Generate lower triangular matrices (stacked across dimensions)
206        ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
207        ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril()
208
209        torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
210
211    def __repr__(self):
212        fmt_string = [f"dimension={self.dimension}"]
213        if self.scramble:
214            fmt_string += ["scramble=True"]
215        if self.seed is not None:
216            fmt_string += [f"seed={self.seed}"]
217        return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"
218