xref: /aosp_15_r20/external/abseil-cpp/absl/random/discrete_distribution.cc (revision 9356374a3709195abf420251b3e825997ff56c0f)
1*9356374aSAndroid Build Coastguard Worker // Copyright 2017 The Abseil Authors.
2*9356374aSAndroid Build Coastguard Worker //
3*9356374aSAndroid Build Coastguard Worker // Licensed under the Apache License, Version 2.0 (the "License");
4*9356374aSAndroid Build Coastguard Worker // you may not use this file except in compliance with the License.
5*9356374aSAndroid Build Coastguard Worker // You may obtain a copy of the License at
6*9356374aSAndroid Build Coastguard Worker //
7*9356374aSAndroid Build Coastguard Worker //      https://www.apache.org/licenses/LICENSE-2.0
8*9356374aSAndroid Build Coastguard Worker //
9*9356374aSAndroid Build Coastguard Worker // Unless required by applicable law or agreed to in writing, software
10*9356374aSAndroid Build Coastguard Worker // distributed under the License is distributed on an "AS IS" BASIS,
11*9356374aSAndroid Build Coastguard Worker // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*9356374aSAndroid Build Coastguard Worker // See the License for the specific language governing permissions and
13*9356374aSAndroid Build Coastguard Worker // limitations under the License.
14*9356374aSAndroid Build Coastguard Worker 
15*9356374aSAndroid Build Coastguard Worker #include "absl/random/discrete_distribution.h"
16*9356374aSAndroid Build Coastguard Worker 
17*9356374aSAndroid Build Coastguard Worker namespace absl {
18*9356374aSAndroid Build Coastguard Worker ABSL_NAMESPACE_BEGIN
19*9356374aSAndroid Build Coastguard Worker namespace random_internal {
20*9356374aSAndroid Build Coastguard Worker 
21*9356374aSAndroid Build Coastguard Worker // Initializes the distribution table for Walker's Aliasing algorithm, described
22*9356374aSAndroid Build Coastguard Worker // in Knuth, Vol 2. as well as in https://en.wikipedia.org/wiki/Alias_method
InitDiscreteDistribution(std::vector<double> * probabilities)23*9356374aSAndroid Build Coastguard Worker std::vector<std::pair<double, size_t>> InitDiscreteDistribution(
24*9356374aSAndroid Build Coastguard Worker     std::vector<double>* probabilities) {
25*9356374aSAndroid Build Coastguard Worker   // The empty-case should already be handled by the constructor.
26*9356374aSAndroid Build Coastguard Worker   assert(probabilities);
27*9356374aSAndroid Build Coastguard Worker   assert(!probabilities->empty());
28*9356374aSAndroid Build Coastguard Worker 
29*9356374aSAndroid Build Coastguard Worker   // Step 1. Normalize the input probabilities to 1.0.
30*9356374aSAndroid Build Coastguard Worker   double sum = std::accumulate(std::begin(*probabilities),
31*9356374aSAndroid Build Coastguard Worker                                std::end(*probabilities), 0.0);
32*9356374aSAndroid Build Coastguard Worker   if (std::fabs(sum - 1.0) > 1e-6) {
33*9356374aSAndroid Build Coastguard Worker     // Scale `probabilities` only when the sum is too far from 1.0.  Scaling
34*9356374aSAndroid Build Coastguard Worker     // unconditionally will alter the probabilities slightly.
35*9356374aSAndroid Build Coastguard Worker     for (double& item : *probabilities) {
36*9356374aSAndroid Build Coastguard Worker       item = item / sum;
37*9356374aSAndroid Build Coastguard Worker     }
38*9356374aSAndroid Build Coastguard Worker   }
39*9356374aSAndroid Build Coastguard Worker 
40*9356374aSAndroid Build Coastguard Worker   // Step 2. At this point `probabilities` is set to the conditional
41*9356374aSAndroid Build Coastguard Worker   // probabilities of each element which sum to 1.0, to within reasonable error.
42*9356374aSAndroid Build Coastguard Worker   // These values are used to construct the proportional probability tables for
43*9356374aSAndroid Build Coastguard Worker   // the selection phases of Walker's Aliasing algorithm.
44*9356374aSAndroid Build Coastguard Worker   //
45*9356374aSAndroid Build Coastguard Worker   // To construct the table, pick an element which is under-full (i.e., an
46*9356374aSAndroid Build Coastguard Worker   // element for which `(*probabilities)[i] < 1.0/n`), and pair it with an
47*9356374aSAndroid Build Coastguard Worker   // element which is over-full (i.e., an element for which
48*9356374aSAndroid Build Coastguard Worker   // `(*probabilities)[i] > 1.0/n`). The smaller value can always be retired.
49*9356374aSAndroid Build Coastguard Worker   // The larger may still be greater than 1.0/n, or may now be less than 1.0/n,
50*9356374aSAndroid Build Coastguard Worker   // and put back onto the appropriate collection.
51*9356374aSAndroid Build Coastguard Worker   const size_t n = probabilities->size();
52*9356374aSAndroid Build Coastguard Worker   std::vector<std::pair<double, size_t>> q;
53*9356374aSAndroid Build Coastguard Worker   q.reserve(n);
54*9356374aSAndroid Build Coastguard Worker 
55*9356374aSAndroid Build Coastguard Worker   std::vector<size_t> over;
56*9356374aSAndroid Build Coastguard Worker   std::vector<size_t> under;
57*9356374aSAndroid Build Coastguard Worker   size_t idx = 0;
58*9356374aSAndroid Build Coastguard Worker   for (const double item : *probabilities) {
59*9356374aSAndroid Build Coastguard Worker     assert(item >= 0);
60*9356374aSAndroid Build Coastguard Worker     const double v = item * n;
61*9356374aSAndroid Build Coastguard Worker     q.emplace_back(v, 0);
62*9356374aSAndroid Build Coastguard Worker     if (v < 1.0) {
63*9356374aSAndroid Build Coastguard Worker       under.push_back(idx++);
64*9356374aSAndroid Build Coastguard Worker     } else {
65*9356374aSAndroid Build Coastguard Worker       over.push_back(idx++);
66*9356374aSAndroid Build Coastguard Worker     }
67*9356374aSAndroid Build Coastguard Worker   }
68*9356374aSAndroid Build Coastguard Worker   while (!over.empty() && !under.empty()) {
69*9356374aSAndroid Build Coastguard Worker     auto lo = under.back();
70*9356374aSAndroid Build Coastguard Worker     under.pop_back();
71*9356374aSAndroid Build Coastguard Worker     auto hi = over.back();
72*9356374aSAndroid Build Coastguard Worker     over.pop_back();
73*9356374aSAndroid Build Coastguard Worker 
74*9356374aSAndroid Build Coastguard Worker     q[lo].second = hi;
75*9356374aSAndroid Build Coastguard Worker     const double r = q[hi].first - (1.0 - q[lo].first);
76*9356374aSAndroid Build Coastguard Worker     q[hi].first = r;
77*9356374aSAndroid Build Coastguard Worker     if (r < 1.0) {
78*9356374aSAndroid Build Coastguard Worker       under.push_back(hi);
79*9356374aSAndroid Build Coastguard Worker     } else {
80*9356374aSAndroid Build Coastguard Worker       over.push_back(hi);
81*9356374aSAndroid Build Coastguard Worker     }
82*9356374aSAndroid Build Coastguard Worker   }
83*9356374aSAndroid Build Coastguard Worker 
84*9356374aSAndroid Build Coastguard Worker   // Due to rounding errors, there may be un-paired elements in either
85*9356374aSAndroid Build Coastguard Worker   // collection; these should all be values near 1.0.  For these values, set `q`
86*9356374aSAndroid Build Coastguard Worker   // to 1.0 and set the alternate to the identity.
87*9356374aSAndroid Build Coastguard Worker   for (auto i : over) {
88*9356374aSAndroid Build Coastguard Worker     q[i] = {1.0, i};
89*9356374aSAndroid Build Coastguard Worker   }
90*9356374aSAndroid Build Coastguard Worker   for (auto i : under) {
91*9356374aSAndroid Build Coastguard Worker     q[i] = {1.0, i};
92*9356374aSAndroid Build Coastguard Worker   }
93*9356374aSAndroid Build Coastguard Worker   return q;
94*9356374aSAndroid Build Coastguard Worker }
95*9356374aSAndroid Build Coastguard Worker 
96*9356374aSAndroid Build Coastguard Worker }  // namespace random_internal
97*9356374aSAndroid Build Coastguard Worker ABSL_NAMESPACE_END
98*9356374aSAndroid Build Coastguard Worker }  // namespace absl
99