xref: /aosp_15_r20/external/abseil-cpp/absl/random/poisson_distribution.h (revision 9356374a3709195abf420251b3e825997ff56c0f)
1*9356374aSAndroid Build Coastguard Worker // Copyright 2017 The Abseil Authors.
2*9356374aSAndroid Build Coastguard Worker //
3*9356374aSAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License");
4*9356374aSAndroid Build Coastguard Worker // you may not use this file except in compliance with the License.
5*9356374aSAndroid Build Coastguard Worker // You may obtain a copy of the License at
6*9356374aSAndroid Build Coastguard Worker //
7*9356374aSAndroid Build Coastguard Worker //      https://www.apache.org/licenses/LICENSE-2.0
8*9356374aSAndroid Build Coastguard Worker //
9*9356374aSAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*9356374aSAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
11*9356374aSAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*9356374aSAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
13*9356374aSAndroid Build Coastguard Worker // limitations under the License.
14*9356374aSAndroid Build Coastguard Worker 
15*9356374aSAndroid Build Coastguard Worker #ifndef ABSL_RANDOM_POISSON_DISTRIBUTION_H_
16*9356374aSAndroid Build Coastguard Worker #define ABSL_RANDOM_POISSON_DISTRIBUTION_H_
17*9356374aSAndroid Build Coastguard Worker 
18*9356374aSAndroid Build Coastguard Worker #include <cassert>
19*9356374aSAndroid Build Coastguard Worker #include <cmath>
20*9356374aSAndroid Build Coastguard Worker #include <istream>
21*9356374aSAndroid Build Coastguard Worker #include <limits>
22*9356374aSAndroid Build Coastguard Worker #include <ostream>
23*9356374aSAndroid Build Coastguard Worker #include <type_traits>
24*9356374aSAndroid Build Coastguard Worker 
25*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/fast_uniform_bits.h"
26*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/fastmath.h"
27*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/generate_real.h"
28*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/iostream_state_saver.h"
29*9356374aSAndroid Build Coastguard Worker #include "absl/random/internal/traits.h"
30*9356374aSAndroid Build Coastguard Worker 
31*9356374aSAndroid Build Coastguard Worker namespace absl {
32*9356374aSAndroid Build Coastguard Worker ABSL_NAMESPACE_BEGIN
33*9356374aSAndroid Build Coastguard Worker 
34*9356374aSAndroid Build Coastguard Worker // absl::poisson_distribution:
35*9356374aSAndroid Build Coastguard Worker // Generates discrete variates conforming to a Poisson distribution.
36*9356374aSAndroid Build Coastguard Worker //   p(n) = (mean^n / n!) exp(-mean)
37*9356374aSAndroid Build Coastguard Worker //
38*9356374aSAndroid Build Coastguard Worker // Depending on the parameter, the distribution selects one of the following
39*9356374aSAndroid Build Coastguard Worker // algorithms:
40*9356374aSAndroid Build Coastguard Worker // * The standard algorithm, attributed to Knuth, extended using a split method
41*9356374aSAndroid Build Coastguard Worker // for larger values
42*9356374aSAndroid Build Coastguard Worker // * The "Ratio of Uniforms as a convenient method for sampling from classical
43*9356374aSAndroid Build Coastguard Worker // discrete distributions", Stadlober, 1989.
44*9356374aSAndroid Build Coastguard Worker // http://www.sciencedirect.com/science/article/pii/0377042790903495
45*9356374aSAndroid Build Coastguard Worker //
46*9356374aSAndroid Build Coastguard Worker // NOTE: param_type.mean() is a double, which permits values larger than
47*9356374aSAndroid Build Coastguard Worker // poisson_distribution<IntType>::max(), however this should be avoided and
48*9356374aSAndroid Build Coastguard Worker // the distribution results are limited to the max() value.
49*9356374aSAndroid Build Coastguard Worker //
50*9356374aSAndroid Build Coastguard Worker // The goals of this implementation are to provide good performance while still
51*9356374aSAndroid Build Coastguard Worker // beig thread-safe: This limits the implementation to not using lgamma provided
52*9356374aSAndroid Build Coastguard Worker // by <math.h>.
53*9356374aSAndroid Build Coastguard Worker //
54*9356374aSAndroid Build Coastguard Worker template <typename IntType = int>
55*9356374aSAndroid Build Coastguard Worker class poisson_distribution {
56*9356374aSAndroid Build Coastguard Worker  public:
57*9356374aSAndroid Build Coastguard Worker   using result_type = IntType;
58*9356374aSAndroid Build Coastguard Worker 
59*9356374aSAndroid Build Coastguard Worker   class param_type {
60*9356374aSAndroid Build Coastguard Worker    public:
61*9356374aSAndroid Build Coastguard Worker     using distribution_type = poisson_distribution;
62*9356374aSAndroid Build Coastguard Worker     explicit param_type(double mean = 1.0);
63*9356374aSAndroid Build Coastguard Worker 
mean()64*9356374aSAndroid Build Coastguard Worker     double mean() const { return mean_; }
65*9356374aSAndroid Build Coastguard Worker 
66*9356374aSAndroid Build Coastguard Worker     friend bool operator==(const param_type& a, const param_type& b) {
67*9356374aSAndroid Build Coastguard Worker       return a.mean_ == b.mean_;
68*9356374aSAndroid Build Coastguard Worker     }
69*9356374aSAndroid Build Coastguard Worker 
70*9356374aSAndroid Build Coastguard Worker     friend bool operator!=(const param_type& a, const param_type& b) {
71*9356374aSAndroid Build Coastguard Worker       return !(a == b);
72*9356374aSAndroid Build Coastguard Worker     }
73*9356374aSAndroid Build Coastguard Worker 
74*9356374aSAndroid Build Coastguard Worker    private:
75*9356374aSAndroid Build Coastguard Worker     friend class poisson_distribution;
76*9356374aSAndroid Build Coastguard Worker 
77*9356374aSAndroid Build Coastguard Worker     double mean_;
78*9356374aSAndroid Build Coastguard Worker     double emu_;  // e ^ -mean_
79*9356374aSAndroid Build Coastguard Worker     double lmu_;  // ln(mean_)
80*9356374aSAndroid Build Coastguard Worker     double s_;
81*9356374aSAndroid Build Coastguard Worker     double log_k_;
82*9356374aSAndroid Build Coastguard Worker     int split_;
83*9356374aSAndroid Build Coastguard Worker 
84*9356374aSAndroid Build Coastguard Worker     static_assert(random_internal::IsIntegral<IntType>::value,
85*9356374aSAndroid Build Coastguard Worker                   "Class-template absl::poisson_distribution<> must be "
86*9356374aSAndroid Build Coastguard Worker                   "parameterized using an integral type.");
87*9356374aSAndroid Build Coastguard Worker   };
88*9356374aSAndroid Build Coastguard Worker 
poisson_distribution()89*9356374aSAndroid Build Coastguard Worker   poisson_distribution() : poisson_distribution(1.0) {}
90*9356374aSAndroid Build Coastguard Worker 
poisson_distribution(double mean)91*9356374aSAndroid Build Coastguard Worker   explicit poisson_distribution(double mean) : param_(mean) {}
92*9356374aSAndroid Build Coastguard Worker 
poisson_distribution(const param_type & p)93*9356374aSAndroid Build Coastguard Worker   explicit poisson_distribution(const param_type& p) : param_(p) {}
94*9356374aSAndroid Build Coastguard Worker 
reset()95*9356374aSAndroid Build Coastguard Worker   void reset() {}
96*9356374aSAndroid Build Coastguard Worker 
97*9356374aSAndroid Build Coastguard Worker   // generating functions
98*9356374aSAndroid Build Coastguard Worker   template <typename URBG>
operator()99*9356374aSAndroid Build Coastguard Worker   result_type operator()(URBG& g) {  // NOLINT(runtime/references)
100*9356374aSAndroid Build Coastguard Worker     return (*this)(g, param_);
101*9356374aSAndroid Build Coastguard Worker   }
102*9356374aSAndroid Build Coastguard Worker 
103*9356374aSAndroid Build Coastguard Worker   template <typename URBG>
104*9356374aSAndroid Build Coastguard Worker   result_type operator()(URBG& g,  // NOLINT(runtime/references)
105*9356374aSAndroid Build Coastguard Worker                          const param_type& p);
106*9356374aSAndroid Build Coastguard Worker 
param()107*9356374aSAndroid Build Coastguard Worker   param_type param() const { return param_; }
param(const param_type & p)108*9356374aSAndroid Build Coastguard Worker   void param(const param_type& p) { param_ = p; }
109*9356374aSAndroid Build Coastguard Worker 
result_type(min)110*9356374aSAndroid Build Coastguard Worker   result_type(min)() const { return 0; }
result_type(max)111*9356374aSAndroid Build Coastguard Worker   result_type(max)() const { return (std::numeric_limits<result_type>::max)(); }
112*9356374aSAndroid Build Coastguard Worker 
mean()113*9356374aSAndroid Build Coastguard Worker   double mean() const { return param_.mean(); }
114*9356374aSAndroid Build Coastguard Worker 
115*9356374aSAndroid Build Coastguard Worker   friend bool operator==(const poisson_distribution& a,
116*9356374aSAndroid Build Coastguard Worker                          const poisson_distribution& b) {
117*9356374aSAndroid Build Coastguard Worker     return a.param_ == b.param_;
118*9356374aSAndroid Build Coastguard Worker   }
119*9356374aSAndroid Build Coastguard Worker   friend bool operator!=(const poisson_distribution& a,
120*9356374aSAndroid Build Coastguard Worker                          const poisson_distribution& b) {
121*9356374aSAndroid Build Coastguard Worker     return a.param_ != b.param_;
122*9356374aSAndroid Build Coastguard Worker   }
123*9356374aSAndroid Build Coastguard Worker 
124*9356374aSAndroid Build Coastguard Worker  private:
125*9356374aSAndroid Build Coastguard Worker   param_type param_;
126*9356374aSAndroid Build Coastguard Worker   random_internal::FastUniformBits<uint64_t> fast_u64_;
127*9356374aSAndroid Build Coastguard Worker };
128*9356374aSAndroid Build Coastguard Worker 
129*9356374aSAndroid Build Coastguard Worker // -----------------------------------------------------------------------------
130*9356374aSAndroid Build Coastguard Worker // Implementation details follow
131*9356374aSAndroid Build Coastguard Worker // -----------------------------------------------------------------------------
132*9356374aSAndroid Build Coastguard Worker 
133*9356374aSAndroid Build Coastguard Worker template <typename IntType>
param_type(double mean)134*9356374aSAndroid Build Coastguard Worker poisson_distribution<IntType>::param_type::param_type(double mean)
135*9356374aSAndroid Build Coastguard Worker     : mean_(mean), split_(0) {
136*9356374aSAndroid Build Coastguard Worker   assert(mean >= 0);
137*9356374aSAndroid Build Coastguard Worker   assert(mean <=
138*9356374aSAndroid Build Coastguard Worker          static_cast<double>((std::numeric_limits<result_type>::max)()));
139*9356374aSAndroid Build Coastguard Worker   // As a defensive measure, avoid large values of the mean.  The rejection
140*9356374aSAndroid Build Coastguard Worker   // algorithm used does not support very large values well.  It my be worth
141*9356374aSAndroid Build Coastguard Worker   // changing algorithms to better deal with these cases.
142*9356374aSAndroid Build Coastguard Worker   assert(mean <= 1e10);
143*9356374aSAndroid Build Coastguard Worker   if (mean_ < 10) {
144*9356374aSAndroid Build Coastguard Worker     // For small lambda, use the knuth method.
145*9356374aSAndroid Build Coastguard Worker     split_ = 1;
146*9356374aSAndroid Build Coastguard Worker     emu_ = std::exp(-mean_);
147*9356374aSAndroid Build Coastguard Worker   } else if (mean_ <= 50) {
148*9356374aSAndroid Build Coastguard Worker     // Use split-knuth method.
149*9356374aSAndroid Build Coastguard Worker     split_ = 1 + static_cast<int>(mean_ / 10.0);
150*9356374aSAndroid Build Coastguard Worker     emu_ = std::exp(-mean_ / static_cast<double>(split_));
151*9356374aSAndroid Build Coastguard Worker   } else {
152*9356374aSAndroid Build Coastguard Worker     // Use ratio of uniforms method.
153*9356374aSAndroid Build Coastguard Worker     constexpr double k2E = 0.7357588823428846;
154*9356374aSAndroid Build Coastguard Worker     constexpr double kSA = 0.4494580810294493;
155*9356374aSAndroid Build Coastguard Worker 
156*9356374aSAndroid Build Coastguard Worker     lmu_ = std::log(mean_);
157*9356374aSAndroid Build Coastguard Worker     double a = mean_ + 0.5;
158*9356374aSAndroid Build Coastguard Worker     s_ = kSA + std::sqrt(k2E * a);
159*9356374aSAndroid Build Coastguard Worker     const double mode = std::ceil(mean_) - 1;
160*9356374aSAndroid Build Coastguard Worker     log_k_ = lmu_ * mode - absl::random_internal::StirlingLogFactorial(mode);
161*9356374aSAndroid Build Coastguard Worker   }
162*9356374aSAndroid Build Coastguard Worker }
163*9356374aSAndroid Build Coastguard Worker 
164*9356374aSAndroid Build Coastguard Worker template <typename IntType>
165*9356374aSAndroid Build Coastguard Worker template <typename URBG>
166*9356374aSAndroid Build Coastguard Worker typename poisson_distribution<IntType>::result_type
operator()167*9356374aSAndroid Build Coastguard Worker poisson_distribution<IntType>::operator()(
168*9356374aSAndroid Build Coastguard Worker     URBG& g,  // NOLINT(runtime/references)
169*9356374aSAndroid Build Coastguard Worker     const param_type& p) {
170*9356374aSAndroid Build Coastguard Worker   using random_internal::GeneratePositiveTag;
171*9356374aSAndroid Build Coastguard Worker   using random_internal::GenerateRealFromBits;
172*9356374aSAndroid Build Coastguard Worker   using random_internal::GenerateSignedTag;
173*9356374aSAndroid Build Coastguard Worker 
174*9356374aSAndroid Build Coastguard Worker   if (p.split_ != 0) {
175*9356374aSAndroid Build Coastguard Worker     // Use Knuth's algorithm with range splitting to avoid floating-point
176*9356374aSAndroid Build Coastguard Worker     // errors. Knuth's algorithm is: Ui is a sequence of uniform variates on
177*9356374aSAndroid Build Coastguard Worker     // (0,1); return the number of variates required for product(Ui) <
178*9356374aSAndroid Build Coastguard Worker     // exp(-lambda).
179*9356374aSAndroid Build Coastguard Worker     //
180*9356374aSAndroid Build Coastguard Worker     // The expected number of variates required for Knuth's method can be
181*9356374aSAndroid Build Coastguard Worker     // computed as follows:
182*9356374aSAndroid Build Coastguard Worker     // The expected value of U is 0.5, so solving for 0.5^n < exp(-lambda) gives
183*9356374aSAndroid Build Coastguard Worker     // the expected number of uniform variates
184*9356374aSAndroid Build Coastguard Worker     // required for a given lambda, which is:
185*9356374aSAndroid Build Coastguard Worker     //  lambda = [2, 5,  9, 10, 11, 12, 13, 14, 15, 16, 17]
186*9356374aSAndroid Build Coastguard Worker     //  n      = [3, 8, 13, 15, 16, 18, 19, 21, 22, 24, 25]
187*9356374aSAndroid Build Coastguard Worker     //
188*9356374aSAndroid Build Coastguard Worker     result_type n = 0;
189*9356374aSAndroid Build Coastguard Worker     for (int split = p.split_; split > 0; --split) {
190*9356374aSAndroid Build Coastguard Worker       double r = 1.0;
191*9356374aSAndroid Build Coastguard Worker       do {
192*9356374aSAndroid Build Coastguard Worker         r *= GenerateRealFromBits<double, GeneratePositiveTag, true>(
193*9356374aSAndroid Build Coastguard Worker             fast_u64_(g));  // U(-1, 0)
194*9356374aSAndroid Build Coastguard Worker         ++n;
195*9356374aSAndroid Build Coastguard Worker       } while (r > p.emu_);
196*9356374aSAndroid Build Coastguard Worker       --n;
197*9356374aSAndroid Build Coastguard Worker     }
198*9356374aSAndroid Build Coastguard Worker     return n;
199*9356374aSAndroid Build Coastguard Worker   }
200*9356374aSAndroid Build Coastguard Worker 
201*9356374aSAndroid Build Coastguard Worker   // Use ratio of uniforms method.
202*9356374aSAndroid Build Coastguard Worker   //
203*9356374aSAndroid Build Coastguard Worker   // Let u ~ Uniform(0, 1), v ~ Uniform(-1, 1),
204*9356374aSAndroid Build Coastguard Worker   //     a = lambda + 1/2,
205*9356374aSAndroid Build Coastguard Worker   //     s = 1.5 - sqrt(3/e) + sqrt(2(lambda + 1/2)/e),
206*9356374aSAndroid Build Coastguard Worker   //     x = s * v/u + a.
207*9356374aSAndroid Build Coastguard Worker   // P(floor(x) = k | u^2 < f(floor(x))/k), where
208*9356374aSAndroid Build Coastguard Worker   // f(m) = lambda^m exp(-lambda)/ m!, for 0 <= m, and f(m) = 0 otherwise,
209*9356374aSAndroid Build Coastguard Worker   // and k = max(f).
210*9356374aSAndroid Build Coastguard Worker   const double a = p.mean_ + 0.5;
211*9356374aSAndroid Build Coastguard Worker   for (;;) {
212*9356374aSAndroid Build Coastguard Worker     const double u = GenerateRealFromBits<double, GeneratePositiveTag, false>(
213*9356374aSAndroid Build Coastguard Worker         fast_u64_(g));  // U(0, 1)
214*9356374aSAndroid Build Coastguard Worker     const double v = GenerateRealFromBits<double, GenerateSignedTag, false>(
215*9356374aSAndroid Build Coastguard Worker         fast_u64_(g));  // U(-1, 1)
216*9356374aSAndroid Build Coastguard Worker 
217*9356374aSAndroid Build Coastguard Worker     const double x = std::floor(p.s_ * v / u + a);
218*9356374aSAndroid Build Coastguard Worker     if (x < 0) continue;  // f(negative) = 0
219*9356374aSAndroid Build Coastguard Worker     const double rhs = x * p.lmu_;
220*9356374aSAndroid Build Coastguard Worker     // clang-format off
221*9356374aSAndroid Build Coastguard Worker     double s = (x <= 1.0) ? 0.0
222*9356374aSAndroid Build Coastguard Worker              : (x == 2.0) ? 0.693147180559945
223*9356374aSAndroid Build Coastguard Worker              : absl::random_internal::StirlingLogFactorial(x);
224*9356374aSAndroid Build Coastguard Worker     // clang-format on
225*9356374aSAndroid Build Coastguard Worker     const double lhs = 2.0 * std::log(u) + p.log_k_ + s;
226*9356374aSAndroid Build Coastguard Worker     if (lhs < rhs) {
227*9356374aSAndroid Build Coastguard Worker       return x > static_cast<double>((max)())
228*9356374aSAndroid Build Coastguard Worker                  ? (max)()
229*9356374aSAndroid Build Coastguard Worker                  : static_cast<result_type>(x);  // f(x)/k >= u^2
230*9356374aSAndroid Build Coastguard Worker     }
231*9356374aSAndroid Build Coastguard Worker   }
232*9356374aSAndroid Build Coastguard Worker }
233*9356374aSAndroid Build Coastguard Worker 
234*9356374aSAndroid Build Coastguard Worker template <typename CharT, typename Traits, typename IntType>
235*9356374aSAndroid Build Coastguard Worker std::basic_ostream<CharT, Traits>& operator<<(
236*9356374aSAndroid Build Coastguard Worker     std::basic_ostream<CharT, Traits>& os,  // NOLINT(runtime/references)
237*9356374aSAndroid Build Coastguard Worker     const poisson_distribution<IntType>& x) {
238*9356374aSAndroid Build Coastguard Worker   auto saver = random_internal::make_ostream_state_saver(os);
239*9356374aSAndroid Build Coastguard Worker   os.precision(random_internal::stream_precision_helper<double>::kPrecision);
240*9356374aSAndroid Build Coastguard Worker   os << x.mean();
241*9356374aSAndroid Build Coastguard Worker   return os;
242*9356374aSAndroid Build Coastguard Worker }
243*9356374aSAndroid Build Coastguard Worker 
244*9356374aSAndroid Build Coastguard Worker template <typename CharT, typename Traits, typename IntType>
245*9356374aSAndroid Build Coastguard Worker std::basic_istream<CharT, Traits>& operator>>(
246*9356374aSAndroid Build Coastguard Worker     std::basic_istream<CharT, Traits>& is,  // NOLINT(runtime/references)
247*9356374aSAndroid Build Coastguard Worker     poisson_distribution<IntType>& x) {     // NOLINT(runtime/references)
248*9356374aSAndroid Build Coastguard Worker   using param_type = typename poisson_distribution<IntType>::param_type;
249*9356374aSAndroid Build Coastguard Worker 
250*9356374aSAndroid Build Coastguard Worker   auto saver = random_internal::make_istream_state_saver(is);
251*9356374aSAndroid Build Coastguard Worker   double mean = random_internal::read_floating_point<double>(is);
252*9356374aSAndroid Build Coastguard Worker   if (!is.fail()) {
253*9356374aSAndroid Build Coastguard Worker     x.param(param_type(mean));
254*9356374aSAndroid Build Coastguard Worker   }
255*9356374aSAndroid Build Coastguard Worker   return is;
256*9356374aSAndroid Build Coastguard Worker }
257*9356374aSAndroid Build Coastguard Worker 
258*9356374aSAndroid Build Coastguard Worker ABSL_NAMESPACE_END
259*9356374aSAndroid Build Coastguard Worker }  // namespace absl
260*9356374aSAndroid Build Coastguard Worker 
261*9356374aSAndroid Build Coastguard Worker #endif  // ABSL_RANDOM_POISSON_DISTRIBUTION_H_
262