1 // Copyright 2024 The Abseil Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ 16 #define ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ 17 18 #include <type_traits> 19 20 #include "absl/base/config.h" 21 #include "absl/base/internal/raw_logging.h" 22 #include "absl/random/internal/iostream_state_saver.h" 23 #include "absl/random/internal/uniform_helper.h" 24 #include "absl/strings/str_cat.h" 25 #include "absl/strings/string_view.h" 26 27 namespace absl { 28 ABSL_NAMESPACE_BEGIN 29 namespace random_internal { 30 31 template <typename NumType> 32 class UniformDistributionValidator { 33 public: 34 // Handle absl::Uniform<NumType>(gen, absl::IntervalTag, lo, hi). 35 template <typename TagType> Validate(NumType x,TagType tag,NumType lo,NumType hi)36 static void Validate(NumType x, TagType tag, NumType lo, NumType hi) { 37 // For invalid ranges, absl::Uniform() simply returns one of the bounds. 38 if (x == lo && lo == hi) return; 39 40 ValidateImpl(std::is_floating_point<NumType>{}, x, tag, lo, hi); 41 } 42 43 // Handle absl::Uniform<NumType>(gen, lo, hi). Validate(NumType x,NumType lo,NumType hi)44 static void Validate(NumType x, NumType lo, NumType hi) { 45 Validate(x, IntervalClosedOpenTag(), lo, hi); 46 } 47 48 // Handle absl::Uniform<NumType>(gen). Validate(NumType)49 static void Validate(NumType) { 50 // absl::Uniform<NumType>(gen) spans the entire range of `NumType`, so any 51 // value is okay. This overload exists because the validation logic attempts 52 // to call it anyway rather than adding extra SFINAE. 53 } 54 55 private: TagLbBound(IntervalClosedOpenTag)56 static absl::string_view TagLbBound(IntervalClosedOpenTag) { return "["; } TagLbBound(IntervalOpenOpenTag)57 static absl::string_view TagLbBound(IntervalOpenOpenTag) { return "("; } TagLbBound(IntervalClosedClosedTag)58 static absl::string_view TagLbBound(IntervalClosedClosedTag) { return "["; } TagLbBound(IntervalOpenClosedTag)59 static absl::string_view TagLbBound(IntervalOpenClosedTag) { return "("; } TagUbBound(IntervalClosedOpenTag)60 static absl::string_view TagUbBound(IntervalClosedOpenTag) { return ")"; } TagUbBound(IntervalOpenOpenTag)61 static absl::string_view TagUbBound(IntervalOpenOpenTag) { return ")"; } TagUbBound(IntervalClosedClosedTag)62 static absl::string_view TagUbBound(IntervalClosedClosedTag) { return "]"; } TagUbBound(IntervalOpenClosedTag)63 static absl::string_view TagUbBound(IntervalOpenClosedTag) { return "]"; } 64 65 template <typename TagType> ValidateImpl(std::true_type,NumType x,TagType tag,NumType lo,NumType hi)66 static void ValidateImpl(std::true_type /* is_floating_point */, NumType x, 67 TagType tag, NumType lo, NumType hi) { 68 UniformDistributionWrapper<NumType> dist(tag, lo, hi); 69 NumType lb = dist.a(); 70 NumType ub = dist.b(); 71 // uniform_real_distribution is always closed-open, so the upper bound is 72 // always non-inclusive. 73 ABSL_INTERNAL_CHECK(lb <= x && x < ub, 74 absl::StrCat(x, " is not in ", TagLbBound(tag), lo, 75 ", ", hi, TagUbBound(tag))); 76 } 77 78 template <typename TagType> ValidateImpl(std::false_type,NumType x,TagType tag,NumType lo,NumType hi)79 static void ValidateImpl(std::false_type /* is_floating_point */, NumType x, 80 TagType tag, NumType lo, NumType hi) { 81 using stream_type = 82 typename random_internal::stream_format_type<NumType>::type; 83 84 UniformDistributionWrapper<NumType> dist(tag, lo, hi); 85 NumType lb = dist.a(); 86 NumType ub = dist.b(); 87 ABSL_INTERNAL_CHECK( 88 lb <= x && x <= ub, 89 absl::StrCat(stream_type{x}, " is not in ", TagLbBound(tag), 90 stream_type{lo}, ", ", stream_type{hi}, TagUbBound(tag))); 91 } 92 }; 93 94 } // namespace random_internal 95 ABSL_NAMESPACE_END 96 } // namespace absl 97 98 #endif // ABSL_RANDOM_INTERNAL_MOCK_VALIDATORS_H_ 99