xref: /aosp_15_r20/external/abseil-cpp/absl/random/beta_distribution_test.cc (revision 9356374a3709195abf420251b3e825997ff56c0f)
1*9356374aSAndroid Build Coastguard Worker // Copyright 2017 The Abseil Authors.
2*9356374aSAndroid Build Coastguard Worker //
3*9356374aSAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License");
4*9356374aSAndroid Build Coastguard Worker // you may not use this file except in compliance with the License.
5*9356374aSAndroid Build Coastguard Worker // You may obtain a copy of the License at
6*9356374aSAndroid Build Coastguard Worker //
7*9356374aSAndroid Build Coastguard Worker //      https://www.apache.org/licenses/LICENSE-2.0
8*9356374aSAndroid Build Coastguard Worker //
9*9356374aSAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*9356374aSAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
11*9356374aSAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*9356374aSAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
13*9356374aSAndroid Build Coastguard Worker // limitations under the License.
14*9356374aSAndroid Build Coastguard Worker 
15*9356374aSAndroid Build Coastguard Worker #include "absl/random/beta_distribution.h"
16*9356374aSAndroid Build Coastguard Worker 
17*9356374aSAndroid Build Coastguard Worker #include <algorithm>
18*9356374aSAndroid Build Coastguard Worker #include <cfloat>
19*9356374aSAndroid Build Coastguard Worker #include <cstddef>
20*9356374aSAndroid Build Coastguard Worker #include <cstdint>
21*9356374aSAndroid Build Coastguard Worker #include <iterator>
22*9356374aSAndroid Build Coastguard Worker #include <random>
23*9356374aSAndroid Build Coastguard Worker #include <sstream>
24*9356374aSAndroid Build Coastguard Worker #include <string>
25*9356374aSAndroid Build Coastguard Worker #include <type_traits>
26*9356374aSAndroid Build Coastguard Worker #include <unordered_map>
27*9356374aSAndroid Build Coastguard Worker #include <vector>
28*9356374aSAndroid Build Coastguard Worker 
29*9356374aSAndroid Build Coastguard Worker #include "gmock/gmock.h"
30*9356374aSAndroid Build Coastguard Worker #include "gtest/gtest.h"
31*9356374aSAndroid Build Coastguard Worker #include "absl/log/log.h"
32*9356374aSAndroid Build Coastguard Worker #include "absl/numeric/internal/representation.h"
33*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/chi_square.h"
34*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/distribution_test_util.h"
35*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/pcg_engine.h"
36*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/sequence_urbg.h"
37*9356374aSAndroid Build Coastguard Worker #include "absl/random/random.h"
38*9356374aSAndroid Build Coastguard Worker #include "absl/strings/str_cat.h"
39*9356374aSAndroid Build Coastguard Worker #include "absl/strings/str_format.h"
40*9356374aSAndroid Build Coastguard Worker #include "absl/strings/str_replace.h"
41*9356374aSAndroid Build Coastguard Worker #include "absl/strings/strip.h"
42*9356374aSAndroid Build Coastguard Worker 
43*9356374aSAndroid Build Coastguard Worker namespace {
44*9356374aSAndroid Build Coastguard Worker 
45*9356374aSAndroid Build Coastguard Worker template <typename IntType>
46*9356374aSAndroid Build Coastguard Worker class BetaDistributionInterfaceTest : public ::testing::Test {};
47*9356374aSAndroid Build Coastguard Worker 
ShouldExerciseLongDoubleTests()48*9356374aSAndroid Build Coastguard Worker constexpr bool ShouldExerciseLongDoubleTests() {
49*9356374aSAndroid Build Coastguard Worker   // long double arithmetic is not supported well by either GCC or Clang on
50*9356374aSAndroid Build Coastguard Worker   // most platforms specifically not when implemented in terms of double-double;
51*9356374aSAndroid Build Coastguard Worker   // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=99048,
52*9356374aSAndroid Build Coastguard Worker   // https://bugs.llvm.org/show_bug.cgi?id=49131, and
53*9356374aSAndroid Build Coastguard Worker   // https://bugs.llvm.org/show_bug.cgi?id=49132.
54*9356374aSAndroid Build Coastguard Worker   // So a conservative choice here is to disable long-double tests pretty much
55*9356374aSAndroid Build Coastguard Worker   // everywhere except on x64 but only if long double is not implemented as
56*9356374aSAndroid Build Coastguard Worker   // double-double.
57*9356374aSAndroid Build Coastguard Worker #if defined(__i686__) && defined(__x86_64__)
58*9356374aSAndroid Build Coastguard Worker   return !absl::numeric_internal::IsDoubleDouble();
59*9356374aSAndroid Build Coastguard Worker #else
60*9356374aSAndroid Build Coastguard Worker   return false;
61*9356374aSAndroid Build Coastguard Worker #endif
62*9356374aSAndroid Build Coastguard Worker }
63*9356374aSAndroid Build Coastguard Worker 
64*9356374aSAndroid Build Coastguard Worker using RealTypes = std::conditional<ShouldExerciseLongDoubleTests(),
65*9356374aSAndroid Build Coastguard Worker                                    ::testing::Types<float, double, long double>,
66*9356374aSAndroid Build Coastguard Worker                                    ::testing::Types<float, double>>::type;
67*9356374aSAndroid Build Coastguard Worker TYPED_TEST_SUITE(BetaDistributionInterfaceTest, RealTypes);
68*9356374aSAndroid Build Coastguard Worker 
TYPED_TEST(BetaDistributionInterfaceTest,SerializeTest)69*9356374aSAndroid Build Coastguard Worker TYPED_TEST(BetaDistributionInterfaceTest, SerializeTest) {
70*9356374aSAndroid Build Coastguard Worker   // The threshold for whether std::exp(1/a) is finite.
71*9356374aSAndroid Build Coastguard Worker   const TypeParam kSmallA =
72*9356374aSAndroid Build Coastguard Worker       1.0f / std::log((std::numeric_limits<TypeParam>::max)());
73*9356374aSAndroid Build Coastguard Worker   // The threshold for whether a * std::log(a) is finite.
74*9356374aSAndroid Build Coastguard Worker   const TypeParam kLargeA =
75*9356374aSAndroid Build Coastguard Worker       std::exp(std::log((std::numeric_limits<TypeParam>::max)()) -
76*9356374aSAndroid Build Coastguard Worker                std::log(std::log((std::numeric_limits<TypeParam>::max)())));
77*9356374aSAndroid Build Coastguard Worker   using param_type = typename absl::beta_distribution<TypeParam>::param_type;
78*9356374aSAndroid Build Coastguard Worker 
79*9356374aSAndroid Build Coastguard Worker   constexpr int kCount = 1000;
80*9356374aSAndroid Build Coastguard Worker   absl::InsecureBitGen gen;
81*9356374aSAndroid Build Coastguard Worker   const TypeParam kValues[] = {
82*9356374aSAndroid Build Coastguard Worker       TypeParam(1e-20), TypeParam(1e-12), TypeParam(1e-8), TypeParam(1e-4),
83*9356374aSAndroid Build Coastguard Worker       TypeParam(1e-3), TypeParam(0.1), TypeParam(0.25),
84*9356374aSAndroid Build Coastguard Worker       std::nextafter(TypeParam(0.5), TypeParam(0)),  // 0.5 - epsilon
85*9356374aSAndroid Build Coastguard Worker       std::nextafter(TypeParam(0.5), TypeParam(1)),  // 0.5 + epsilon
86*9356374aSAndroid Build Coastguard Worker       TypeParam(0.5), TypeParam(1.0),                //
87*9356374aSAndroid Build Coastguard Worker       std::nextafter(TypeParam(1), TypeParam(0)),    // 1 - epsilon
88*9356374aSAndroid Build Coastguard Worker       std::nextafter(TypeParam(1), TypeParam(2)),    // 1 + epsilon
89*9356374aSAndroid Build Coastguard Worker       TypeParam(12.5), TypeParam(1e2), TypeParam(1e8), TypeParam(1e12),
90*9356374aSAndroid Build Coastguard Worker       TypeParam(1e20),                        //
91*9356374aSAndroid Build Coastguard Worker       kSmallA,                                //
92*9356374aSAndroid Build Coastguard Worker       std::nextafter(kSmallA, TypeParam(0)),  //
93*9356374aSAndroid Build Coastguard Worker       std::nextafter(kSmallA, TypeParam(1)),  //
94*9356374aSAndroid Build Coastguard Worker       kLargeA,                                //
95*9356374aSAndroid Build Coastguard Worker       std::nextafter(kLargeA, TypeParam(0)),  //
96*9356374aSAndroid Build Coastguard Worker       std::nextafter(kLargeA, std::numeric_limits<TypeParam>::max()),
97*9356374aSAndroid Build Coastguard Worker       // Boundary cases.
98*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::max(),
99*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::epsilon(),
100*9356374aSAndroid Build Coastguard Worker       std::nextafter(std::numeric_limits<TypeParam>::min(),
101*9356374aSAndroid Build Coastguard Worker                      TypeParam(1)),                  // min + epsilon
102*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::min(),         // smallest normal
103*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::denorm_min(),  // smallest denorm
104*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::min() / 2,     // denorm
105*9356374aSAndroid Build Coastguard Worker       std::nextafter(std::numeric_limits<TypeParam>::min(),
106*9356374aSAndroid Build Coastguard Worker                      TypeParam(0)),  // denorm_max
107*9356374aSAndroid Build Coastguard Worker   };
108*9356374aSAndroid Build Coastguard Worker   for (TypeParam alpha : kValues) {
109*9356374aSAndroid Build Coastguard Worker     for (TypeParam beta : kValues) {
110*9356374aSAndroid Build Coastguard Worker       LOG(INFO) << absl::StreamFormat("Smoke test for Beta(%a, %a)", alpha,
111*9356374aSAndroid Build Coastguard Worker                                       beta);
112*9356374aSAndroid Build Coastguard Worker 
113*9356374aSAndroid Build Coastguard Worker       param_type param(alpha, beta);
114*9356374aSAndroid Build Coastguard Worker       absl::beta_distribution<TypeParam> before(alpha, beta);
115*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(before.alpha(), param.alpha());
116*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(before.beta(), param.beta());
117*9356374aSAndroid Build Coastguard Worker 
118*9356374aSAndroid Build Coastguard Worker       {
119*9356374aSAndroid Build Coastguard Worker         absl::beta_distribution<TypeParam> via_param(param);
120*9356374aSAndroid Build Coastguard Worker         EXPECT_EQ(via_param, before);
121*9356374aSAndroid Build Coastguard Worker         EXPECT_EQ(via_param.param(), before.param());
122*9356374aSAndroid Build Coastguard Worker       }
123*9356374aSAndroid Build Coastguard Worker 
124*9356374aSAndroid Build Coastguard Worker       // Smoke test.
125*9356374aSAndroid Build Coastguard Worker       for (int i = 0; i < kCount; ++i) {
126*9356374aSAndroid Build Coastguard Worker         auto sample = before(gen);
127*9356374aSAndroid Build Coastguard Worker         EXPECT_TRUE(std::isfinite(sample));
128*9356374aSAndroid Build Coastguard Worker         EXPECT_GE(sample, before.min());
129*9356374aSAndroid Build Coastguard Worker         EXPECT_LE(sample, before.max());
130*9356374aSAndroid Build Coastguard Worker       }
131*9356374aSAndroid Build Coastguard Worker 
132*9356374aSAndroid Build Coastguard Worker       // Validate stream serialization.
133*9356374aSAndroid Build Coastguard Worker       std::stringstream ss;
134*9356374aSAndroid Build Coastguard Worker       ss << before;
135*9356374aSAndroid Build Coastguard Worker       absl::beta_distribution<TypeParam> after(3.8f, 1.43f);
136*9356374aSAndroid Build Coastguard Worker       EXPECT_NE(before.alpha(), after.alpha());
137*9356374aSAndroid Build Coastguard Worker       EXPECT_NE(before.beta(), after.beta());
138*9356374aSAndroid Build Coastguard Worker       EXPECT_NE(before.param(), after.param());
139*9356374aSAndroid Build Coastguard Worker       EXPECT_NE(before, after);
140*9356374aSAndroid Build Coastguard Worker 
141*9356374aSAndroid Build Coastguard Worker       ss >> after;
142*9356374aSAndroid Build Coastguard Worker 
143*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(before.alpha(), after.alpha());
144*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(before.beta(), after.beta());
145*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(before, after)           //
146*9356374aSAndroid Build Coastguard Worker           << ss.str() << " "             //
147*9356374aSAndroid Build Coastguard Worker           << (ss.good() ? "good " : "")  //
148*9356374aSAndroid Build Coastguard Worker           << (ss.bad() ? "bad " : "")    //
149*9356374aSAndroid Build Coastguard Worker           << (ss.eof() ? "eof " : "")    //
150*9356374aSAndroid Build Coastguard Worker           << (ss.fail() ? "fail " : "");
151*9356374aSAndroid Build Coastguard Worker     }
152*9356374aSAndroid Build Coastguard Worker   }
153*9356374aSAndroid Build Coastguard Worker }
154*9356374aSAndroid Build Coastguard Worker 
TYPED_TEST(BetaDistributionInterfaceTest,DegenerateCases)155*9356374aSAndroid Build Coastguard Worker TYPED_TEST(BetaDistributionInterfaceTest, DegenerateCases) {
156*9356374aSAndroid Build Coastguard Worker   // We use a fixed bit generator for distribution accuracy tests.  This allows
157*9356374aSAndroid Build Coastguard Worker   // these tests to be deterministic, while still testing the qualify of the
158*9356374aSAndroid Build Coastguard Worker   // implementation.
159*9356374aSAndroid Build Coastguard Worker   absl::random_internal::pcg64_2018_engine rng(0x2B7E151628AED2A6);
160*9356374aSAndroid Build Coastguard Worker 
161*9356374aSAndroid Build Coastguard Worker   // Extreme cases when the params are abnormal.
162*9356374aSAndroid Build Coastguard Worker   constexpr int kCount = 1000;
163*9356374aSAndroid Build Coastguard Worker   const TypeParam kSmallValues[] = {
164*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::min(),
165*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::denorm_min(),
166*9356374aSAndroid Build Coastguard Worker       std::nextafter(std::numeric_limits<TypeParam>::min(),
167*9356374aSAndroid Build Coastguard Worker                      TypeParam(0)),  // denorm_max
168*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::epsilon(),
169*9356374aSAndroid Build Coastguard Worker   };
170*9356374aSAndroid Build Coastguard Worker   const TypeParam kLargeValues[] = {
171*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::max() * static_cast<TypeParam>(0.9999),
172*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::max() - 1,
173*9356374aSAndroid Build Coastguard Worker       std::numeric_limits<TypeParam>::max(),
174*9356374aSAndroid Build Coastguard Worker   };
175*9356374aSAndroid Build Coastguard Worker   {
176*9356374aSAndroid Build Coastguard Worker     // Small alpha and beta.
177*9356374aSAndroid Build Coastguard Worker     // Useful WolframAlpha plots:
178*9356374aSAndroid Build Coastguard Worker     //   * plot InverseBetaRegularized[x, 0.0001, 0.0001] from 0.495 to 0.505
179*9356374aSAndroid Build Coastguard Worker     //   * Beta[1.0, 0.0000001, 0.0000001]
180*9356374aSAndroid Build Coastguard Worker     //   * Beta[0.9999, 0.0000001, 0.0000001]
181*9356374aSAndroid Build Coastguard Worker     for (TypeParam alpha : kSmallValues) {
182*9356374aSAndroid Build Coastguard Worker       for (TypeParam beta : kSmallValues) {
183*9356374aSAndroid Build Coastguard Worker         int zeros = 0;
184*9356374aSAndroid Build Coastguard Worker         int ones = 0;
185*9356374aSAndroid Build Coastguard Worker         absl::beta_distribution<TypeParam> d(alpha, beta);
186*9356374aSAndroid Build Coastguard Worker         for (int i = 0; i < kCount; ++i) {
187*9356374aSAndroid Build Coastguard Worker           TypeParam x = d(rng);
188*9356374aSAndroid Build Coastguard Worker           if (x == 0.0) {
189*9356374aSAndroid Build Coastguard Worker             zeros++;
190*9356374aSAndroid Build Coastguard Worker           } else if (x == 1.0) {
191*9356374aSAndroid Build Coastguard Worker             ones++;
192*9356374aSAndroid Build Coastguard Worker           }
193*9356374aSAndroid Build Coastguard Worker         }
194*9356374aSAndroid Build Coastguard Worker         EXPECT_EQ(ones + zeros, kCount);
195*9356374aSAndroid Build Coastguard Worker         if (alpha == beta) {
196*9356374aSAndroid Build Coastguard Worker           EXPECT_NE(ones, 0);
197*9356374aSAndroid Build Coastguard Worker           EXPECT_NE(zeros, 0);
198*9356374aSAndroid Build Coastguard Worker         }
199*9356374aSAndroid Build Coastguard Worker       }
200*9356374aSAndroid Build Coastguard Worker     }
201*9356374aSAndroid Build Coastguard Worker   }
202*9356374aSAndroid Build Coastguard Worker   {
203*9356374aSAndroid Build Coastguard Worker     // Small alpha, large beta.
204*9356374aSAndroid Build Coastguard Worker     // Useful WolframAlpha plots:
205*9356374aSAndroid Build Coastguard Worker     //   * plot InverseBetaRegularized[x, 0.0001, 10000] from 0.995 to 1
206*9356374aSAndroid Build Coastguard Worker     //   * Beta[0, 0.0000001, 1000000]
207*9356374aSAndroid Build Coastguard Worker     //   * Beta[0.001, 0.0000001, 1000000]
208*9356374aSAndroid Build Coastguard Worker     //   * Beta[1, 0.0000001, 1000000]
209*9356374aSAndroid Build Coastguard Worker     for (TypeParam alpha : kSmallValues) {
210*9356374aSAndroid Build Coastguard Worker       for (TypeParam beta : kLargeValues) {
211*9356374aSAndroid Build Coastguard Worker         absl::beta_distribution<TypeParam> d(alpha, beta);
212*9356374aSAndroid Build Coastguard Worker         for (int i = 0; i < kCount; ++i) {
213*9356374aSAndroid Build Coastguard Worker           EXPECT_EQ(d(rng), 0.0);
214*9356374aSAndroid Build Coastguard Worker         }
215*9356374aSAndroid Build Coastguard Worker       }
216*9356374aSAndroid Build Coastguard Worker     }
217*9356374aSAndroid Build Coastguard Worker   }
218*9356374aSAndroid Build Coastguard Worker   {
219*9356374aSAndroid Build Coastguard Worker     // Large alpha, small beta.
220*9356374aSAndroid Build Coastguard Worker     // Useful WolframAlpha plots:
221*9356374aSAndroid Build Coastguard Worker     //   * plot InverseBetaRegularized[x, 10000, 0.0001] from 0 to 0.001
222*9356374aSAndroid Build Coastguard Worker     //   * Beta[0.99, 1000000, 0.0000001]
223*9356374aSAndroid Build Coastguard Worker     //   * Beta[1, 1000000, 0.0000001]
224*9356374aSAndroid Build Coastguard Worker     for (TypeParam alpha : kLargeValues) {
225*9356374aSAndroid Build Coastguard Worker       for (TypeParam beta : kSmallValues) {
226*9356374aSAndroid Build Coastguard Worker         absl::beta_distribution<TypeParam> d(alpha, beta);
227*9356374aSAndroid Build Coastguard Worker         for (int i = 0; i < kCount; ++i) {
228*9356374aSAndroid Build Coastguard Worker           EXPECT_EQ(d(rng), 1.0);
229*9356374aSAndroid Build Coastguard Worker         }
230*9356374aSAndroid Build Coastguard Worker       }
231*9356374aSAndroid Build Coastguard Worker     }
232*9356374aSAndroid Build Coastguard Worker   }
233*9356374aSAndroid Build Coastguard Worker   {
234*9356374aSAndroid Build Coastguard Worker     // Large alpha and beta.
235*9356374aSAndroid Build Coastguard Worker     absl::beta_distribution<TypeParam> d(std::numeric_limits<TypeParam>::max(),
236*9356374aSAndroid Build Coastguard Worker                                          std::numeric_limits<TypeParam>::max());
237*9356374aSAndroid Build Coastguard Worker     for (int i = 0; i < kCount; ++i) {
238*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(d(rng), 0.5);
239*9356374aSAndroid Build Coastguard Worker     }
240*9356374aSAndroid Build Coastguard Worker   }
241*9356374aSAndroid Build Coastguard Worker   {
242*9356374aSAndroid Build Coastguard Worker     // Large alpha and beta but unequal.
243*9356374aSAndroid Build Coastguard Worker     absl::beta_distribution<TypeParam> d(
244*9356374aSAndroid Build Coastguard Worker         std::numeric_limits<TypeParam>::max(),
245*9356374aSAndroid Build Coastguard Worker         std::numeric_limits<TypeParam>::max() * 0.9999);
246*9356374aSAndroid Build Coastguard Worker     for (int i = 0; i < kCount; ++i) {
247*9356374aSAndroid Build Coastguard Worker       TypeParam x = d(rng);
248*9356374aSAndroid Build Coastguard Worker       EXPECT_NE(x, 0.5f);
249*9356374aSAndroid Build Coastguard Worker       EXPECT_FLOAT_EQ(x, 0.500025f);
250*9356374aSAndroid Build Coastguard Worker     }
251*9356374aSAndroid Build Coastguard Worker   }
252*9356374aSAndroid Build Coastguard Worker }
253*9356374aSAndroid Build Coastguard Worker 
254*9356374aSAndroid Build Coastguard Worker class BetaDistributionModel {
255*9356374aSAndroid Build Coastguard Worker  public:
BetaDistributionModel(::testing::tuple<double,double> p)256*9356374aSAndroid Build Coastguard Worker   explicit BetaDistributionModel(::testing::tuple<double, double> p)
257*9356374aSAndroid Build Coastguard Worker       : alpha_(::testing::get<0>(p)), beta_(::testing::get<1>(p)) {}
258*9356374aSAndroid Build Coastguard Worker 
Mean() const259*9356374aSAndroid Build Coastguard Worker   double Mean() const { return alpha_ / (alpha_ + beta_); }
260*9356374aSAndroid Build Coastguard Worker 
Variance() const261*9356374aSAndroid Build Coastguard Worker   double Variance() const {
262*9356374aSAndroid Build Coastguard Worker     return alpha_ * beta_ / (alpha_ + beta_ + 1) / (alpha_ + beta_) /
263*9356374aSAndroid Build Coastguard Worker            (alpha_ + beta_);
264*9356374aSAndroid Build Coastguard Worker   }
265*9356374aSAndroid Build Coastguard Worker 
Kurtosis() const266*9356374aSAndroid Build Coastguard Worker   double Kurtosis() const {
267*9356374aSAndroid Build Coastguard Worker     return 3 + 6 *
268*9356374aSAndroid Build Coastguard Worker                    ((alpha_ - beta_) * (alpha_ - beta_) * (alpha_ + beta_ + 1) -
269*9356374aSAndroid Build Coastguard Worker                     alpha_ * beta_ * (2 + alpha_ + beta_)) /
270*9356374aSAndroid Build Coastguard Worker                    alpha_ / beta_ / (alpha_ + beta_ + 2) / (alpha_ + beta_ + 3);
271*9356374aSAndroid Build Coastguard Worker   }
272*9356374aSAndroid Build Coastguard Worker 
273*9356374aSAndroid Build Coastguard Worker  protected:
274*9356374aSAndroid Build Coastguard Worker   const double alpha_;
275*9356374aSAndroid Build Coastguard Worker   const double beta_;
276*9356374aSAndroid Build Coastguard Worker };
277*9356374aSAndroid Build Coastguard Worker 
278*9356374aSAndroid Build Coastguard Worker class BetaDistributionTest
279*9356374aSAndroid Build Coastguard Worker     : public ::testing::TestWithParam<::testing::tuple<double, double>>,
280*9356374aSAndroid Build Coastguard Worker       public BetaDistributionModel {
281*9356374aSAndroid Build Coastguard Worker  public:
BetaDistributionTest()282*9356374aSAndroid Build Coastguard Worker   BetaDistributionTest() : BetaDistributionModel(GetParam()) {}
283*9356374aSAndroid Build Coastguard Worker 
284*9356374aSAndroid Build Coastguard Worker  protected:
285*9356374aSAndroid Build Coastguard Worker   template <class D>
286*9356374aSAndroid Build Coastguard Worker   bool SingleZTestOnMeanAndVariance(double p, size_t samples);
287*9356374aSAndroid Build Coastguard Worker 
288*9356374aSAndroid Build Coastguard Worker   template <class D>
289*9356374aSAndroid Build Coastguard Worker   bool SingleChiSquaredTest(double p, size_t samples, size_t buckets);
290*9356374aSAndroid Build Coastguard Worker 
291*9356374aSAndroid Build Coastguard Worker   absl::InsecureBitGen rng_;
292*9356374aSAndroid Build Coastguard Worker };
293*9356374aSAndroid Build Coastguard Worker 
294*9356374aSAndroid Build Coastguard Worker template <class D>
SingleZTestOnMeanAndVariance(double p,size_t samples)295*9356374aSAndroid Build Coastguard Worker bool BetaDistributionTest::SingleZTestOnMeanAndVariance(double p,
296*9356374aSAndroid Build Coastguard Worker                                                         size_t samples) {
297*9356374aSAndroid Build Coastguard Worker   D dis(alpha_, beta_);
298*9356374aSAndroid Build Coastguard Worker 
299*9356374aSAndroid Build Coastguard Worker   std::vector<double> data;
300*9356374aSAndroid Build Coastguard Worker   data.reserve(samples);
301*9356374aSAndroid Build Coastguard Worker   for (size_t i = 0; i < samples; i++) {
302*9356374aSAndroid Build Coastguard Worker     const double variate = dis(rng_);
303*9356374aSAndroid Build Coastguard Worker     EXPECT_FALSE(std::isnan(variate));
304*9356374aSAndroid Build Coastguard Worker     // Note that equality is allowed on both sides.
305*9356374aSAndroid Build Coastguard Worker     EXPECT_GE(variate, 0.0);
306*9356374aSAndroid Build Coastguard Worker     EXPECT_LE(variate, 1.0);
307*9356374aSAndroid Build Coastguard Worker     data.push_back(variate);
308*9356374aSAndroid Build Coastguard Worker   }
309*9356374aSAndroid Build Coastguard Worker 
310*9356374aSAndroid Build Coastguard Worker   // We validate that the sample mean and sample variance are indeed from a
311*9356374aSAndroid Build Coastguard Worker   // Beta distribution with the given shape parameters.
312*9356374aSAndroid Build Coastguard Worker   const auto m = absl::random_internal::ComputeDistributionMoments(data);
313*9356374aSAndroid Build Coastguard Worker 
314*9356374aSAndroid Build Coastguard Worker   // The variance of the sample mean is variance / n.
315*9356374aSAndroid Build Coastguard Worker   const double mean_stddev = std::sqrt(Variance() / static_cast<double>(m.n));
316*9356374aSAndroid Build Coastguard Worker 
317*9356374aSAndroid Build Coastguard Worker   // The variance of the sample variance is (approximately):
318*9356374aSAndroid Build Coastguard Worker   //   (kurtosis - 1) * variance^2 / n
319*9356374aSAndroid Build Coastguard Worker   const double variance_stddev = std::sqrt(
320*9356374aSAndroid Build Coastguard Worker       (Kurtosis() - 1) * Variance() * Variance() / static_cast<double>(m.n));
321*9356374aSAndroid Build Coastguard Worker   // z score for the sample variance.
322*9356374aSAndroid Build Coastguard Worker   const double z_variance = (m.variance - Variance()) / variance_stddev;
323*9356374aSAndroid Build Coastguard Worker 
324*9356374aSAndroid Build Coastguard Worker   const double max_err = absl::random_internal::MaxErrorTolerance(p);
325*9356374aSAndroid Build Coastguard Worker   const double z_mean = absl::random_internal::ZScore(Mean(), m);
326*9356374aSAndroid Build Coastguard Worker   const bool pass =
327*9356374aSAndroid Build Coastguard Worker       absl::random_internal::Near("z", z_mean, 0.0, max_err) &&
328*9356374aSAndroid Build Coastguard Worker       absl::random_internal::Near("z_variance", z_variance, 0.0, max_err);
329*9356374aSAndroid Build Coastguard Worker   if (!pass) {
330*9356374aSAndroid Build Coastguard Worker     LOG(INFO) << "Beta(" << alpha_ << ", " << beta_ << "), mean: sample "
331*9356374aSAndroid Build Coastguard Worker               << m.mean << ", expect " << Mean() << ", which is "
332*9356374aSAndroid Build Coastguard Worker               << std::abs(m.mean - Mean()) / mean_stddev
333*9356374aSAndroid Build Coastguard Worker               << " stddevs away, variance: sample " << m.variance << ", expect "
334*9356374aSAndroid Build Coastguard Worker               << Variance() << ", which is "
335*9356374aSAndroid Build Coastguard Worker               << std::abs(m.variance - Variance()) / variance_stddev
336*9356374aSAndroid Build Coastguard Worker               << " stddevs away.";
337*9356374aSAndroid Build Coastguard Worker   }
338*9356374aSAndroid Build Coastguard Worker   return pass;
339*9356374aSAndroid Build Coastguard Worker }
340*9356374aSAndroid Build Coastguard Worker 
341*9356374aSAndroid Build Coastguard Worker template <class D>
SingleChiSquaredTest(double p,size_t samples,size_t buckets)342*9356374aSAndroid Build Coastguard Worker bool BetaDistributionTest::SingleChiSquaredTest(double p, size_t samples,
343*9356374aSAndroid Build Coastguard Worker                                                 size_t buckets) {
344*9356374aSAndroid Build Coastguard Worker   constexpr double kErr = 1e-7;
345*9356374aSAndroid Build Coastguard Worker   std::vector<double> cutoffs, expected;
346*9356374aSAndroid Build Coastguard Worker   const double bucket_width = 1.0 / static_cast<double>(buckets);
347*9356374aSAndroid Build Coastguard Worker   int i = 1;
348*9356374aSAndroid Build Coastguard Worker   int unmerged_buckets = 0;
349*9356374aSAndroid Build Coastguard Worker   for (; i < buckets; ++i) {
350*9356374aSAndroid Build Coastguard Worker     const double p = bucket_width * static_cast<double>(i);
351*9356374aSAndroid Build Coastguard Worker     const double boundary =
352*9356374aSAndroid Build Coastguard Worker         absl::random_internal::BetaIncompleteInv(alpha_, beta_, p);
353*9356374aSAndroid Build Coastguard Worker     // The intention is to add `boundary` to the list of `cutoffs`. It becomes
354*9356374aSAndroid Build Coastguard Worker     // problematic, however, when the boundary values are not monotone, due to
355*9356374aSAndroid Build Coastguard Worker     // numerical issues when computing the inverse regularized incomplete
356*9356374aSAndroid Build Coastguard Worker     // Beta function. In these cases, we merge that bucket with its previous
357*9356374aSAndroid Build Coastguard Worker     // neighbor and merge their expected counts.
358*9356374aSAndroid Build Coastguard Worker     if ((cutoffs.empty() && boundary < kErr) ||
359*9356374aSAndroid Build Coastguard Worker         (!cutoffs.empty() && boundary <= cutoffs.back())) {
360*9356374aSAndroid Build Coastguard Worker       unmerged_buckets++;
361*9356374aSAndroid Build Coastguard Worker       continue;
362*9356374aSAndroid Build Coastguard Worker     }
363*9356374aSAndroid Build Coastguard Worker     if (boundary >= 1.0 - 1e-10) {
364*9356374aSAndroid Build Coastguard Worker       break;
365*9356374aSAndroid Build Coastguard Worker     }
366*9356374aSAndroid Build Coastguard Worker     cutoffs.push_back(boundary);
367*9356374aSAndroid Build Coastguard Worker     expected.push_back(static_cast<double>(1 + unmerged_buckets) *
368*9356374aSAndroid Build Coastguard Worker                        bucket_width * static_cast<double>(samples));
369*9356374aSAndroid Build Coastguard Worker     unmerged_buckets = 0;
370*9356374aSAndroid Build Coastguard Worker   }
371*9356374aSAndroid Build Coastguard Worker   cutoffs.push_back(std::numeric_limits<double>::infinity());
372*9356374aSAndroid Build Coastguard Worker   // Merge all remaining buckets.
373*9356374aSAndroid Build Coastguard Worker   expected.push_back(static_cast<double>(buckets - i + 1) * bucket_width *
374*9356374aSAndroid Build Coastguard Worker                      static_cast<double>(samples));
375*9356374aSAndroid Build Coastguard Worker   // Make sure that we don't merge all the buckets, making this test
376*9356374aSAndroid Build Coastguard Worker   // meaningless.
377*9356374aSAndroid Build Coastguard Worker   EXPECT_GE(cutoffs.size(), 3) << alpha_ << ", " << beta_;
378*9356374aSAndroid Build Coastguard Worker 
379*9356374aSAndroid Build Coastguard Worker   D dis(alpha_, beta_);
380*9356374aSAndroid Build Coastguard Worker 
381*9356374aSAndroid Build Coastguard Worker   std::vector<int32_t> counts(cutoffs.size(), 0);
382*9356374aSAndroid Build Coastguard Worker   for (int i = 0; i < samples; i++) {
383*9356374aSAndroid Build Coastguard Worker     const double x = dis(rng_);
384*9356374aSAndroid Build Coastguard Worker     auto it = std::upper_bound(cutoffs.begin(), cutoffs.end(), x);
385*9356374aSAndroid Build Coastguard Worker     counts[std::distance(cutoffs.begin(), it)]++;
386*9356374aSAndroid Build Coastguard Worker   }
387*9356374aSAndroid Build Coastguard Worker 
388*9356374aSAndroid Build Coastguard Worker   // Null-hypothesis is that the distribution is beta distributed with the
389*9356374aSAndroid Build Coastguard Worker   // provided alpha, beta params (not estimated from the data).
390*9356374aSAndroid Build Coastguard Worker   const int dof = cutoffs.size() - 1;
391*9356374aSAndroid Build Coastguard Worker 
392*9356374aSAndroid Build Coastguard Worker   const double chi_square = absl::random_internal::ChiSquare(
393*9356374aSAndroid Build Coastguard Worker       counts.begin(), counts.end(), expected.begin(), expected.end());
394*9356374aSAndroid Build Coastguard Worker   const bool pass =
395*9356374aSAndroid Build Coastguard Worker       (absl::random_internal::ChiSquarePValue(chi_square, dof) >= p);
396*9356374aSAndroid Build Coastguard Worker   if (!pass) {
397*9356374aSAndroid Build Coastguard Worker     for (size_t i = 0; i < cutoffs.size(); i++) {
398*9356374aSAndroid Build Coastguard Worker       LOG(INFO) << "cutoff[" << i << "] = " << cutoffs[i] << ", actual count "
399*9356374aSAndroid Build Coastguard Worker                 << counts[i] << ", expected " << static_cast<int>(expected[i]);
400*9356374aSAndroid Build Coastguard Worker     }
401*9356374aSAndroid Build Coastguard Worker 
402*9356374aSAndroid Build Coastguard Worker     LOG(INFO) << "Beta(" << alpha_ << ", " << beta_ << ") "
403*9356374aSAndroid Build Coastguard Worker               << absl::random_internal::kChiSquared << " " << chi_square
404*9356374aSAndroid Build Coastguard Worker               << ", p = "
405*9356374aSAndroid Build Coastguard Worker               << absl::random_internal::ChiSquarePValue(chi_square, dof);
406*9356374aSAndroid Build Coastguard Worker   }
407*9356374aSAndroid Build Coastguard Worker   return pass;
408*9356374aSAndroid Build Coastguard Worker }
409*9356374aSAndroid Build Coastguard Worker 
TEST_P(BetaDistributionTest,TestSampleStatistics)410*9356374aSAndroid Build Coastguard Worker TEST_P(BetaDistributionTest, TestSampleStatistics) {
411*9356374aSAndroid Build Coastguard Worker   static constexpr int kRuns = 20;
412*9356374aSAndroid Build Coastguard Worker   static constexpr double kPFail = 0.02;
413*9356374aSAndroid Build Coastguard Worker   const double p =
414*9356374aSAndroid Build Coastguard Worker       absl::random_internal::RequiredSuccessProbability(kPFail, kRuns);
415*9356374aSAndroid Build Coastguard Worker   static constexpr int kSampleCount = 10000;
416*9356374aSAndroid Build Coastguard Worker   static constexpr int kBucketCount = 100;
417*9356374aSAndroid Build Coastguard Worker   int failed = 0;
418*9356374aSAndroid Build Coastguard Worker   for (int i = 0; i < kRuns; ++i) {
419*9356374aSAndroid Build Coastguard Worker     if (!SingleZTestOnMeanAndVariance<absl::beta_distribution<double>>(
420*9356374aSAndroid Build Coastguard Worker             p, kSampleCount)) {
421*9356374aSAndroid Build Coastguard Worker       failed++;
422*9356374aSAndroid Build Coastguard Worker     }
423*9356374aSAndroid Build Coastguard Worker     if (!SingleChiSquaredTest<absl::beta_distribution<double>>(
424*9356374aSAndroid Build Coastguard Worker             0.005, kSampleCount, kBucketCount)) {
425*9356374aSAndroid Build Coastguard Worker       failed++;
426*9356374aSAndroid Build Coastguard Worker     }
427*9356374aSAndroid Build Coastguard Worker   }
428*9356374aSAndroid Build Coastguard Worker   // Set so that the test is not flaky at --runs_per_test=10000
429*9356374aSAndroid Build Coastguard Worker   EXPECT_LE(failed, 5);
430*9356374aSAndroid Build Coastguard Worker }
431*9356374aSAndroid Build Coastguard Worker 
ParamName(const::testing::TestParamInfo<::testing::tuple<double,double>> & info)432*9356374aSAndroid Build Coastguard Worker std::string ParamName(
433*9356374aSAndroid Build Coastguard Worker     const ::testing::TestParamInfo<::testing::tuple<double, double>>& info) {
434*9356374aSAndroid Build Coastguard Worker   std::string name = absl::StrCat("alpha_", ::testing::get<0>(info.param),
435*9356374aSAndroid Build Coastguard Worker                                   "__beta_", ::testing::get<1>(info.param));
436*9356374aSAndroid Build Coastguard Worker   return absl::StrReplaceAll(name, {{"+", "_"}, {"-", "_"}, {".", "_"}});
437*9356374aSAndroid Build Coastguard Worker }
438*9356374aSAndroid Build Coastguard Worker 
439*9356374aSAndroid Build Coastguard Worker INSTANTIATE_TEST_SUITE_P(
440*9356374aSAndroid Build Coastguard Worker     TestSampleStatisticsCombinations, BetaDistributionTest,
441*9356374aSAndroid Build Coastguard Worker     ::testing::Combine(::testing::Values(0.1, 0.2, 0.9, 1.1, 2.5, 10.0, 123.4),
442*9356374aSAndroid Build Coastguard Worker                        ::testing::Values(0.1, 0.2, 0.9, 1.1, 2.5, 10.0, 123.4)),
443*9356374aSAndroid Build Coastguard Worker     ParamName);
444*9356374aSAndroid Build Coastguard Worker 
445*9356374aSAndroid Build Coastguard Worker INSTANTIATE_TEST_SUITE_P(
446*9356374aSAndroid Build Coastguard Worker     TestSampleStatistics_SelectedPairs, BetaDistributionTest,
447*9356374aSAndroid Build Coastguard Worker     ::testing::Values(std::make_pair(0.5, 1000), std::make_pair(1000, 0.5),
448*9356374aSAndroid Build Coastguard Worker                       std::make_pair(900, 1000), std::make_pair(10000, 20000),
449*9356374aSAndroid Build Coastguard Worker                       std::make_pair(4e5, 2e7), std::make_pair(1e7, 1e5)),
450*9356374aSAndroid Build Coastguard Worker     ParamName);
451*9356374aSAndroid Build Coastguard Worker 
452*9356374aSAndroid Build Coastguard Worker // NOTE: absl::beta_distribution is not guaranteed to be stable.
TEST(BetaDistributionTest,StabilityTest)453*9356374aSAndroid Build Coastguard Worker TEST(BetaDistributionTest, StabilityTest) {
454*9356374aSAndroid Build Coastguard Worker   // absl::beta_distribution stability relies on the stability of
455*9356374aSAndroid Build Coastguard Worker   // absl::random_interna::RandU64ToDouble, std::exp, std::log, std::pow,
456*9356374aSAndroid Build Coastguard Worker   // and std::sqrt.
457*9356374aSAndroid Build Coastguard Worker   //
458*9356374aSAndroid Build Coastguard Worker   // This test also depends on the stability of std::frexp.
459*9356374aSAndroid Build Coastguard Worker   using testing::ElementsAre;
460*9356374aSAndroid Build Coastguard Worker   absl::random_internal::sequence_urbg urbg({
461*9356374aSAndroid Build Coastguard Worker       0xffff00000000e6c8ull, 0xffff0000000006c8ull, 0x800003766295CFA9ull,
462*9356374aSAndroid Build Coastguard Worker       0x11C819684E734A41ull, 0x832603766295CFA9ull, 0x7fbe76c8b4395800ull,
463*9356374aSAndroid Build Coastguard Worker       0xB3472DCA7B14A94Aull, 0x0003eb76f6f7f755ull, 0xFFCEA50FDB2F953Bull,
464*9356374aSAndroid Build Coastguard Worker       0x13CCA830EB61BD96ull, 0x0334FE1EAA0363CFull, 0x00035C904C70A239ull,
465*9356374aSAndroid Build Coastguard Worker       0x00009E0BCBAADE14ull, 0x0000000000622CA7ull, 0x4864f22c059bf29eull,
466*9356374aSAndroid Build Coastguard Worker       0x247856d8b862665cull, 0xe46e86e9a1337e10ull, 0xd8c8541f3519b133ull,
467*9356374aSAndroid Build Coastguard Worker       0xffe75b52c567b9e4ull, 0xfffff732e5709c5bull, 0xff1f7f0b983532acull,
468*9356374aSAndroid Build Coastguard Worker       0x1ec2e8986d2362caull, 0xC332DDEFBE6C5AA5ull, 0x6558218568AB9702ull,
469*9356374aSAndroid Build Coastguard Worker       0x2AEF7DAD5B6E2F84ull, 0x1521B62829076170ull, 0xECDD4775619F1510ull,
470*9356374aSAndroid Build Coastguard Worker       0x814c8e35fe9a961aull, 0x0c3cd59c9b638a02ull, 0xcb3bb6478a07715cull,
471*9356374aSAndroid Build Coastguard Worker       0x1224e62c978bbc7full, 0x671ef2cb04e81f6eull, 0x3c1cbd811eaf1808ull,
472*9356374aSAndroid Build Coastguard Worker       0x1bbc23cfa8fac721ull, 0xa4c2cda65e596a51ull, 0xb77216fad37adf91ull,
473*9356374aSAndroid Build Coastguard Worker       0x836d794457c08849ull, 0xe083df03475f49d7ull, 0xbc9feb512e6b0d6cull,
474*9356374aSAndroid Build Coastguard Worker       0xb12d74fdd718c8c5ull, 0x12ff09653bfbe4caull, 0x8dd03a105bc4ee7eull,
475*9356374aSAndroid Build Coastguard Worker       0x5738341045ba0d85ull, 0xf3fd722dc65ad09eull, 0xfa14fd21ea2a5705ull,
476*9356374aSAndroid Build Coastguard Worker       0xffe6ea4d6edb0c73ull, 0xD07E9EFE2BF11FB4ull, 0x95DBDA4DAE909198ull,
477*9356374aSAndroid Build Coastguard Worker       0xEAAD8E716B93D5A0ull, 0xD08ED1D0AFC725E0ull, 0x8E3C5B2F8E7594B7ull,
478*9356374aSAndroid Build Coastguard Worker       0x8FF6E2FBF2122B64ull, 0x8888B812900DF01Cull, 0x4FAD5EA0688FC31Cull,
479*9356374aSAndroid Build Coastguard Worker       0xD1CFF191B3A8C1ADull, 0x2F2F2218BE0E1777ull, 0xEA752DFE8B021FA1ull,
480*9356374aSAndroid Build Coastguard Worker   });
481*9356374aSAndroid Build Coastguard Worker 
482*9356374aSAndroid Build Coastguard Worker   // Convert the real-valued result into a unit64 where we compare
483*9356374aSAndroid Build Coastguard Worker   // 5 (float) or 10 (double) decimal digits plus the base-2 exponent.
484*9356374aSAndroid Build Coastguard Worker   auto float_to_u64 = [](float d) {
485*9356374aSAndroid Build Coastguard Worker     int exp = 0;
486*9356374aSAndroid Build Coastguard Worker     auto f = std::frexp(d, &exp);
487*9356374aSAndroid Build Coastguard Worker     return (static_cast<uint64_t>(1e5 * f) * 10000) + std::abs(exp);
488*9356374aSAndroid Build Coastguard Worker   };
489*9356374aSAndroid Build Coastguard Worker   auto double_to_u64 = [](double d) {
490*9356374aSAndroid Build Coastguard Worker     int exp = 0;
491*9356374aSAndroid Build Coastguard Worker     auto f = std::frexp(d, &exp);
492*9356374aSAndroid Build Coastguard Worker     return (static_cast<uint64_t>(1e10 * f) * 10000) + std::abs(exp);
493*9356374aSAndroid Build Coastguard Worker   };
494*9356374aSAndroid Build Coastguard Worker 
495*9356374aSAndroid Build Coastguard Worker   std::vector<uint64_t> output(20);
496*9356374aSAndroid Build Coastguard Worker   {
497*9356374aSAndroid Build Coastguard Worker     // Algorithm Joehnk (float)
498*9356374aSAndroid Build Coastguard Worker     absl::beta_distribution<float> dist(0.1f, 0.2f);
499*9356374aSAndroid Build Coastguard Worker     std::generate(std::begin(output), std::end(output),
500*9356374aSAndroid Build Coastguard Worker                   [&] { return float_to_u64(dist(urbg)); });
501*9356374aSAndroid Build Coastguard Worker     EXPECT_EQ(44, urbg.invocations());
502*9356374aSAndroid Build Coastguard Worker     EXPECT_THAT(output,  //
503*9356374aSAndroid Build Coastguard Worker                 testing::ElementsAre(
504*9356374aSAndroid Build Coastguard Worker                     998340000, 619030004, 500000001, 999990000, 996280000,
505*9356374aSAndroid Build Coastguard Worker                     500000001, 844740004, 847210001, 999970000, 872320000,
506*9356374aSAndroid Build Coastguard Worker                     585480007, 933280000, 869080042, 647670031, 528240004,
507*9356374aSAndroid Build Coastguard Worker                     969980004, 626050008, 915930002, 833440033, 878040015));
508*9356374aSAndroid Build Coastguard Worker   }
509*9356374aSAndroid Build Coastguard Worker 
510*9356374aSAndroid Build Coastguard Worker   urbg.reset();
511*9356374aSAndroid Build Coastguard Worker   {
512*9356374aSAndroid Build Coastguard Worker     // Algorithm Joehnk (double)
513*9356374aSAndroid Build Coastguard Worker     absl::beta_distribution<double> dist(0.1, 0.2);
514*9356374aSAndroid Build Coastguard Worker     std::generate(std::begin(output), std::end(output),
515*9356374aSAndroid Build Coastguard Worker                   [&] { return double_to_u64(dist(urbg)); });
516*9356374aSAndroid Build Coastguard Worker     EXPECT_EQ(44, urbg.invocations());
517*9356374aSAndroid Build Coastguard Worker     EXPECT_THAT(
518*9356374aSAndroid Build Coastguard Worker         output,  //
519*9356374aSAndroid Build Coastguard Worker         testing::ElementsAre(
520*9356374aSAndroid Build Coastguard Worker             99834713000000, 61903356870004, 50000000000001, 99999721170000,
521*9356374aSAndroid Build Coastguard Worker             99628374770000, 99999999990000, 84474397860004, 84721276240001,
522*9356374aSAndroid Build Coastguard Worker             99997407490000, 87232528120000, 58548364780007, 93328932910000,
523*9356374aSAndroid Build Coastguard Worker             86908237770042, 64767917930031, 52824581970004, 96998544140004,
524*9356374aSAndroid Build Coastguard Worker             62605946270008, 91593604380002, 83345031740033, 87804397230015));
525*9356374aSAndroid Build Coastguard Worker   }
526*9356374aSAndroid Build Coastguard Worker 
527*9356374aSAndroid Build Coastguard Worker   urbg.reset();
528*9356374aSAndroid Build Coastguard Worker   {
529*9356374aSAndroid Build Coastguard Worker     // Algorithm Cheng 1
530*9356374aSAndroid Build Coastguard Worker     absl::beta_distribution<double> dist(0.9, 2.0);
531*9356374aSAndroid Build Coastguard Worker     std::generate(std::begin(output), std::end(output),
532*9356374aSAndroid Build Coastguard Worker                   [&] { return double_to_u64(dist(urbg)); });
533*9356374aSAndroid Build Coastguard Worker     EXPECT_EQ(62, urbg.invocations());
534*9356374aSAndroid Build Coastguard Worker     EXPECT_THAT(
535*9356374aSAndroid Build Coastguard Worker         output,  //
536*9356374aSAndroid Build Coastguard Worker         testing::ElementsAre(
537*9356374aSAndroid Build Coastguard Worker             62069004780001, 64433204450001, 53607416560000, 89644295430008,
538*9356374aSAndroid Build Coastguard Worker             61434586310019, 55172615890002, 62187161490000, 56433684810003,
539*9356374aSAndroid Build Coastguard Worker             80454622050005, 86418558710003, 92920514700001, 64645184680001,
540*9356374aSAndroid Build Coastguard Worker             58549183380000, 84881283650005, 71078728590002, 69949694970000,
541*9356374aSAndroid Build Coastguard Worker             73157461710001, 68592191300001, 70747623900000, 78584696930005));
542*9356374aSAndroid Build Coastguard Worker   }
543*9356374aSAndroid Build Coastguard Worker 
544*9356374aSAndroid Build Coastguard Worker   urbg.reset();
545*9356374aSAndroid Build Coastguard Worker   {
546*9356374aSAndroid Build Coastguard Worker     // Algorithm Cheng 2
547*9356374aSAndroid Build Coastguard Worker     absl::beta_distribution<double> dist(1.5, 2.5);
548*9356374aSAndroid Build Coastguard Worker     std::generate(std::begin(output), std::end(output),
549*9356374aSAndroid Build Coastguard Worker                   [&] { return double_to_u64(dist(urbg)); });
550*9356374aSAndroid Build Coastguard Worker     EXPECT_EQ(54, urbg.invocations());
551*9356374aSAndroid Build Coastguard Worker     EXPECT_THAT(
552*9356374aSAndroid Build Coastguard Worker         output,  //
553*9356374aSAndroid Build Coastguard Worker         testing::ElementsAre(
554*9356374aSAndroid Build Coastguard Worker             75000029250001, 76751482860001, 53264575220000, 69193133650005,
555*9356374aSAndroid Build Coastguard Worker             78028324470013, 91573587560002, 59167523770000, 60658618560002,
556*9356374aSAndroid Build Coastguard Worker             80075870540000, 94141320460004, 63196592770003, 78883906300002,
557*9356374aSAndroid Build Coastguard Worker             96797992590001, 76907587800001, 56645167560000, 65408302280003,
558*9356374aSAndroid Build Coastguard Worker             53401156320001, 64731238570000, 83065573750001, 79788333820001));
559*9356374aSAndroid Build Coastguard Worker   }
560*9356374aSAndroid Build Coastguard Worker }
561*9356374aSAndroid Build Coastguard Worker 
562*9356374aSAndroid Build Coastguard Worker // This is an implementation-specific test. If any part of the implementation
563*9356374aSAndroid Build Coastguard Worker // changes, then it is likely that this test will change as well.  Also, if
564*9356374aSAndroid Build Coastguard Worker // dependencies of the distribution change, such as RandU64ToDouble, then this
565*9356374aSAndroid Build Coastguard Worker // is also likely to change.
TEST(BetaDistributionTest,AlgorithmBounds)566*9356374aSAndroid Build Coastguard Worker TEST(BetaDistributionTest, AlgorithmBounds) {
567*9356374aSAndroid Build Coastguard Worker #if (defined(__i386__) || defined(_M_IX86)) && FLT_EVAL_METHOD != 0
568*9356374aSAndroid Build Coastguard Worker   // We're using an x87-compatible FPU, and intermediate operations are
569*9356374aSAndroid Build Coastguard Worker   // performed with 80-bit floats. This produces slightly different results from
570*9356374aSAndroid Build Coastguard Worker   // what we expect below.
571*9356374aSAndroid Build Coastguard Worker   GTEST_SKIP()
572*9356374aSAndroid Build Coastguard Worker       << "Skipping the test because we detected x87 floating-point semantics";
573*9356374aSAndroid Build Coastguard Worker #endif
574*9356374aSAndroid Build Coastguard Worker 
575*9356374aSAndroid Build Coastguard Worker   {
576*9356374aSAndroid Build Coastguard Worker     absl::random_internal::sequence_urbg urbg(
577*9356374aSAndroid Build Coastguard Worker         {0x7fbe76c8b4395800ull, 0x8000000000000000ull});
578*9356374aSAndroid Build Coastguard Worker     // u=0.499, v=0.5
579*9356374aSAndroid Build Coastguard Worker     absl::beta_distribution<double> dist(1e-4, 1e-4);
580*9356374aSAndroid Build Coastguard Worker     double a = dist(urbg);
581*9356374aSAndroid Build Coastguard Worker     EXPECT_EQ(a, 2.0202860861567108529e-09);
582*9356374aSAndroid Build Coastguard Worker     EXPECT_EQ(2, urbg.invocations());
583*9356374aSAndroid Build Coastguard Worker   }
584*9356374aSAndroid Build Coastguard Worker 
585*9356374aSAndroid Build Coastguard Worker   // Test that both the float & double algorithms appropriately reject the
586*9356374aSAndroid Build Coastguard Worker   // initial draw.
587*9356374aSAndroid Build Coastguard Worker   {
588*9356374aSAndroid Build Coastguard Worker     // 1/alpha = 1/beta = 2.
589*9356374aSAndroid Build Coastguard Worker     absl::beta_distribution<float> dist(0.5, 0.5);
590*9356374aSAndroid Build Coastguard Worker 
591*9356374aSAndroid Build Coastguard Worker     // first two outputs are close to 1.0 - epsilon,
592*9356374aSAndroid Build Coastguard Worker     // thus:  (u ^ 2 + v ^ 2) > 1.0
593*9356374aSAndroid Build Coastguard Worker     absl::random_internal::sequence_urbg urbg(
594*9356374aSAndroid Build Coastguard Worker         {0xffff00000006e6c8ull, 0xffff00000007c7c8ull, 0x800003766295CFA9ull,
595*9356374aSAndroid Build Coastguard Worker          0x11C819684E734A41ull});
596*9356374aSAndroid Build Coastguard Worker     {
597*9356374aSAndroid Build Coastguard Worker       double y = absl::beta_distribution<double>(0.5, 0.5)(urbg);
598*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(4, urbg.invocations());
599*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(y, 0.9810668952633862) << y;
600*9356374aSAndroid Build Coastguard Worker     }
601*9356374aSAndroid Build Coastguard Worker 
602*9356374aSAndroid Build Coastguard Worker     // ...and:  log(u) * a ~= log(v) * b ~= -0.02
603*9356374aSAndroid Build Coastguard Worker     // thus z ~= -0.02 + log(1 + e(~0))
604*9356374aSAndroid Build Coastguard Worker     //        ~= -0.02 + 0.69
605*9356374aSAndroid Build Coastguard Worker     // thus z > 0
606*9356374aSAndroid Build Coastguard Worker     urbg.reset();
607*9356374aSAndroid Build Coastguard Worker     {
608*9356374aSAndroid Build Coastguard Worker       float x = absl::beta_distribution<float>(0.5, 0.5)(urbg);
609*9356374aSAndroid Build Coastguard Worker       EXPECT_EQ(4, urbg.invocations());
610*9356374aSAndroid Build Coastguard Worker       EXPECT_NEAR(0.98106688261032104, x, 0.0000005) << x << "f";
611*9356374aSAndroid Build Coastguard Worker     }
612*9356374aSAndroid Build Coastguard Worker   }
613*9356374aSAndroid Build Coastguard Worker }
614*9356374aSAndroid Build Coastguard Worker 
615*9356374aSAndroid Build Coastguard Worker }  // namespace
616