Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 1 | // Copyright 2019 The Abseil Authors. |
| 2 | // |
| 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | // you may not use this file except in compliance with the License. |
| 5 | // You may obtain a copy of the License at |
| 6 | // |
| 7 | // https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | // |
| 9 | // Unless required by applicable law or agreed to in writing, software |
| 10 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | // See the License for the specific language governing permissions and |
| 13 | // limitations under the License. |
| 14 | // |
| 15 | #ifndef ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_ |
| 16 | #define ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_ |
| 17 | |
| 18 | #include <cmath> |
| 19 | #include <limits> |
| 20 | #include <type_traits> |
| 21 | |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 22 | #include "absl/base/config.h" |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 23 | #include "absl/meta/type_traits.h" |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 24 | #include "absl/random/internal/traits.h" |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 25 | |
| 26 | namespace absl { |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 27 | ABSL_NAMESPACE_BEGIN |
| 28 | |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 29 | template <typename IntType> |
| 30 | class uniform_int_distribution; |
| 31 | |
| 32 | template <typename RealType> |
| 33 | class uniform_real_distribution; |
| 34 | |
| 35 | // Interval tag types which specify whether the interval is open or closed |
| 36 | // on either boundary. |
| 37 | |
| 38 | namespace random_internal { |
| 39 | template <typename T> |
| 40 | struct TagTypeCompare {}; |
| 41 | |
| 42 | template <typename T> |
| 43 | constexpr bool operator==(TagTypeCompare<T>, TagTypeCompare<T>) { |
| 44 | // Tags are mono-states. They always compare equal. |
| 45 | return true; |
| 46 | } |
| 47 | template <typename T> |
| 48 | constexpr bool operator!=(TagTypeCompare<T>, TagTypeCompare<T>) { |
| 49 | return false; |
| 50 | } |
| 51 | |
| 52 | } // namespace random_internal |
| 53 | |
| 54 | struct IntervalClosedClosedTag |
| 55 | : public random_internal::TagTypeCompare<IntervalClosedClosedTag> {}; |
| 56 | struct IntervalClosedOpenTag |
| 57 | : public random_internal::TagTypeCompare<IntervalClosedOpenTag> {}; |
| 58 | struct IntervalOpenClosedTag |
| 59 | : public random_internal::TagTypeCompare<IntervalOpenClosedTag> {}; |
| 60 | struct IntervalOpenOpenTag |
| 61 | : public random_internal::TagTypeCompare<IntervalOpenOpenTag> {}; |
| 62 | |
| 63 | namespace random_internal { |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 64 | |
| 65 | // In the absence of an explicitly provided return-type, the template |
| 66 | // "uniform_inferred_return_t<A, B>" is used to derive a suitable type, based on |
| 67 | // the data-types of the endpoint-arguments {A lo, B hi}. |
| 68 | // |
| 69 | // Given endpoints {A lo, B hi}, one of {A, B} will be chosen as the |
| 70 | // return-type, if one type can be implicitly converted into the other, in a |
| 71 | // lossless way. The template "is_widening_convertible" implements the |
| 72 | // compile-time logic for deciding if such a conversion is possible. |
| 73 | // |
| 74 | // If no such conversion between {A, B} exists, then the overload for |
| 75 | // absl::Uniform() will be discarded, and the call will be ill-formed. |
| 76 | // Return-type for absl::Uniform() when the return-type is inferred. |
| 77 | template <typename A, typename B> |
| 78 | using uniform_inferred_return_t = |
| 79 | absl::enable_if_t<absl::disjunction<is_widening_convertible<A, B>, |
| 80 | is_widening_convertible<B, A>>::value, |
| 81 | typename std::conditional< |
| 82 | is_widening_convertible<A, B>::value, B, A>::type>; |
| 83 | |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 84 | // The functions |
| 85 | // uniform_lower_bound(tag, a, b) |
| 86 | // and |
| 87 | // uniform_upper_bound(tag, a, b) |
| 88 | // are used as implementation-details for absl::Uniform(). |
| 89 | // |
| 90 | // Conceptually, |
| 91 | // [a, b] == [uniform_lower_bound(IntervalClosedClosed, a, b), |
| 92 | // uniform_upper_bound(IntervalClosedClosed, a, b)] |
| 93 | // (a, b) == [uniform_lower_bound(IntervalOpenOpen, a, b), |
| 94 | // uniform_upper_bound(IntervalOpenOpen, a, b)] |
| 95 | // [a, b) == [uniform_lower_bound(IntervalClosedOpen, a, b), |
| 96 | // uniform_upper_bound(IntervalClosedOpen, a, b)] |
| 97 | // (a, b] == [uniform_lower_bound(IntervalOpenClosed, a, b), |
| 98 | // uniform_upper_bound(IntervalOpenClosed, a, b)] |
| 99 | // |
| 100 | template <typename IntType, typename Tag> |
| 101 | typename absl::enable_if_t< |
| 102 | absl::conjunction< |
| 103 | std::is_integral<IntType>, |
| 104 | absl::disjunction<std::is_same<Tag, IntervalOpenClosedTag>, |
| 105 | std::is_same<Tag, IntervalOpenOpenTag>>>::value, |
| 106 | IntType> |
| 107 | uniform_lower_bound(Tag, IntType a, IntType) { |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 108 | return a < (std::numeric_limits<IntType>::max)() ? (a + 1) : a; |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 109 | } |
| 110 | |
| 111 | template <typename FloatType, typename Tag> |
| 112 | typename absl::enable_if_t< |
| 113 | absl::conjunction< |
| 114 | std::is_floating_point<FloatType>, |
| 115 | absl::disjunction<std::is_same<Tag, IntervalOpenClosedTag>, |
| 116 | std::is_same<Tag, IntervalOpenOpenTag>>>::value, |
| 117 | FloatType> |
| 118 | uniform_lower_bound(Tag, FloatType a, FloatType b) { |
| 119 | return std::nextafter(a, b); |
| 120 | } |
| 121 | |
| 122 | template <typename NumType, typename Tag> |
| 123 | typename absl::enable_if_t< |
| 124 | absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>, |
| 125 | std::is_same<Tag, IntervalClosedOpenTag>>::value, |
| 126 | NumType> |
| 127 | uniform_lower_bound(Tag, NumType a, NumType) { |
| 128 | return a; |
| 129 | } |
| 130 | |
| 131 | template <typename IntType, typename Tag> |
| 132 | typename absl::enable_if_t< |
| 133 | absl::conjunction< |
| 134 | std::is_integral<IntType>, |
| 135 | absl::disjunction<std::is_same<Tag, IntervalClosedOpenTag>, |
| 136 | std::is_same<Tag, IntervalOpenOpenTag>>>::value, |
| 137 | IntType> |
| 138 | uniform_upper_bound(Tag, IntType, IntType b) { |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 139 | return b > (std::numeric_limits<IntType>::min)() ? (b - 1) : b; |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 140 | } |
| 141 | |
| 142 | template <typename FloatType, typename Tag> |
| 143 | typename absl::enable_if_t< |
| 144 | absl::conjunction< |
| 145 | std::is_floating_point<FloatType>, |
| 146 | absl::disjunction<std::is_same<Tag, IntervalClosedOpenTag>, |
| 147 | std::is_same<Tag, IntervalOpenOpenTag>>>::value, |
| 148 | FloatType> |
| 149 | uniform_upper_bound(Tag, FloatType, FloatType b) { |
| 150 | return b; |
| 151 | } |
| 152 | |
| 153 | template <typename IntType, typename Tag> |
| 154 | typename absl::enable_if_t< |
| 155 | absl::conjunction< |
| 156 | std::is_integral<IntType>, |
| 157 | absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>, |
| 158 | std::is_same<Tag, IntervalOpenClosedTag>>>::value, |
| 159 | IntType> |
| 160 | uniform_upper_bound(Tag, IntType, IntType b) { |
| 161 | return b; |
| 162 | } |
| 163 | |
| 164 | template <typename FloatType, typename Tag> |
| 165 | typename absl::enable_if_t< |
| 166 | absl::conjunction< |
| 167 | std::is_floating_point<FloatType>, |
| 168 | absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>, |
| 169 | std::is_same<Tag, IntervalOpenClosedTag>>>::value, |
| 170 | FloatType> |
| 171 | uniform_upper_bound(Tag, FloatType, FloatType b) { |
| 172 | return std::nextafter(b, (std::numeric_limits<FloatType>::max)()); |
| 173 | } |
| 174 | |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 175 | // Returns whether the bounds are valid for the underlying distribution. |
| 176 | // Inputs must have already been resolved via uniform_*_bound calls. |
| 177 | // |
| 178 | // The c++ standard constraints in [rand.dist.uni.int] are listed as: |
| 179 | // requires: lo <= hi. |
| 180 | // |
| 181 | // In the uniform_int_distrubtion, {lo, hi} are closed, closed. Thus: |
| 182 | // [0, 0] is legal. |
| 183 | // [0, 0) is not legal, but [0, 1) is, which translates to [0, 0]. |
| 184 | // (0, 1) is not legal, but (0, 2) is, which translates to [1, 1]. |
| 185 | // (0, 0] is not legal, but (0, 1] is, which translates to [1, 1]. |
| 186 | // |
| 187 | // The c++ standard constraints in [rand.dist.uni.real] are listed as: |
| 188 | // requires: lo <= hi. |
| 189 | // requires: (hi - lo) <= numeric_limits<T>::max() |
| 190 | // |
| 191 | // In the uniform_real_distribution, {lo, hi} are closed, open, Thus: |
| 192 | // [0, 0] is legal, which is [0, 0+epsilon). |
| 193 | // [0, 0) is legal. |
| 194 | // (0, 0) is not legal, but (0-epsilon, 0+epsilon) is. |
| 195 | // (0, 0] is not legal, but (0, 0+epsilon] is. |
| 196 | // |
| 197 | template <typename FloatType> |
| 198 | absl::enable_if_t<std::is_floating_point<FloatType>::value, bool> |
| 199 | is_uniform_range_valid(FloatType a, FloatType b) { |
| 200 | return a <= b && std::isfinite(b - a); |
| 201 | } |
| 202 | |
| 203 | template <typename IntType> |
| 204 | absl::enable_if_t<std::is_integral<IntType>::value, bool> |
| 205 | is_uniform_range_valid(IntType a, IntType b) { |
| 206 | return a <= b; |
| 207 | } |
| 208 | |
| 209 | // UniformDistribution selects either absl::uniform_int_distribution |
| 210 | // or absl::uniform_real_distribution depending on the NumType parameter. |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 211 | template <typename NumType> |
| 212 | using UniformDistribution = |
| 213 | typename std::conditional<std::is_integral<NumType>::value, |
| 214 | absl::uniform_int_distribution<NumType>, |
| 215 | absl::uniform_real_distribution<NumType>>::type; |
| 216 | |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 217 | // UniformDistributionWrapper is used as the underlying distribution type |
| 218 | // by the absl::Uniform template function. It selects the proper Abseil |
| 219 | // uniform distribution and provides constructor overloads that match the |
| 220 | // expected parameter order as well as adjusting distribtuion bounds based |
| 221 | // on the tag. |
| 222 | template <typename NumType> |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 223 | struct UniformDistributionWrapper : public UniformDistribution<NumType> { |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 224 | template <typename TagType> |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 225 | explicit UniformDistributionWrapper(TagType, NumType lo, NumType hi) |
| 226 | : UniformDistribution<NumType>( |
| 227 | uniform_lower_bound<NumType>(TagType{}, lo, hi), |
| 228 | uniform_upper_bound<NumType>(TagType{}, lo, hi)) {} |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 229 | |
| 230 | explicit UniformDistributionWrapper(NumType lo, NumType hi) |
| 231 | : UniformDistribution<NumType>( |
| 232 | uniform_lower_bound<NumType>(IntervalClosedOpenTag(), lo, hi), |
| 233 | uniform_upper_bound<NumType>(IntervalClosedOpenTag(), lo, hi)) {} |
| 234 | |
| 235 | explicit UniformDistributionWrapper() |
| 236 | : UniformDistribution<NumType>(std::numeric_limits<NumType>::lowest(), |
| 237 | (std::numeric_limits<NumType>::max)()) {} |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 238 | }; |
| 239 | |
| 240 | } // namespace random_internal |
Austin Schuh | b4691e9 | 2020-12-31 12:37:18 -0800 | [diff] [blame^] | 241 | ABSL_NAMESPACE_END |
Austin Schuh | 36244a1 | 2019-09-21 17:52:38 -0700 | [diff] [blame] | 242 | } // namespace absl |
| 243 | |
| 244 | #endif // ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_ |