xref: /aosp_15_r20/external/abseil-cpp/absl/random/internal/fast_uniform_bits.h (revision 9356374a3709195abf420251b3e825997ff56c0f)
1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_
16 #define ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_
17 
18 #include <cstddef>
19 #include <cstdint>
20 #include <limits>
21 #include <type_traits>
22 
23 #include "absl/base/config.h"
24 #include "absl/meta/type_traits.h"
25 #include "absl/random/internal/traits.h"
26 
27 namespace absl {
28 ABSL_NAMESPACE_BEGIN
29 namespace random_internal {
30 // Returns true if the input value is zero or a power of two. Useful for
31 // determining if the range of output values in a URBG
32 template <typename UIntType>
IsPowerOfTwoOrZero(UIntType n)33 constexpr bool IsPowerOfTwoOrZero(UIntType n) {
34   return (n == 0) || ((n & (n - 1)) == 0);
35 }
36 
37 // Computes the length of the range of values producible by the URBG, or returns
38 // zero if that would encompass the entire range of representable values in
39 // URBG::result_type.
40 template <typename URBG>
RangeSize()41 constexpr typename URBG::result_type RangeSize() {
42   using result_type = typename URBG::result_type;
43   static_assert((URBG::max)() != (URBG::min)(), "URBG range cannot be 0.");
44   return ((URBG::max)() == (std::numeric_limits<result_type>::max)() &&
45           (URBG::min)() == std::numeric_limits<result_type>::lowest())
46              ? result_type{0}
47              : ((URBG::max)() - (URBG::min)() + result_type{1});
48 }
49 
50 // Computes the floor of the log. (i.e., std::floor(std::log2(N));
51 template <typename UIntType>
IntegerLog2(UIntType n)52 constexpr UIntType IntegerLog2(UIntType n) {
53   return (n <= 1) ? 0 : 1 + IntegerLog2(n >> 1);
54 }
55 
56 // Returns the number of bits of randomness returned through
57 // `PowerOfTwoVariate(urbg)`.
58 template <typename URBG>
NumBits()59 constexpr size_t NumBits() {
60   return static_cast<size_t>(
61       RangeSize<URBG>() == 0
62           ? std::numeric_limits<typename URBG::result_type>::digits
63           : IntegerLog2(RangeSize<URBG>()));
64 }
65 
66 // Given a shift value `n`, constructs a mask with exactly the low `n` bits set.
67 // If `n == 0`, all bits are set.
68 template <typename UIntType>
MaskFromShift(size_t n)69 constexpr UIntType MaskFromShift(size_t n) {
70   return ((n % std::numeric_limits<UIntType>::digits) == 0)
71              ? ~UIntType{0}
72              : (UIntType{1} << n) - UIntType{1};
73 }
74 
75 // Tags used to dispatch FastUniformBits::generate to the simple or more complex
76 // entropy extraction algorithm.
77 struct SimplifiedLoopTag {};
78 struct RejectionLoopTag {};
79 
80 // FastUniformBits implements a fast path to acquire uniform independent bits
81 // from a type which conforms to the [rand.req.urbg] concept.
82 // Parameterized by:
83 //  `UIntType`: the result (output) type
84 //
85 // The std::independent_bits_engine [rand.adapt.ibits] adaptor can be
86 // instantiated from an existing generator through a copy or a move. It does
87 // not, however, facilitate the production of pseudorandom bits from an un-owned
88 // generator that will outlive the std::independent_bits_engine instance.
89 template <typename UIntType = uint64_t>
90 class FastUniformBits {
91  public:
92   using result_type = UIntType;
93 
result_type(min)94   static constexpr result_type(min)() { return 0; }
result_type(max)95   static constexpr result_type(max)() {
96     return (std::numeric_limits<result_type>::max)();
97   }
98 
99   template <typename URBG>
100   result_type operator()(URBG& g);  // NOLINT(runtime/references)
101 
102  private:
103   static_assert(IsUnsigned<UIntType>::value,
104                 "Class-template FastUniformBits<> must be parameterized using "
105                 "an unsigned type.");
106 
107   // Generate() generates a random value, dispatched on whether
108   // the underlying URBG must use rejection sampling to generate a value,
109   // or whether a simplified loop will suffice.
110   template <typename URBG>
111   result_type Generate(URBG& g,  // NOLINT(runtime/references)
112                        SimplifiedLoopTag);
113 
114   template <typename URBG>
115   result_type Generate(URBG& g,  // NOLINT(runtime/references)
116                        RejectionLoopTag);
117 };
118 
119 template <typename UIntType>
120 template <typename URBG>
121 typename FastUniformBits<UIntType>::result_type
operator()122 FastUniformBits<UIntType>::operator()(URBG& g) {  // NOLINT(runtime/references)
123   // kRangeMask is the mask used when sampling variates from the URBG when the
124   // width of the URBG range is not a power of 2.
125   // Y = (2 ^ kRange) - 1
126   static_assert((URBG::max)() > (URBG::min)(),
127                 "URBG::max and URBG::min may not be equal.");
128 
129   using tag = absl::conditional_t<IsPowerOfTwoOrZero(RangeSize<URBG>()),
130                                   SimplifiedLoopTag, RejectionLoopTag>;
131   return Generate(g, tag{});
132 }
133 
134 template <typename UIntType>
135 template <typename URBG>
136 typename FastUniformBits<UIntType>::result_type
Generate(URBG & g,SimplifiedLoopTag)137 FastUniformBits<UIntType>::Generate(URBG& g,  // NOLINT(runtime/references)
138                                     SimplifiedLoopTag) {
139   // The simplified version of FastUniformBits works only on URBGs that have
140   // a range that is a power of 2. In this case we simply loop and shift without
141   // attempting to balance the bits across calls.
142   static_assert(IsPowerOfTwoOrZero(RangeSize<URBG>()),
143                 "incorrect Generate tag for URBG instance");
144 
145   static constexpr size_t kResultBits =
146       std::numeric_limits<result_type>::digits;
147   static constexpr size_t kUrbgBits = NumBits<URBG>();
148   static constexpr size_t kIters =
149       (kResultBits / kUrbgBits) + (kResultBits % kUrbgBits != 0);
150   static constexpr size_t kShift = (kIters == 1) ? 0 : kUrbgBits;
151   static constexpr auto kMin = (URBG::min)();
152 
153   result_type r = static_cast<result_type>(g() - kMin);
154   for (size_t n = 1; n < kIters; ++n) {
155     r = static_cast<result_type>(r << kShift) +
156         static_cast<result_type>(g() - kMin);
157   }
158   return r;
159 }
160 
161 template <typename UIntType>
162 template <typename URBG>
163 typename FastUniformBits<UIntType>::result_type
Generate(URBG & g,RejectionLoopTag)164 FastUniformBits<UIntType>::Generate(URBG& g,  // NOLINT(runtime/references)
165                                     RejectionLoopTag) {
166   static_assert(!IsPowerOfTwoOrZero(RangeSize<URBG>()),
167                 "incorrect Generate tag for URBG instance");
168   using urbg_result_type = typename URBG::result_type;
169 
170   // See [rand.adapt.ibits] for more details on the constants calculated below.
171   //
172   // It is preferable to use roughly the same number of bits from each generator
173   // call, however this is only possible when the number of bits provided by the
174   // URBG is a divisor of the number of bits in `result_type`. In all other
175   // cases, the number of bits used cannot always be the same, but it can be
176   // guaranteed to be off by at most 1. Thus we run two loops, one with a
177   // smaller bit-width size (`kSmallWidth`) and one with a larger width size
178   // (satisfying `kLargeWidth == kSmallWidth + 1`). The loops are run
179   // `kSmallIters` and `kLargeIters` times respectively such
180   // that
181   //
182   //    `kResultBits == kSmallIters * kSmallBits
183   //                    + kLargeIters * kLargeBits`
184   //
185   // where `kResultBits` is the total number of bits in `result_type`.
186   //
187   static constexpr size_t kResultBits =
188       std::numeric_limits<result_type>::digits;                      // w
189   static constexpr urbg_result_type kUrbgRange = RangeSize<URBG>();  // R
190   static constexpr size_t kUrbgBits = NumBits<URBG>();               // m
191 
192   // compute the initial estimate of the bits used.
193   // [rand.adapt.ibits] 2 (c)
194   static constexpr size_t kA =  // ceil(w/m)
195       (kResultBits / kUrbgBits) + ((kResultBits % kUrbgBits) != 0);  // n'
196 
197   static constexpr size_t kABits = kResultBits / kA;  // w0'
198   static constexpr urbg_result_type kARejection =
199       ((kUrbgRange >> kABits) << kABits);  // y0'
200 
201   // refine the selection to reduce the rejection frequency.
202   static constexpr size_t kTotalIters =
203       ((kUrbgRange - kARejection) <= (kARejection / kA)) ? kA : (kA + 1);  // n
204 
205   // [rand.adapt.ibits] 2 (b)
206   static constexpr size_t kSmallIters =
207       kTotalIters - (kResultBits % kTotalIters);                   // n0
208   static constexpr size_t kSmallBits = kResultBits / kTotalIters;  // w0
209   static constexpr urbg_result_type kSmallRejection =
210       ((kUrbgRange >> kSmallBits) << kSmallBits);  // y0
211 
212   static constexpr size_t kLargeBits = kSmallBits + 1;  // w0+1
213   static constexpr urbg_result_type kLargeRejection =
214       ((kUrbgRange >> kLargeBits) << kLargeBits);  // y1
215 
216   //
217   // Because `kLargeBits == kSmallBits + 1`, it follows that
218   //
219   //     `kResultBits == kSmallIters * kSmallBits + kLargeIters`
220   //
221   // and therefore
222   //
223   //     `kLargeIters == kTotalWidth % kSmallWidth`
224   //
225   // Intuitively, each iteration with the large width accounts for one unit
226   // of the remainder when `kTotalWidth` is divided by `kSmallWidth`. As
227   // mentioned above, if the URBG width is a divisor of `kTotalWidth`, then
228   // there would be no need for any large iterations (i.e., one loop would
229   // suffice), and indeed, in this case, `kLargeIters` would be zero.
230   static_assert(kResultBits == kSmallIters * kSmallBits +
231                                    (kTotalIters - kSmallIters) * kLargeBits,
232                 "Error in looping constant calculations.");
233 
234   // The small shift is essentially small bits, but due to the potential
235   // of generating a smaller result_type from a larger urbg type, the actual
236   // shift might be 0.
237   static constexpr size_t kSmallShift = kSmallBits % kResultBits;
238   static constexpr auto kSmallMask =
239       MaskFromShift<urbg_result_type>(kSmallShift);
240   static constexpr size_t kLargeShift = kLargeBits % kResultBits;
241   static constexpr auto kLargeMask =
242       MaskFromShift<urbg_result_type>(kLargeShift);
243 
244   static constexpr auto kMin = (URBG::min)();
245 
246   result_type s = 0;
247   for (size_t n = 0; n < kSmallIters; ++n) {
248     urbg_result_type v;
249     do {
250       v = g() - kMin;
251     } while (v >= kSmallRejection);
252 
253     s = (s << kSmallShift) + static_cast<result_type>(v & kSmallMask);
254   }
255 
256   for (size_t n = kSmallIters; n < kTotalIters; ++n) {
257     urbg_result_type v;
258     do {
259       v = g() - kMin;
260     } while (v >= kLargeRejection);
261 
262     s = (s << kLargeShift) + static_cast<result_type>(v & kLargeMask);
263   }
264   return s;
265 }
266 
267 }  // namespace random_internal
268 ABSL_NAMESPACE_END
269 }  // namespace absl
270 
271 #endif  // ABSL_RANDOM_INTERNAL_FAST_UNIFORM_BITS_H_
272