xref: /aosp_15_r20/external/abseil-cpp/absl/random/log_uniform_int_distribution.h (revision 9356374a3709195abf420251b3e825997ff56c0f)
1 // Copyright 2017 The Abseil Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_
16 #define ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_
17 
18 #include <algorithm>
19 #include <cassert>
20 #include <cmath>
21 #include <istream>
22 #include <limits>
23 #include <ostream>
24 #include <type_traits>
25 
26 #include "absl/numeric/bits.h"
27 #include "absl/random/internal/fastmath.h"
28 #include "absl/random/internal/generate_real.h"
29 #include "absl/random/internal/iostream_state_saver.h"
30 #include "absl/random/internal/traits.h"
31 #include "absl/random/uniform_int_distribution.h"
32 
33 namespace absl {
34 ABSL_NAMESPACE_BEGIN
35 
36 // log_uniform_int_distribution:
37 //
38 // Returns a random variate R in range [min, max] such that
39 // floor(log(R-min, base)) is uniformly distributed.
40 // We ensure uniformity by discretization using the
41 // boundary sets [0, 1, base, base * base, ... min(base*n, max)]
42 //
43 template <typename IntType = int>
44 class log_uniform_int_distribution {
45  private:
46   using unsigned_type =
47       typename random_internal::make_unsigned_bits<IntType>::type;
48 
49  public:
50   using result_type = IntType;
51 
52   class param_type {
53    public:
54     using distribution_type = log_uniform_int_distribution;
55 
56     explicit param_type(
57         result_type min = 0,
58         result_type max = (std::numeric_limits<result_type>::max)(),
59         result_type base = 2)
min_(min)60         : min_(min),
61           max_(max),
62           base_(base),
63           range_(static_cast<unsigned_type>(max_) -
64                  static_cast<unsigned_type>(min_)),
65           log_range_(0) {
66       assert(max_ >= min_);
67       assert(base_ > 1);
68 
69       if (base_ == 2) {
70         // Determine where the first set bit is on range(), giving a log2(range)
71         // value which can be used to construct bounds.
72         log_range_ = (std::min)(random_internal::BitWidth(range()),
73                                 std::numeric_limits<unsigned_type>::digits);
74       } else {
75         // NOTE: Computing the logN(x) introduces error from 2 sources:
76         // 1. Conversion of int to double loses precision for values >=
77         // 2^53, which may cause some log() computations to operate on
78         // different values.
79         // 2. The error introduced by the division will cause the result
80         // to differ from the expected value.
81         //
82         // Thus a result which should equal K may equal K +/- epsilon,
83         // which can eliminate some values depending on where the bounds fall.
84         const double inv_log_base = 1.0 / std::log(static_cast<double>(base_));
85         const double log_range = std::log(static_cast<double>(range()) + 0.5);
86         log_range_ = static_cast<int>(std::ceil(inv_log_base * log_range));
87       }
88     }
89 
result_type(min)90     result_type(min)() const { return min_; }
result_type(max)91     result_type(max)() const { return max_; }
base()92     result_type base() const { return base_; }
93 
94     friend bool operator==(const param_type& a, const param_type& b) {
95       return a.min_ == b.min_ && a.max_ == b.max_ && a.base_ == b.base_;
96     }
97 
98     friend bool operator!=(const param_type& a, const param_type& b) {
99       return !(a == b);
100     }
101 
102    private:
103     friend class log_uniform_int_distribution;
104 
log_range()105     int log_range() const { return log_range_; }
range()106     unsigned_type range() const { return range_; }
107 
108     result_type min_;
109     result_type max_;
110     result_type base_;
111     unsigned_type range_;  // max - min
112     int log_range_;        // ceil(logN(range_))
113 
114     static_assert(random_internal::IsIntegral<IntType>::value,
115                   "Class-template absl::log_uniform_int_distribution<> must be "
116                   "parameterized using an integral type.");
117   };
118 
log_uniform_int_distribution()119   log_uniform_int_distribution() : log_uniform_int_distribution(0) {}
120 
121   explicit log_uniform_int_distribution(
122       result_type min,
123       result_type max = (std::numeric_limits<result_type>::max)(),
124       result_type base = 2)
param_(min,max,base)125       : param_(min, max, base) {}
126 
log_uniform_int_distribution(const param_type & p)127   explicit log_uniform_int_distribution(const param_type& p) : param_(p) {}
128 
reset()129   void reset() {}
130 
131   // generating functions
132   template <typename URBG>
operator()133   result_type operator()(URBG& g) {  // NOLINT(runtime/references)
134     return (*this)(g, param_);
135   }
136 
137   template <typename URBG>
operator()138   result_type operator()(URBG& g,  // NOLINT(runtime/references)
139                          const param_type& p) {
140     return static_cast<result_type>((p.min)() + Generate(g, p));
141   }
142 
result_type(min)143   result_type(min)() const { return (param_.min)(); }
result_type(max)144   result_type(max)() const { return (param_.max)(); }
base()145   result_type base() const { return param_.base(); }
146 
param()147   param_type param() const { return param_; }
param(const param_type & p)148   void param(const param_type& p) { param_ = p; }
149 
150   friend bool operator==(const log_uniform_int_distribution& a,
151                          const log_uniform_int_distribution& b) {
152     return a.param_ == b.param_;
153   }
154   friend bool operator!=(const log_uniform_int_distribution& a,
155                          const log_uniform_int_distribution& b) {
156     return a.param_ != b.param_;
157   }
158 
159  private:
160   // Returns a log-uniform variate in the range [0, p.range()]. The caller
161   // should add min() to shift the result to the correct range.
162   template <typename URNG>
163   unsigned_type Generate(URNG& g,  // NOLINT(runtime/references)
164                          const param_type& p);
165 
166   param_type param_;
167 };
168 
169 template <typename IntType>
170 template <typename URBG>
171 typename log_uniform_int_distribution<IntType>::unsigned_type
Generate(URBG & g,const param_type & p)172 log_uniform_int_distribution<IntType>::Generate(
173     URBG& g,  // NOLINT(runtime/references)
174     const param_type& p) {
175   // sample e over [0, log_range]. Map the results of e to this:
176   // 0 => 0
177   // 1 => [1, b-1]
178   // 2 => [b, (b^2)-1]
179   // n => [b^(n-1)..(b^n)-1]
180   const int e = absl::uniform_int_distribution<int>(0, p.log_range())(g);
181   if (e == 0) {
182     return 0;
183   }
184   const int d = e - 1;
185 
186   unsigned_type base_e, top_e;
187   if (p.base() == 2) {
188     base_e = static_cast<unsigned_type>(1) << d;
189 
190     top_e = (e >= std::numeric_limits<unsigned_type>::digits)
191                 ? (std::numeric_limits<unsigned_type>::max)()
192                 : (static_cast<unsigned_type>(1) << e) - 1;
193   } else {
194     const double r = std::pow(static_cast<double>(p.base()), d);
195     const double s = (r * static_cast<double>(p.base())) - 1.0;
196 
197     base_e =
198         (r > static_cast<double>((std::numeric_limits<unsigned_type>::max)()))
199             ? (std::numeric_limits<unsigned_type>::max)()
200             : static_cast<unsigned_type>(r);
201 
202     top_e =
203         (s > static_cast<double>((std::numeric_limits<unsigned_type>::max)()))
204             ? (std::numeric_limits<unsigned_type>::max)()
205             : static_cast<unsigned_type>(s);
206   }
207 
208   const unsigned_type lo = (base_e >= p.range()) ? p.range() : base_e;
209   const unsigned_type hi = (top_e >= p.range()) ? p.range() : top_e;
210 
211   // choose uniformly over [lo, hi]
212   return absl::uniform_int_distribution<result_type>(
213       static_cast<result_type>(lo), static_cast<result_type>(hi))(g);
214 }
215 
216 template <typename CharT, typename Traits, typename IntType>
217 std::basic_ostream<CharT, Traits>& operator<<(
218     std::basic_ostream<CharT, Traits>& os,  // NOLINT(runtime/references)
219     const log_uniform_int_distribution<IntType>& x) {
220   using stream_type =
221       typename random_internal::stream_format_type<IntType>::type;
222   auto saver = random_internal::make_ostream_state_saver(os);
223   os << static_cast<stream_type>((x.min)()) << os.fill()
224      << static_cast<stream_type>((x.max)()) << os.fill()
225      << static_cast<stream_type>(x.base());
226   return os;
227 }
228 
229 template <typename CharT, typename Traits, typename IntType>
230 std::basic_istream<CharT, Traits>& operator>>(
231     std::basic_istream<CharT, Traits>& is,       // NOLINT(runtime/references)
232     log_uniform_int_distribution<IntType>& x) {  // NOLINT(runtime/references)
233   using param_type = typename log_uniform_int_distribution<IntType>::param_type;
234   using result_type =
235       typename log_uniform_int_distribution<IntType>::result_type;
236   using stream_type =
237       typename random_internal::stream_format_type<IntType>::type;
238 
239   stream_type min;
240   stream_type max;
241   stream_type base;
242 
243   auto saver = random_internal::make_istream_state_saver(is);
244   is >> min >> max >> base;
245   if (!is.fail()) {
246     x.param(param_type(static_cast<result_type>(min),
247                        static_cast<result_type>(max),
248                        static_cast<result_type>(base)));
249   }
250   return is;
251 }
252 
253 ABSL_NAMESPACE_END
254 }  // namespace absl
255 
256 #endif  // ABSL_RANDOM_LOG_UNIFORM_INT_DISTRIBUTION_H_
257