blob: 1243bc1c62abe24dccb19323ca3f2c964dc403bb [file] [log] [blame]
Austin Schuh36244a12019-09-21 17:52:38 -07001// Copyright 2019 The Abseil Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15#ifndef ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_
16#define ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_
17
18#include <cmath>
19#include <limits>
20#include <type_traits>
21
Austin Schuhb4691e92020-12-31 12:37:18 -080022#include "absl/base/config.h"
Austin Schuh36244a12019-09-21 17:52:38 -070023#include "absl/meta/type_traits.h"
Austin Schuhb4691e92020-12-31 12:37:18 -080024#include "absl/random/internal/traits.h"
Austin Schuh36244a12019-09-21 17:52:38 -070025
26namespace absl {
Austin Schuhb4691e92020-12-31 12:37:18 -080027ABSL_NAMESPACE_BEGIN
28
Austin Schuh36244a12019-09-21 17:52:38 -070029template <typename IntType>
30class uniform_int_distribution;
31
32template <typename RealType>
33class uniform_real_distribution;
34
35// Interval tag types which specify whether the interval is open or closed
36// on either boundary.
37
38namespace random_internal {
39template <typename T>
40struct TagTypeCompare {};
41
42template <typename T>
43constexpr bool operator==(TagTypeCompare<T>, TagTypeCompare<T>) {
44 // Tags are mono-states. They always compare equal.
45 return true;
46}
47template <typename T>
48constexpr bool operator!=(TagTypeCompare<T>, TagTypeCompare<T>) {
49 return false;
50}
51
52} // namespace random_internal
53
54struct IntervalClosedClosedTag
55 : public random_internal::TagTypeCompare<IntervalClosedClosedTag> {};
56struct IntervalClosedOpenTag
57 : public random_internal::TagTypeCompare<IntervalClosedOpenTag> {};
58struct IntervalOpenClosedTag
59 : public random_internal::TagTypeCompare<IntervalOpenClosedTag> {};
60struct IntervalOpenOpenTag
61 : public random_internal::TagTypeCompare<IntervalOpenOpenTag> {};
62
63namespace random_internal {
Austin Schuhb4691e92020-12-31 12:37:18 -080064
65// In the absence of an explicitly provided return-type, the template
66// "uniform_inferred_return_t<A, B>" is used to derive a suitable type, based on
67// the data-types of the endpoint-arguments {A lo, B hi}.
68//
69// Given endpoints {A lo, B hi}, one of {A, B} will be chosen as the
70// return-type, if one type can be implicitly converted into the other, in a
71// lossless way. The template "is_widening_convertible" implements the
72// compile-time logic for deciding if such a conversion is possible.
73//
74// If no such conversion between {A, B} exists, then the overload for
75// absl::Uniform() will be discarded, and the call will be ill-formed.
76// Return-type for absl::Uniform() when the return-type is inferred.
77template <typename A, typename B>
78using uniform_inferred_return_t =
79 absl::enable_if_t<absl::disjunction<is_widening_convertible<A, B>,
80 is_widening_convertible<B, A>>::value,
81 typename std::conditional<
82 is_widening_convertible<A, B>::value, B, A>::type>;
83
Austin Schuh36244a12019-09-21 17:52:38 -070084// The functions
85// uniform_lower_bound(tag, a, b)
86// and
87// uniform_upper_bound(tag, a, b)
88// are used as implementation-details for absl::Uniform().
89//
90// Conceptually,
91// [a, b] == [uniform_lower_bound(IntervalClosedClosed, a, b),
92// uniform_upper_bound(IntervalClosedClosed, a, b)]
93// (a, b) == [uniform_lower_bound(IntervalOpenOpen, a, b),
94// uniform_upper_bound(IntervalOpenOpen, a, b)]
95// [a, b) == [uniform_lower_bound(IntervalClosedOpen, a, b),
96// uniform_upper_bound(IntervalClosedOpen, a, b)]
97// (a, b] == [uniform_lower_bound(IntervalOpenClosed, a, b),
98// uniform_upper_bound(IntervalOpenClosed, a, b)]
99//
100template <typename IntType, typename Tag>
101typename absl::enable_if_t<
102 absl::conjunction<
103 std::is_integral<IntType>,
104 absl::disjunction<std::is_same<Tag, IntervalOpenClosedTag>,
105 std::is_same<Tag, IntervalOpenOpenTag>>>::value,
106 IntType>
107uniform_lower_bound(Tag, IntType a, IntType) {
Austin Schuhb4691e92020-12-31 12:37:18 -0800108 return a < (std::numeric_limits<IntType>::max)() ? (a + 1) : a;
Austin Schuh36244a12019-09-21 17:52:38 -0700109}
110
111template <typename FloatType, typename Tag>
112typename absl::enable_if_t<
113 absl::conjunction<
114 std::is_floating_point<FloatType>,
115 absl::disjunction<std::is_same<Tag, IntervalOpenClosedTag>,
116 std::is_same<Tag, IntervalOpenOpenTag>>>::value,
117 FloatType>
118uniform_lower_bound(Tag, FloatType a, FloatType b) {
119 return std::nextafter(a, b);
120}
121
122template <typename NumType, typename Tag>
123typename absl::enable_if_t<
124 absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>,
125 std::is_same<Tag, IntervalClosedOpenTag>>::value,
126 NumType>
127uniform_lower_bound(Tag, NumType a, NumType) {
128 return a;
129}
130
131template <typename IntType, typename Tag>
132typename absl::enable_if_t<
133 absl::conjunction<
134 std::is_integral<IntType>,
135 absl::disjunction<std::is_same<Tag, IntervalClosedOpenTag>,
136 std::is_same<Tag, IntervalOpenOpenTag>>>::value,
137 IntType>
138uniform_upper_bound(Tag, IntType, IntType b) {
Austin Schuhb4691e92020-12-31 12:37:18 -0800139 return b > (std::numeric_limits<IntType>::min)() ? (b - 1) : b;
Austin Schuh36244a12019-09-21 17:52:38 -0700140}
141
142template <typename FloatType, typename Tag>
143typename absl::enable_if_t<
144 absl::conjunction<
145 std::is_floating_point<FloatType>,
146 absl::disjunction<std::is_same<Tag, IntervalClosedOpenTag>,
147 std::is_same<Tag, IntervalOpenOpenTag>>>::value,
148 FloatType>
149uniform_upper_bound(Tag, FloatType, FloatType b) {
150 return b;
151}
152
153template <typename IntType, typename Tag>
154typename absl::enable_if_t<
155 absl::conjunction<
156 std::is_integral<IntType>,
157 absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>,
158 std::is_same<Tag, IntervalOpenClosedTag>>>::value,
159 IntType>
160uniform_upper_bound(Tag, IntType, IntType b) {
161 return b;
162}
163
164template <typename FloatType, typename Tag>
165typename absl::enable_if_t<
166 absl::conjunction<
167 std::is_floating_point<FloatType>,
168 absl::disjunction<std::is_same<Tag, IntervalClosedClosedTag>,
169 std::is_same<Tag, IntervalOpenClosedTag>>>::value,
170 FloatType>
171uniform_upper_bound(Tag, FloatType, FloatType b) {
172 return std::nextafter(b, (std::numeric_limits<FloatType>::max)());
173}
174
Austin Schuhb4691e92020-12-31 12:37:18 -0800175// Returns whether the bounds are valid for the underlying distribution.
176// Inputs must have already been resolved via uniform_*_bound calls.
177//
178// The c++ standard constraints in [rand.dist.uni.int] are listed as:
179// requires: lo <= hi.
180//
181// In the uniform_int_distrubtion, {lo, hi} are closed, closed. Thus:
182// [0, 0] is legal.
183// [0, 0) is not legal, but [0, 1) is, which translates to [0, 0].
184// (0, 1) is not legal, but (0, 2) is, which translates to [1, 1].
185// (0, 0] is not legal, but (0, 1] is, which translates to [1, 1].
186//
187// The c++ standard constraints in [rand.dist.uni.real] are listed as:
188// requires: lo <= hi.
189// requires: (hi - lo) <= numeric_limits<T>::max()
190//
191// In the uniform_real_distribution, {lo, hi} are closed, open, Thus:
192// [0, 0] is legal, which is [0, 0+epsilon).
193// [0, 0) is legal.
194// (0, 0) is not legal, but (0-epsilon, 0+epsilon) is.
195// (0, 0] is not legal, but (0, 0+epsilon] is.
196//
197template <typename FloatType>
198absl::enable_if_t<std::is_floating_point<FloatType>::value, bool>
199is_uniform_range_valid(FloatType a, FloatType b) {
200 return a <= b && std::isfinite(b - a);
201}
202
203template <typename IntType>
204absl::enable_if_t<std::is_integral<IntType>::value, bool>
205is_uniform_range_valid(IntType a, IntType b) {
206 return a <= b;
207}
208
209// UniformDistribution selects either absl::uniform_int_distribution
210// or absl::uniform_real_distribution depending on the NumType parameter.
Austin Schuh36244a12019-09-21 17:52:38 -0700211template <typename NumType>
212using UniformDistribution =
213 typename std::conditional<std::is_integral<NumType>::value,
214 absl::uniform_int_distribution<NumType>,
215 absl::uniform_real_distribution<NumType>>::type;
216
Austin Schuhb4691e92020-12-31 12:37:18 -0800217// UniformDistributionWrapper is used as the underlying distribution type
218// by the absl::Uniform template function. It selects the proper Abseil
219// uniform distribution and provides constructor overloads that match the
220// expected parameter order as well as adjusting distribtuion bounds based
221// on the tag.
222template <typename NumType>
Austin Schuh36244a12019-09-21 17:52:38 -0700223struct UniformDistributionWrapper : public UniformDistribution<NumType> {
Austin Schuhb4691e92020-12-31 12:37:18 -0800224 template <typename TagType>
Austin Schuh36244a12019-09-21 17:52:38 -0700225 explicit UniformDistributionWrapper(TagType, NumType lo, NumType hi)
226 : UniformDistribution<NumType>(
227 uniform_lower_bound<NumType>(TagType{}, lo, hi),
228 uniform_upper_bound<NumType>(TagType{}, lo, hi)) {}
Austin Schuhb4691e92020-12-31 12:37:18 -0800229
230 explicit UniformDistributionWrapper(NumType lo, NumType hi)
231 : UniformDistribution<NumType>(
232 uniform_lower_bound<NumType>(IntervalClosedOpenTag(), lo, hi),
233 uniform_upper_bound<NumType>(IntervalClosedOpenTag(), lo, hi)) {}
234
235 explicit UniformDistributionWrapper()
236 : UniformDistribution<NumType>(std::numeric_limits<NumType>::lowest(),
237 (std::numeric_limits<NumType>::max)()) {}
Austin Schuh36244a12019-09-21 17:52:38 -0700238};
239
240} // namespace random_internal
Austin Schuhb4691e92020-12-31 12:37:18 -0800241ABSL_NAMESPACE_END
Austin Schuh36244a12019-09-21 17:52:38 -0700242} // namespace absl
243
244#endif // ABSL_RANDOM_INTERNAL_UNIFORM_HELPER_H_