1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_
16 #define ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_
17
18 #include <cassert>
19 #include <cstdint>
20 #include <istream>
21 #include <ostream>
22
23 #include "absl/base/config.h"
24 #include "absl/base/optimization.h"
25 #include "absl/random/internal/fast_uniform_bits.h"
26 #include "absl/random/internal/iostream_state_saver.h"
27
28 namespace absl {
29 ABSL_NAMESPACE_BEGIN
30
31 // absl::bernoulli_distribution is a drop in replacement for
32 // std::bernoulli_distribution. It guarantees that (given a perfect
33 // UniformRandomBitGenerator) the acceptance probability is *exactly* equal to
34 // the given double.
35 //
36 // The implementation assumes that double is IEEE754
37 class bernoulli_distribution {
38 public:
39 using result_type = bool;
40
41 class param_type {
42 public:
43 using distribution_type = bernoulli_distribution;
44
prob_(p)45 explicit param_type(double p = 0.5) : prob_(p) {
46 assert(p >= 0.0 && p <= 1.0);
47 }
48
p()49 double p() const { return prob_; }
50
51 friend bool operator==(const param_type& p1, const param_type& p2) {
52 return p1.p() == p2.p();
53 }
54 friend bool operator!=(const param_type& p1, const param_type& p2) {
55 return p1.p() != p2.p();
56 }
57
58 private:
59 double prob_;
60 };
61
bernoulli_distribution()62 bernoulli_distribution() : bernoulli_distribution(0.5) {}
63
bernoulli_distribution(double p)64 explicit bernoulli_distribution(double p) : param_(p) {}
65
bernoulli_distribution(param_type p)66 explicit bernoulli_distribution(param_type p) : param_(p) {}
67
68 // no-op
reset()69 void reset() {}
70
71 template <typename URBG>
operator()72 bool operator()(URBG& g) { // NOLINT(runtime/references)
73 return Generate(param_.p(), g);
74 }
75
76 template <typename URBG>
operator()77 bool operator()(URBG& g, // NOLINT(runtime/references)
78 const param_type& param) {
79 return Generate(param.p(), g);
80 }
81
param()82 param_type param() const { return param_; }
param(const param_type & param)83 void param(const param_type& param) { param_ = param; }
84
p()85 double p() const { return param_.p(); }
86
result_type(min)87 result_type(min)() const { return false; }
result_type(max)88 result_type(max)() const { return true; }
89
90 friend bool operator==(const bernoulli_distribution& d1,
91 const bernoulli_distribution& d2) {
92 return d1.param_ == d2.param_;
93 }
94
95 friend bool operator!=(const bernoulli_distribution& d1,
96 const bernoulli_distribution& d2) {
97 return d1.param_ != d2.param_;
98 }
99
100 private:
101 static constexpr uint64_t kP32 = static_cast<uint64_t>(1) << 32;
102
103 template <typename URBG>
104 static bool Generate(double p, URBG& g); // NOLINT(runtime/references)
105
106 param_type param_;
107 };
108
109 template <typename CharT, typename Traits>
110 std::basic_ostream<CharT, Traits>& operator<<(
111 std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references)
112 const bernoulli_distribution& x) {
113 auto saver = random_internal::make_ostream_state_saver(os);
114 os.precision(random_internal::stream_precision_helper<double>::kPrecision);
115 os << x.p();
116 return os;
117 }
118
119 template <typename CharT, typename Traits>
120 std::basic_istream<CharT, Traits>& operator>>(
121 std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references)
122 bernoulli_distribution& x) { // NOLINT(runtime/references)
123 auto saver = random_internal::make_istream_state_saver(is);
124 auto p = random_internal::read_floating_point<double>(is);
125 if (!is.fail()) {
126 x.param(bernoulli_distribution::param_type(p));
127 }
128 return is;
129 }
130
131 template <typename URBG>
Generate(double p,URBG & g)132 bool bernoulli_distribution::Generate(double p,
133 URBG& g) { // NOLINT(runtime/references)
134 random_internal::FastUniformBits<uint32_t> fast_u32;
135
136 while (true) {
137 // There are two aspects of the definition of `c` below that are worth
138 // commenting on. First, because `p` is in the range [0, 1], `c` is in the
139 // range [0, 2^32] which does not fit in a uint32_t and therefore requires
140 // 64 bits.
141 //
142 // Second, `c` is constructed by first casting explicitly to a signed
143 // integer and then casting explicitly to an unsigned integer of the same
144 // size. This is done because the hardware conversion instructions produce
145 // signed integers from double; if taken as a uint64_t the conversion would
146 // be wrong for doubles greater than 2^63 (not relevant in this use-case).
147 // If converted directly to an unsigned integer, the compiler would end up
148 // emitting code to handle such large values that are not relevant due to
149 // the known bounds on `c`. To avoid these extra instructions this
150 // implementation converts first to the signed type and then convert to
151 // unsigned (which is a no-op).
152 const uint64_t c = static_cast<uint64_t>(static_cast<int64_t>(p * kP32));
153 const uint32_t v = fast_u32(g);
154 // FAST PATH: this path fails with probability 1/2^32. Note that simply
155 // returning v <= c would approximate P very well (up to an absolute error
156 // of 1/2^32); the slow path (taken in that range of possible error, in the
157 // case of equality) eliminates the remaining error.
158 if (ABSL_PREDICT_TRUE(v != c)) return v < c;
159
160 // It is guaranteed that `q` is strictly less than 1, because if `q` were
161 // greater than or equal to 1, the same would be true for `p`. Certainly `p`
162 // cannot be greater than 1, and if `p == 1`, then the fast path would
163 // necessary have been taken already.
164 const double q = static_cast<double>(c) / kP32;
165
166 // The probability of acceptance on the fast path is `q` and so the
167 // probability of acceptance here should be `p - q`.
168 //
169 // Note that `q` is obtained from `p` via some shifts and conversions, the
170 // upshot of which is that `q` is simply `p` with some of the
171 // least-significant bits of its mantissa set to zero. This means that the
172 // difference `p - q` will not have any rounding errors. To see why, pretend
173 // that double has 10 bits of resolution and q is obtained from `p` in such
174 // a way that the 4 least-significant bits of its mantissa are set to zero.
175 // For example:
176 // p = 1.1100111011 * 2^-1
177 // q = 1.1100110000 * 2^-1
178 // p - q = 1.011 * 2^-8
179 // The difference `p - q` has exactly the nonzero mantissa bits that were
180 // "lost" in `q` producing a number which is certainly representable in a
181 // double.
182 const double left = p - q;
183
184 // By construction, the probability of being on this slow path is 1/2^32, so
185 // P(accept in slow path) = P(accept| in slow path) * P(slow path),
186 // which means the probability of acceptance here is `1 / (left * kP32)`:
187 const double here = left * kP32;
188
189 // The simplest way to compute the result of this trial is to repeat the
190 // whole algorithm with the new probability. This terminates because even
191 // given arbitrarily unfriendly "random" bits, each iteration either
192 // multiplies a tiny probability by 2^32 (if c == 0) or strips off some
193 // number of nonzero mantissa bits. That process is bounded.
194 if (here == 0) return false;
195 p = here;
196 }
197 }
198
199 ABSL_NAMESPACE_END
200 } // namespace absl
201
202 #endif // ABSL_RANDOM_BERNOULLI_DISTRIBUTION_H_
203