1*9356374aSAndroid Build Coastguard Worker // Copyright 2017 The Abseil Authors.
2*9356374aSAndroid Build Coastguard Worker //
3*9356374aSAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License");
4*9356374aSAndroid Build Coastguard Worker // you may not use this file except in compliance with the License.
5*9356374aSAndroid Build Coastguard Worker // You may obtain a copy of the License at
6*9356374aSAndroid Build Coastguard Worker //
7*9356374aSAndroid Build Coastguard Worker // https://www.apache.org/licenses/LICENSE-2.0
8*9356374aSAndroid Build Coastguard Worker //
9*9356374aSAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*9356374aSAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
11*9356374aSAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*9356374aSAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
13*9356374aSAndroid Build Coastguard Worker // limitations under the License.
14*9356374aSAndroid Build Coastguard Worker
15*9356374aSAndroid Build Coastguard Worker #ifndef ABSL_RANDOM_BETA_DISTRIBUTION_H_
16*9356374aSAndroid Build Coastguard Worker #define ABSL_RANDOM_BETA_DISTRIBUTION_H_
17*9356374aSAndroid Build Coastguard Worker
18*9356374aSAndroid Build Coastguard Worker #include <cassert>
19*9356374aSAndroid Build Coastguard Worker #include <cmath>
20*9356374aSAndroid Build Coastguard Worker #include <istream>
21*9356374aSAndroid Build Coastguard Worker #include <limits>
22*9356374aSAndroid Build Coastguard Worker #include <ostream>
23*9356374aSAndroid Build Coastguard Worker #include <type_traits>
24*9356374aSAndroid Build Coastguard Worker
25*9356374aSAndroid Build Coastguard Worker #include "absl/meta/type_traits.h"
26*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/fast_uniform_bits.h"
27*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/fastmath.h"
28*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/generate_real.h"
29*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/iostream_state_saver.h"
30*9356374aSAndroid Build Coastguard Worker
31*9356374aSAndroid Build Coastguard Worker namespace absl {
32*9356374aSAndroid Build Coastguard Worker ABSL_NAMESPACE_BEGIN
33*9356374aSAndroid Build Coastguard Worker
34*9356374aSAndroid Build Coastguard Worker // absl::beta_distribution:
35*9356374aSAndroid Build Coastguard Worker // Generate a floating-point variate conforming to a Beta distribution:
36*9356374aSAndroid Build Coastguard Worker // pdf(x) \propto x^(alpha-1) * (1-x)^(beta-1),
37*9356374aSAndroid Build Coastguard Worker // where the params alpha and beta are both strictly positive real values.
38*9356374aSAndroid Build Coastguard Worker //
39*9356374aSAndroid Build Coastguard Worker // The support is the open interval (0, 1), but the return value might be equal
40*9356374aSAndroid Build Coastguard Worker // to 0 or 1, due to numerical errors when alpha and beta are very different.
41*9356374aSAndroid Build Coastguard Worker //
42*9356374aSAndroid Build Coastguard Worker // Usage note: One usage is that alpha and beta are counts of number of
43*9356374aSAndroid Build Coastguard Worker // successes and failures. When the total number of trials are large, consider
44*9356374aSAndroid Build Coastguard Worker // approximating a beta distribution with a Gaussian distribution with the same
45*9356374aSAndroid Build Coastguard Worker // mean and variance. One could use the skewness, which depends only on the
46*9356374aSAndroid Build Coastguard Worker // smaller of alpha and beta when the number of trials are sufficiently large,
47*9356374aSAndroid Build Coastguard Worker // to quantify how far a beta distribution is from the normal distribution.
48*9356374aSAndroid Build Coastguard Worker template <typename RealType = double>
49*9356374aSAndroid Build Coastguard Worker class beta_distribution {
50*9356374aSAndroid Build Coastguard Worker public:
51*9356374aSAndroid Build Coastguard Worker using result_type = RealType;
52*9356374aSAndroid Build Coastguard Worker
53*9356374aSAndroid Build Coastguard Worker class param_type {
54*9356374aSAndroid Build Coastguard Worker public:
55*9356374aSAndroid Build Coastguard Worker using distribution_type = beta_distribution;
56*9356374aSAndroid Build Coastguard Worker
param_type(result_type alpha,result_type beta)57*9356374aSAndroid Build Coastguard Worker explicit param_type(result_type alpha, result_type beta)
58*9356374aSAndroid Build Coastguard Worker : alpha_(alpha), beta_(beta) {
59*9356374aSAndroid Build Coastguard Worker assert(alpha >= 0);
60*9356374aSAndroid Build Coastguard Worker assert(beta >= 0);
61*9356374aSAndroid Build Coastguard Worker assert(alpha <= (std::numeric_limits<result_type>::max)());
62*9356374aSAndroid Build Coastguard Worker assert(beta <= (std::numeric_limits<result_type>::max)());
63*9356374aSAndroid Build Coastguard Worker if (alpha == 0 || beta == 0) {
64*9356374aSAndroid Build Coastguard Worker method_ = DEGENERATE_SMALL;
65*9356374aSAndroid Build Coastguard Worker x_ = (alpha >= beta) ? 1 : 0;
66*9356374aSAndroid Build Coastguard Worker return;
67*9356374aSAndroid Build Coastguard Worker }
68*9356374aSAndroid Build Coastguard Worker // a_ = min(beta, alpha), b_ = max(beta, alpha).
69*9356374aSAndroid Build Coastguard Worker if (beta < alpha) {
70*9356374aSAndroid Build Coastguard Worker inverted_ = true;
71*9356374aSAndroid Build Coastguard Worker a_ = beta;
72*9356374aSAndroid Build Coastguard Worker b_ = alpha;
73*9356374aSAndroid Build Coastguard Worker } else {
74*9356374aSAndroid Build Coastguard Worker inverted_ = false;
75*9356374aSAndroid Build Coastguard Worker a_ = alpha;
76*9356374aSAndroid Build Coastguard Worker b_ = beta;
77*9356374aSAndroid Build Coastguard Worker }
78*9356374aSAndroid Build Coastguard Worker if (a_ <= 1 && b_ >= ThresholdForLargeA()) {
79*9356374aSAndroid Build Coastguard Worker method_ = DEGENERATE_SMALL;
80*9356374aSAndroid Build Coastguard Worker x_ = inverted_ ? result_type(1) : result_type(0);
81*9356374aSAndroid Build Coastguard Worker return;
82*9356374aSAndroid Build Coastguard Worker }
83*9356374aSAndroid Build Coastguard Worker // For threshold values, see also:
84*9356374aSAndroid Build Coastguard Worker // Evaluation of Beta Generation Algorithms, Ying-Chao Hung, et. al.
85*9356374aSAndroid Build Coastguard Worker // February, 2009.
86*9356374aSAndroid Build Coastguard Worker if ((b_ < 1.0 && a_ + b_ <= 1.2) || a_ <= ThresholdForSmallA()) {
87*9356374aSAndroid Build Coastguard Worker // Choose Joehnk over Cheng when it's faster or when Cheng encounters
88*9356374aSAndroid Build Coastguard Worker // numerical issues.
89*9356374aSAndroid Build Coastguard Worker method_ = JOEHNK;
90*9356374aSAndroid Build Coastguard Worker a_ = result_type(1) / alpha_;
91*9356374aSAndroid Build Coastguard Worker b_ = result_type(1) / beta_;
92*9356374aSAndroid Build Coastguard Worker if (std::isinf(a_) || std::isinf(b_)) {
93*9356374aSAndroid Build Coastguard Worker method_ = DEGENERATE_SMALL;
94*9356374aSAndroid Build Coastguard Worker x_ = inverted_ ? result_type(1) : result_type(0);
95*9356374aSAndroid Build Coastguard Worker }
96*9356374aSAndroid Build Coastguard Worker return;
97*9356374aSAndroid Build Coastguard Worker }
98*9356374aSAndroid Build Coastguard Worker if (a_ >= ThresholdForLargeA()) {
99*9356374aSAndroid Build Coastguard Worker method_ = DEGENERATE_LARGE;
100*9356374aSAndroid Build Coastguard Worker // Note: on PPC for long double, evaluating
101*9356374aSAndroid Build Coastguard Worker // `std::numeric_limits::max() / ThresholdForLargeA` results in NaN.
102*9356374aSAndroid Build Coastguard Worker result_type r = a_ / b_;
103*9356374aSAndroid Build Coastguard Worker x_ = (inverted_ ? result_type(1) : r) / (1 + r);
104*9356374aSAndroid Build Coastguard Worker return;
105*9356374aSAndroid Build Coastguard Worker }
106*9356374aSAndroid Build Coastguard Worker x_ = a_ + b_;
107*9356374aSAndroid Build Coastguard Worker log_x_ = std::log(x_);
108*9356374aSAndroid Build Coastguard Worker if (a_ <= 1) {
109*9356374aSAndroid Build Coastguard Worker method_ = CHENG_BA;
110*9356374aSAndroid Build Coastguard Worker y_ = result_type(1) / a_;
111*9356374aSAndroid Build Coastguard Worker gamma_ = a_ + a_;
112*9356374aSAndroid Build Coastguard Worker return;
113*9356374aSAndroid Build Coastguard Worker }
114*9356374aSAndroid Build Coastguard Worker method_ = CHENG_BB;
115*9356374aSAndroid Build Coastguard Worker result_type r = (a_ - 1) / (b_ - 1);
116*9356374aSAndroid Build Coastguard Worker y_ = std::sqrt((1 + r) / (b_ * r * 2 - r + 1));
117*9356374aSAndroid Build Coastguard Worker gamma_ = a_ + result_type(1) / y_;
118*9356374aSAndroid Build Coastguard Worker }
119*9356374aSAndroid Build Coastguard Worker
alpha()120*9356374aSAndroid Build Coastguard Worker result_type alpha() const { return alpha_; }
beta()121*9356374aSAndroid Build Coastguard Worker result_type beta() const { return beta_; }
122*9356374aSAndroid Build Coastguard Worker
123*9356374aSAndroid Build Coastguard Worker friend bool operator==(const param_type& a, const param_type& b) {
124*9356374aSAndroid Build Coastguard Worker return a.alpha_ == b.alpha_ && a.beta_ == b.beta_;
125*9356374aSAndroid Build Coastguard Worker }
126*9356374aSAndroid Build Coastguard Worker
127*9356374aSAndroid Build Coastguard Worker friend bool operator!=(const param_type& a, const param_type& b) {
128*9356374aSAndroid Build Coastguard Worker return !(a == b);
129*9356374aSAndroid Build Coastguard Worker }
130*9356374aSAndroid Build Coastguard Worker
131*9356374aSAndroid Build Coastguard Worker private:
132*9356374aSAndroid Build Coastguard Worker friend class beta_distribution;
133*9356374aSAndroid Build Coastguard Worker
134*9356374aSAndroid Build Coastguard Worker #ifdef _MSC_VER
135*9356374aSAndroid Build Coastguard Worker // MSVC does not have constexpr implementations for std::log and std::exp
136*9356374aSAndroid Build Coastguard Worker // so they are computed at runtime.
137*9356374aSAndroid Build Coastguard Worker #define ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR
138*9356374aSAndroid Build Coastguard Worker #else
139*9356374aSAndroid Build Coastguard Worker #define ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR constexpr
140*9356374aSAndroid Build Coastguard Worker #endif
141*9356374aSAndroid Build Coastguard Worker
142*9356374aSAndroid Build Coastguard Worker // The threshold for whether std::exp(1/a) is finite.
143*9356374aSAndroid Build Coastguard Worker // Note that this value is quite large, and a smaller a_ is NOT abnormal.
144*9356374aSAndroid Build Coastguard Worker static ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR result_type
ThresholdForSmallA()145*9356374aSAndroid Build Coastguard Worker ThresholdForSmallA() {
146*9356374aSAndroid Build Coastguard Worker return result_type(1) /
147*9356374aSAndroid Build Coastguard Worker std::log((std::numeric_limits<result_type>::max)());
148*9356374aSAndroid Build Coastguard Worker }
149*9356374aSAndroid Build Coastguard Worker
150*9356374aSAndroid Build Coastguard Worker // The threshold for whether a * std::log(a) is finite.
151*9356374aSAndroid Build Coastguard Worker static ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR result_type
ThresholdForLargeA()152*9356374aSAndroid Build Coastguard Worker ThresholdForLargeA() {
153*9356374aSAndroid Build Coastguard Worker return std::exp(
154*9356374aSAndroid Build Coastguard Worker std::log((std::numeric_limits<result_type>::max)()) -
155*9356374aSAndroid Build Coastguard Worker std::log(std::log((std::numeric_limits<result_type>::max)())) -
156*9356374aSAndroid Build Coastguard Worker ThresholdPadding());
157*9356374aSAndroid Build Coastguard Worker }
158*9356374aSAndroid Build Coastguard Worker
159*9356374aSAndroid Build Coastguard Worker #undef ABSL_RANDOM_INTERNAL_LOG_EXP_CONSTEXPR
160*9356374aSAndroid Build Coastguard Worker
161*9356374aSAndroid Build Coastguard Worker // Pad the threshold for large A for long double on PPC. This is done via a
162*9356374aSAndroid Build Coastguard Worker // template specialization below.
ThresholdPadding()163*9356374aSAndroid Build Coastguard Worker static constexpr result_type ThresholdPadding() { return 0; }
164*9356374aSAndroid Build Coastguard Worker
165*9356374aSAndroid Build Coastguard Worker enum Method {
166*9356374aSAndroid Build Coastguard Worker JOEHNK, // Uses algorithm Joehnk
167*9356374aSAndroid Build Coastguard Worker CHENG_BA, // Uses algorithm BA in Cheng
168*9356374aSAndroid Build Coastguard Worker CHENG_BB, // Uses algorithm BB in Cheng
169*9356374aSAndroid Build Coastguard Worker
170*9356374aSAndroid Build Coastguard Worker // Note: See also:
171*9356374aSAndroid Build Coastguard Worker // Hung et al. Evaluation of beta generation algorithms. Communications
172*9356374aSAndroid Build Coastguard Worker // in Statistics-Simulation and Computation 38.4 (2009): 750-770.
173*9356374aSAndroid Build Coastguard Worker // especially:
174*9356374aSAndroid Build Coastguard Worker // Zechner, Heinz, and Ernst Stadlober. Generating beta variates via
175*9356374aSAndroid Build Coastguard Worker // patchwork rejection. Computing 50.1 (1993): 1-18.
176*9356374aSAndroid Build Coastguard Worker
177*9356374aSAndroid Build Coastguard Worker DEGENERATE_SMALL, // a_ is abnormally small.
178*9356374aSAndroid Build Coastguard Worker DEGENERATE_LARGE, // a_ is abnormally large.
179*9356374aSAndroid Build Coastguard Worker };
180*9356374aSAndroid Build Coastguard Worker
181*9356374aSAndroid Build Coastguard Worker result_type alpha_;
182*9356374aSAndroid Build Coastguard Worker result_type beta_;
183*9356374aSAndroid Build Coastguard Worker
184*9356374aSAndroid Build Coastguard Worker result_type a_{}; // the smaller of {alpha, beta}, or 1.0/alpha_ in JOEHNK
185*9356374aSAndroid Build Coastguard Worker result_type b_{}; // the larger of {alpha, beta}, or 1.0/beta_ in JOEHNK
186*9356374aSAndroid Build Coastguard Worker result_type x_{}; // alpha + beta, or the result in degenerate cases
187*9356374aSAndroid Build Coastguard Worker result_type log_x_{}; // log(x_)
188*9356374aSAndroid Build Coastguard Worker result_type y_{}; // "beta" in Cheng
189*9356374aSAndroid Build Coastguard Worker result_type gamma_{}; // "gamma" in Cheng
190*9356374aSAndroid Build Coastguard Worker
191*9356374aSAndroid Build Coastguard Worker Method method_{};
192*9356374aSAndroid Build Coastguard Worker
193*9356374aSAndroid Build Coastguard Worker // Placing this last for optimal alignment.
194*9356374aSAndroid Build Coastguard Worker // Whether alpha_ != a_, i.e. true iff alpha_ > beta_.
195*9356374aSAndroid Build Coastguard Worker bool inverted_{};
196*9356374aSAndroid Build Coastguard Worker
197*9356374aSAndroid Build Coastguard Worker static_assert(std::is_floating_point<RealType>::value,
198*9356374aSAndroid Build Coastguard Worker "Class-template absl::beta_distribution<> must be "
199*9356374aSAndroid Build Coastguard Worker "parameterized using a floating-point type.");
200*9356374aSAndroid Build Coastguard Worker };
201*9356374aSAndroid Build Coastguard Worker
beta_distribution()202*9356374aSAndroid Build Coastguard Worker beta_distribution() : beta_distribution(1) {}
203*9356374aSAndroid Build Coastguard Worker
204*9356374aSAndroid Build Coastguard Worker explicit beta_distribution(result_type alpha, result_type beta = 1)
param_(alpha,beta)205*9356374aSAndroid Build Coastguard Worker : param_(alpha, beta) {}
206*9356374aSAndroid Build Coastguard Worker
beta_distribution(const param_type & p)207*9356374aSAndroid Build Coastguard Worker explicit beta_distribution(const param_type& p) : param_(p) {}
208*9356374aSAndroid Build Coastguard Worker
reset()209*9356374aSAndroid Build Coastguard Worker void reset() {}
210*9356374aSAndroid Build Coastguard Worker
211*9356374aSAndroid Build Coastguard Worker // Generating functions
212*9356374aSAndroid Build Coastguard Worker template <typename URBG>
operator()213*9356374aSAndroid Build Coastguard Worker result_type operator()(URBG& g) { // NOLINT(runtime/references)
214*9356374aSAndroid Build Coastguard Worker return (*this)(g, param_);
215*9356374aSAndroid Build Coastguard Worker }
216*9356374aSAndroid Build Coastguard Worker
217*9356374aSAndroid Build Coastguard Worker template <typename URBG>
218*9356374aSAndroid Build Coastguard Worker result_type operator()(URBG& g, // NOLINT(runtime/references)
219*9356374aSAndroid Build Coastguard Worker const param_type& p);
220*9356374aSAndroid Build Coastguard Worker
param()221*9356374aSAndroid Build Coastguard Worker param_type param() const { return param_; }
param(const param_type & p)222*9356374aSAndroid Build Coastguard Worker void param(const param_type& p) { param_ = p; }
223*9356374aSAndroid Build Coastguard Worker
result_type(min)224*9356374aSAndroid Build Coastguard Worker result_type(min)() const { return 0; }
result_type(max)225*9356374aSAndroid Build Coastguard Worker result_type(max)() const { return 1; }
226*9356374aSAndroid Build Coastguard Worker
alpha()227*9356374aSAndroid Build Coastguard Worker result_type alpha() const { return param_.alpha(); }
beta()228*9356374aSAndroid Build Coastguard Worker result_type beta() const { return param_.beta(); }
229*9356374aSAndroid Build Coastguard Worker
230*9356374aSAndroid Build Coastguard Worker friend bool operator==(const beta_distribution& a,
231*9356374aSAndroid Build Coastguard Worker const beta_distribution& b) {
232*9356374aSAndroid Build Coastguard Worker return a.param_ == b.param_;
233*9356374aSAndroid Build Coastguard Worker }
234*9356374aSAndroid Build Coastguard Worker friend bool operator!=(const beta_distribution& a,
235*9356374aSAndroid Build Coastguard Worker const beta_distribution& b) {
236*9356374aSAndroid Build Coastguard Worker return a.param_ != b.param_;
237*9356374aSAndroid Build Coastguard Worker }
238*9356374aSAndroid Build Coastguard Worker
239*9356374aSAndroid Build Coastguard Worker private:
240*9356374aSAndroid Build Coastguard Worker template <typename URBG>
241*9356374aSAndroid Build Coastguard Worker result_type AlgorithmJoehnk(URBG& g, // NOLINT(runtime/references)
242*9356374aSAndroid Build Coastguard Worker const param_type& p);
243*9356374aSAndroid Build Coastguard Worker
244*9356374aSAndroid Build Coastguard Worker template <typename URBG>
245*9356374aSAndroid Build Coastguard Worker result_type AlgorithmCheng(URBG& g, // NOLINT(runtime/references)
246*9356374aSAndroid Build Coastguard Worker const param_type& p);
247*9356374aSAndroid Build Coastguard Worker
248*9356374aSAndroid Build Coastguard Worker template <typename URBG>
DegenerateCase(URBG & g,const param_type & p)249*9356374aSAndroid Build Coastguard Worker result_type DegenerateCase(URBG& g, // NOLINT(runtime/references)
250*9356374aSAndroid Build Coastguard Worker const param_type& p) {
251*9356374aSAndroid Build Coastguard Worker if (p.method_ == param_type::DEGENERATE_SMALL && p.alpha_ == p.beta_) {
252*9356374aSAndroid Build Coastguard Worker // Returns 0 or 1 with equal probability.
253*9356374aSAndroid Build Coastguard Worker random_internal::FastUniformBits<uint8_t> fast_u8;
254*9356374aSAndroid Build Coastguard Worker return static_cast<result_type>((fast_u8(g) & 0x10) !=
255*9356374aSAndroid Build Coastguard Worker 0); // pick any single bit.
256*9356374aSAndroid Build Coastguard Worker }
257*9356374aSAndroid Build Coastguard Worker return p.x_;
258*9356374aSAndroid Build Coastguard Worker }
259*9356374aSAndroid Build Coastguard Worker
260*9356374aSAndroid Build Coastguard Worker param_type param_;
261*9356374aSAndroid Build Coastguard Worker random_internal::FastUniformBits<uint64_t> fast_u64_;
262*9356374aSAndroid Build Coastguard Worker };
263*9356374aSAndroid Build Coastguard Worker
264*9356374aSAndroid Build Coastguard Worker #if defined(__powerpc64__) || defined(__PPC64__) || defined(__powerpc__) || \
265*9356374aSAndroid Build Coastguard Worker defined(__ppc__) || defined(__PPC__)
266*9356374aSAndroid Build Coastguard Worker // PPC needs a more stringent boundary for long double.
267*9356374aSAndroid Build Coastguard Worker template <>
268*9356374aSAndroid Build Coastguard Worker constexpr long double
ThresholdPadding()269*9356374aSAndroid Build Coastguard Worker beta_distribution<long double>::param_type::ThresholdPadding() {
270*9356374aSAndroid Build Coastguard Worker return 10;
271*9356374aSAndroid Build Coastguard Worker }
272*9356374aSAndroid Build Coastguard Worker #endif
273*9356374aSAndroid Build Coastguard Worker
274*9356374aSAndroid Build Coastguard Worker template <typename RealType>
275*9356374aSAndroid Build Coastguard Worker template <typename URBG>
276*9356374aSAndroid Build Coastguard Worker typename beta_distribution<RealType>::result_type
AlgorithmJoehnk(URBG & g,const param_type & p)277*9356374aSAndroid Build Coastguard Worker beta_distribution<RealType>::AlgorithmJoehnk(
278*9356374aSAndroid Build Coastguard Worker URBG& g, // NOLINT(runtime/references)
279*9356374aSAndroid Build Coastguard Worker const param_type& p) {
280*9356374aSAndroid Build Coastguard Worker using random_internal::GeneratePositiveTag;
281*9356374aSAndroid Build Coastguard Worker using random_internal::GenerateRealFromBits;
282*9356374aSAndroid Build Coastguard Worker using real_type =
283*9356374aSAndroid Build Coastguard Worker absl::conditional_t<std::is_same<RealType, float>::value, float, double>;
284*9356374aSAndroid Build Coastguard Worker
285*9356374aSAndroid Build Coastguard Worker // Based on Joehnk, M. D. Erzeugung von betaverteilten und gammaverteilten
286*9356374aSAndroid Build Coastguard Worker // Zufallszahlen. Metrika 8.1 (1964): 5-15.
287*9356374aSAndroid Build Coastguard Worker // This method is described in Knuth, Vol 2 (Third Edition), pp 134.
288*9356374aSAndroid Build Coastguard Worker
289*9356374aSAndroid Build Coastguard Worker result_type u, v, x, y, z;
290*9356374aSAndroid Build Coastguard Worker for (;;) {
291*9356374aSAndroid Build Coastguard Worker u = GenerateRealFromBits<real_type, GeneratePositiveTag, false>(
292*9356374aSAndroid Build Coastguard Worker fast_u64_(g));
293*9356374aSAndroid Build Coastguard Worker v = GenerateRealFromBits<real_type, GeneratePositiveTag, false>(
294*9356374aSAndroid Build Coastguard Worker fast_u64_(g));
295*9356374aSAndroid Build Coastguard Worker
296*9356374aSAndroid Build Coastguard Worker // Direct method. std::pow is slow for float, so rely on the optimizer to
297*9356374aSAndroid Build Coastguard Worker // remove the std::pow() path for that case.
298*9356374aSAndroid Build Coastguard Worker if (!std::is_same<float, result_type>::value) {
299*9356374aSAndroid Build Coastguard Worker x = std::pow(u, p.a_);
300*9356374aSAndroid Build Coastguard Worker y = std::pow(v, p.b_);
301*9356374aSAndroid Build Coastguard Worker z = x + y;
302*9356374aSAndroid Build Coastguard Worker if (z > 1) {
303*9356374aSAndroid Build Coastguard Worker // Reject if and only if `x + y > 1.0`
304*9356374aSAndroid Build Coastguard Worker continue;
305*9356374aSAndroid Build Coastguard Worker }
306*9356374aSAndroid Build Coastguard Worker if (z > 0) {
307*9356374aSAndroid Build Coastguard Worker // When both alpha and beta are small, x and y are both close to 0, so
308*9356374aSAndroid Build Coastguard Worker // divide by (x+y) directly may result in nan.
309*9356374aSAndroid Build Coastguard Worker return x / z;
310*9356374aSAndroid Build Coastguard Worker }
311*9356374aSAndroid Build Coastguard Worker }
312*9356374aSAndroid Build Coastguard Worker
313*9356374aSAndroid Build Coastguard Worker // Log transform.
314*9356374aSAndroid Build Coastguard Worker // x = log( pow(u, p.a_) ), y = log( pow(v, p.b_) )
315*9356374aSAndroid Build Coastguard Worker // since u, v <= 1.0, x, y < 0.
316*9356374aSAndroid Build Coastguard Worker x = std::log(u) * p.a_;
317*9356374aSAndroid Build Coastguard Worker y = std::log(v) * p.b_;
318*9356374aSAndroid Build Coastguard Worker if (!std::isfinite(x) || !std::isfinite(y)) {
319*9356374aSAndroid Build Coastguard Worker continue;
320*9356374aSAndroid Build Coastguard Worker }
321*9356374aSAndroid Build Coastguard Worker // z = log( pow(u, a) + pow(v, b) )
322*9356374aSAndroid Build Coastguard Worker z = x > y ? (x + std::log(1 + std::exp(y - x)))
323*9356374aSAndroid Build Coastguard Worker : (y + std::log(1 + std::exp(x - y)));
324*9356374aSAndroid Build Coastguard Worker // Reject iff log(x+y) > 0.
325*9356374aSAndroid Build Coastguard Worker if (z > 0) {
326*9356374aSAndroid Build Coastguard Worker continue;
327*9356374aSAndroid Build Coastguard Worker }
328*9356374aSAndroid Build Coastguard Worker return std::exp(x - z);
329*9356374aSAndroid Build Coastguard Worker }
330*9356374aSAndroid Build Coastguard Worker }
331*9356374aSAndroid Build Coastguard Worker
332*9356374aSAndroid Build Coastguard Worker template <typename RealType>
333*9356374aSAndroid Build Coastguard Worker template <typename URBG>
334*9356374aSAndroid Build Coastguard Worker typename beta_distribution<RealType>::result_type
AlgorithmCheng(URBG & g,const param_type & p)335*9356374aSAndroid Build Coastguard Worker beta_distribution<RealType>::AlgorithmCheng(
336*9356374aSAndroid Build Coastguard Worker URBG& g, // NOLINT(runtime/references)
337*9356374aSAndroid Build Coastguard Worker const param_type& p) {
338*9356374aSAndroid Build Coastguard Worker using random_internal::GeneratePositiveTag;
339*9356374aSAndroid Build Coastguard Worker using random_internal::GenerateRealFromBits;
340*9356374aSAndroid Build Coastguard Worker using real_type =
341*9356374aSAndroid Build Coastguard Worker absl::conditional_t<std::is_same<RealType, float>::value, float, double>;
342*9356374aSAndroid Build Coastguard Worker
343*9356374aSAndroid Build Coastguard Worker // Based on Cheng, Russell CH. Generating beta variates with nonintegral
344*9356374aSAndroid Build Coastguard Worker // shape parameters. Communications of the ACM 21.4 (1978): 317-322.
345*9356374aSAndroid Build Coastguard Worker // (https://dl.acm.org/citation.cfm?id=359482).
346*9356374aSAndroid Build Coastguard Worker static constexpr result_type kLogFour =
347*9356374aSAndroid Build Coastguard Worker result_type(1.3862943611198906188344642429163531361); // log(4)
348*9356374aSAndroid Build Coastguard Worker static constexpr result_type kS =
349*9356374aSAndroid Build Coastguard Worker result_type(2.6094379124341003746007593332261876); // 1+log(5)
350*9356374aSAndroid Build Coastguard Worker
351*9356374aSAndroid Build Coastguard Worker const bool use_algorithm_ba = (p.method_ == param_type::CHENG_BA);
352*9356374aSAndroid Build Coastguard Worker result_type u1, u2, v, w, z, r, s, t, bw_inv, lhs;
353*9356374aSAndroid Build Coastguard Worker for (;;) {
354*9356374aSAndroid Build Coastguard Worker u1 = GenerateRealFromBits<real_type, GeneratePositiveTag, false>(
355*9356374aSAndroid Build Coastguard Worker fast_u64_(g));
356*9356374aSAndroid Build Coastguard Worker u2 = GenerateRealFromBits<real_type, GeneratePositiveTag, false>(
357*9356374aSAndroid Build Coastguard Worker fast_u64_(g));
358*9356374aSAndroid Build Coastguard Worker v = p.y_ * std::log(u1 / (1 - u1));
359*9356374aSAndroid Build Coastguard Worker w = p.a_ * std::exp(v);
360*9356374aSAndroid Build Coastguard Worker bw_inv = result_type(1) / (p.b_ + w);
361*9356374aSAndroid Build Coastguard Worker r = p.gamma_ * v - kLogFour;
362*9356374aSAndroid Build Coastguard Worker s = p.a_ + r - w;
363*9356374aSAndroid Build Coastguard Worker z = u1 * u1 * u2;
364*9356374aSAndroid Build Coastguard Worker if (!use_algorithm_ba && s + kS >= 5 * z) {
365*9356374aSAndroid Build Coastguard Worker break;
366*9356374aSAndroid Build Coastguard Worker }
367*9356374aSAndroid Build Coastguard Worker t = std::log(z);
368*9356374aSAndroid Build Coastguard Worker if (!use_algorithm_ba && s >= t) {
369*9356374aSAndroid Build Coastguard Worker break;
370*9356374aSAndroid Build Coastguard Worker }
371*9356374aSAndroid Build Coastguard Worker lhs = p.x_ * (p.log_x_ + std::log(bw_inv)) + r;
372*9356374aSAndroid Build Coastguard Worker if (lhs >= t) {
373*9356374aSAndroid Build Coastguard Worker break;
374*9356374aSAndroid Build Coastguard Worker }
375*9356374aSAndroid Build Coastguard Worker }
376*9356374aSAndroid Build Coastguard Worker return p.inverted_ ? (1 - w * bw_inv) : w * bw_inv;
377*9356374aSAndroid Build Coastguard Worker }
378*9356374aSAndroid Build Coastguard Worker
379*9356374aSAndroid Build Coastguard Worker template <typename RealType>
380*9356374aSAndroid Build Coastguard Worker template <typename URBG>
381*9356374aSAndroid Build Coastguard Worker typename beta_distribution<RealType>::result_type
operator()382*9356374aSAndroid Build Coastguard Worker beta_distribution<RealType>::operator()(URBG& g, // NOLINT(runtime/references)
383*9356374aSAndroid Build Coastguard Worker const param_type& p) {
384*9356374aSAndroid Build Coastguard Worker switch (p.method_) {
385*9356374aSAndroid Build Coastguard Worker case param_type::JOEHNK:
386*9356374aSAndroid Build Coastguard Worker return AlgorithmJoehnk(g, p);
387*9356374aSAndroid Build Coastguard Worker case param_type::CHENG_BA:
388*9356374aSAndroid Build Coastguard Worker ABSL_FALLTHROUGH_INTENDED;
389*9356374aSAndroid Build Coastguard Worker case param_type::CHENG_BB:
390*9356374aSAndroid Build Coastguard Worker return AlgorithmCheng(g, p);
391*9356374aSAndroid Build Coastguard Worker default:
392*9356374aSAndroid Build Coastguard Worker return DegenerateCase(g, p);
393*9356374aSAndroid Build Coastguard Worker }
394*9356374aSAndroid Build Coastguard Worker }
395*9356374aSAndroid Build Coastguard Worker
396*9356374aSAndroid Build Coastguard Worker template <typename CharT, typename Traits, typename RealType>
397*9356374aSAndroid Build Coastguard Worker std::basic_ostream<CharT, Traits>& operator<<(
398*9356374aSAndroid Build Coastguard Worker std::basic_ostream<CharT, Traits>& os, // NOLINT(runtime/references)
399*9356374aSAndroid Build Coastguard Worker const beta_distribution<RealType>& x) {
400*9356374aSAndroid Build Coastguard Worker auto saver = random_internal::make_ostream_state_saver(os);
401*9356374aSAndroid Build Coastguard Worker os.precision(random_internal::stream_precision_helper<RealType>::kPrecision);
402*9356374aSAndroid Build Coastguard Worker os << x.alpha() << os.fill() << x.beta();
403*9356374aSAndroid Build Coastguard Worker return os;
404*9356374aSAndroid Build Coastguard Worker }
405*9356374aSAndroid Build Coastguard Worker
406*9356374aSAndroid Build Coastguard Worker template <typename CharT, typename Traits, typename RealType>
407*9356374aSAndroid Build Coastguard Worker std::basic_istream<CharT, Traits>& operator>>(
408*9356374aSAndroid Build Coastguard Worker std::basic_istream<CharT, Traits>& is, // NOLINT(runtime/references)
409*9356374aSAndroid Build Coastguard Worker beta_distribution<RealType>& x) { // NOLINT(runtime/references)
410*9356374aSAndroid Build Coastguard Worker using result_type = typename beta_distribution<RealType>::result_type;
411*9356374aSAndroid Build Coastguard Worker using param_type = typename beta_distribution<RealType>::param_type;
412*9356374aSAndroid Build Coastguard Worker result_type alpha, beta;
413*9356374aSAndroid Build Coastguard Worker
414*9356374aSAndroid Build Coastguard Worker auto saver = random_internal::make_istream_state_saver(is);
415*9356374aSAndroid Build Coastguard Worker alpha = random_internal::read_floating_point<result_type>(is);
416*9356374aSAndroid Build Coastguard Worker if (is.fail()) return is;
417*9356374aSAndroid Build Coastguard Worker beta = random_internal::read_floating_point<result_type>(is);
418*9356374aSAndroid Build Coastguard Worker if (!is.fail()) {
419*9356374aSAndroid Build Coastguard Worker x.param(param_type(alpha, beta));
420*9356374aSAndroid Build Coastguard Worker }
421*9356374aSAndroid Build Coastguard Worker return is;
422*9356374aSAndroid Build Coastguard Worker }
423*9356374aSAndroid Build Coastguard Worker
424*9356374aSAndroid Build Coastguard Worker ABSL_NAMESPACE_END
425*9356374aSAndroid Build Coastguard Worker } // namespace absl
426*9356374aSAndroid Build Coastguard Worker
427*9356374aSAndroid Build Coastguard Worker #endif // ABSL_RANDOM_BETA_DISTRIBUTION_H_
428