blob: 3298c2cdb64b1b10d4e6846acab91d39624ba79f [file] [log] [blame]
Austin Schuh36244a12019-09-21 17:52:38 -07001//
2// Copyright 2018 The Abseil Authors.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// https://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16#ifndef ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_
17#define ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_
18
19#include <string>
20#include <tuple>
21#include <typeinfo>
22
23#include "absl/meta/type_traits.h"
24#include "absl/random/bernoulli_distribution.h"
25#include "absl/random/beta_distribution.h"
26#include "absl/random/exponential_distribution.h"
27#include "absl/random/gaussian_distribution.h"
28#include "absl/random/log_uniform_int_distribution.h"
29#include "absl/random/poisson_distribution.h"
30#include "absl/random/uniform_int_distribution.h"
31#include "absl/random/uniform_real_distribution.h"
32#include "absl/random/zipf_distribution.h"
33#include "absl/strings/str_cat.h"
34#include "absl/strings/str_join.h"
35#include "absl/strings/string_view.h"
36#include "absl/types/span.h"
37
38namespace absl {
39
40struct IntervalClosedClosedTag;
41struct IntervalClosedOpenTag;
42struct IntervalOpenClosedTag;
43struct IntervalOpenOpenTag;
44
45namespace random_internal {
46
47// ScalarTypeName defines a preferred hierarchy of preferred type names for
48// scalars, and is evaluated at compile time for the specific type
49// specialization.
50template <typename T>
51constexpr const char* ScalarTypeName() {
52 static_assert(std::is_integral<T>() || std::is_floating_point<T>(), "");
53 // clang-format off
54 return
55 std::is_same<T, float>::value ? "float" :
56 std::is_same<T, double>::value ? "double" :
57 std::is_same<T, long double>::value ? "long double" :
58 std::is_same<T, bool>::value ? "bool" :
59 std::is_signed<T>::value && sizeof(T) == 1 ? "int8_t" :
60 std::is_signed<T>::value && sizeof(T) == 2 ? "int16_t" :
61 std::is_signed<T>::value && sizeof(T) == 4 ? "int32_t" :
62 std::is_signed<T>::value && sizeof(T) == 8 ? "int64_t" :
63 std::is_unsigned<T>::value && sizeof(T) == 1 ? "uint8_t" :
64 std::is_unsigned<T>::value && sizeof(T) == 2 ? "uint16_t" :
65 std::is_unsigned<T>::value && sizeof(T) == 4 ? "uint32_t" :
66 std::is_unsigned<T>::value && sizeof(T) == 8 ? "uint64_t" :
67 "undefined";
68 // clang-format on
69
70 // NOTE: It would be nice to use typeid(T).name(), but that's an
71 // implementation-defined attribute which does not necessarily
72 // correspond to a name. We could potentially demangle it
73 // using, e.g. abi::__cxa_demangle.
74}
75
76// Distribution traits used by DistributionCaller and internal implementation
77// details of the mocking framework.
78/*
79struct DistributionFormatTraits {
80 // Returns the parameterized name of the distribution function.
81 static constexpr const char* FunctionName()
82 // Format DistrT parameters.
83 static std::string FormatArgs(DistrT& dist);
84 // Format DistrT::result_type results.
85 static std::string FormatResults(DistrT& dist);
86};
87*/
88template <typename DistrT>
89struct DistributionFormatTraits;
90
91template <typename R>
92struct DistributionFormatTraits<absl::uniform_int_distribution<R>> {
93 using distribution_t = absl::uniform_int_distribution<R>;
94 using result_t = typename distribution_t::result_type;
95
96 static constexpr const char* Name() { return "Uniform"; }
97
98 static std::string FunctionName() {
99 return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
100 }
101 static std::string FormatArgs(const distribution_t& d) {
102 return absl::StrCat("absl::IntervalClosedClosed, ", (d.min)(), ", ",
103 (d.max)());
104 }
105 static std::string FormatResults(absl::Span<const result_t> results) {
106 return absl::StrJoin(results, ", ");
107 }
108};
109
110template <typename R>
111struct DistributionFormatTraits<absl::uniform_real_distribution<R>> {
112 using distribution_t = absl::uniform_real_distribution<R>;
113 using result_t = typename distribution_t::result_type;
114
115 static constexpr const char* Name() { return "Uniform"; }
116
117 static std::string FunctionName() {
118 return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
119 }
120 static std::string FormatArgs(const distribution_t& d) {
121 return absl::StrCat((d.min)(), ", ", (d.max)());
122 }
123 static std::string FormatResults(absl::Span<const result_t> results) {
124 return absl::StrJoin(results, ", ");
125 }
126};
127
128template <typename R>
129struct DistributionFormatTraits<absl::exponential_distribution<R>> {
130 using distribution_t = absl::exponential_distribution<R>;
131 using result_t = typename distribution_t::result_type;
132
133 static constexpr const char* Name() { return "Exponential"; }
134
135 static std::string FunctionName() {
136 return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
137 }
138 static std::string FormatArgs(const distribution_t& d) {
139 return absl::StrCat(d.lambda());
140 }
141 static std::string FormatResults(absl::Span<const result_t> results) {
142 return absl::StrJoin(results, ", ");
143 }
144};
145
146template <typename R>
147struct DistributionFormatTraits<absl::poisson_distribution<R>> {
148 using distribution_t = absl::poisson_distribution<R>;
149 using result_t = typename distribution_t::result_type;
150
151 static constexpr const char* Name() { return "Poisson"; }
152
153 static std::string FunctionName() {
154 return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
155 }
156 static std::string FormatArgs(const distribution_t& d) {
157 return absl::StrCat(d.mean());
158 }
159 static std::string FormatResults(absl::Span<const result_t> results) {
160 return absl::StrJoin(results, ", ");
161 }
162};
163
164template <>
165struct DistributionFormatTraits<absl::bernoulli_distribution> {
166 using distribution_t = absl::bernoulli_distribution;
167 using result_t = typename distribution_t::result_type;
168
169 static constexpr const char* Name() { return "Bernoulli"; }
170
171 static constexpr const char* FunctionName() { return Name(); }
172 static std::string FormatArgs(const distribution_t& d) {
173 return absl::StrCat(d.p());
174 }
175 static std::string FormatResults(absl::Span<const result_t> results) {
176 return absl::StrJoin(results, ", ");
177 }
178};
179
180template <typename R>
181struct DistributionFormatTraits<absl::beta_distribution<R>> {
182 using distribution_t = absl::beta_distribution<R>;
183 using result_t = typename distribution_t::result_type;
184
185 static constexpr const char* Name() { return "Beta"; }
186
187 static std::string FunctionName() {
188 return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
189 }
190 static std::string FormatArgs(const distribution_t& d) {
191 return absl::StrCat(d.alpha(), ", ", d.beta());
192 }
193 static std::string FormatResults(absl::Span<const result_t> results) {
194 return absl::StrJoin(results, ", ");
195 }
196};
197
198template <typename R>
199struct DistributionFormatTraits<absl::zipf_distribution<R>> {
200 using distribution_t = absl::zipf_distribution<R>;
201 using result_t = typename distribution_t::result_type;
202
203 static constexpr const char* Name() { return "Zipf"; }
204
205 static std::string FunctionName() {
206 return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
207 }
208 static std::string FormatArgs(const distribution_t& d) {
209 return absl::StrCat(d.k(), ", ", d.v(), ", ", d.q());
210 }
211 static std::string FormatResults(absl::Span<const result_t> results) {
212 return absl::StrJoin(results, ", ");
213 }
214};
215
216template <typename R>
217struct DistributionFormatTraits<absl::gaussian_distribution<R>> {
218 using distribution_t = absl::gaussian_distribution<R>;
219 using result_t = typename distribution_t::result_type;
220
221 static constexpr const char* Name() { return "Gaussian"; }
222
223 static std::string FunctionName() {
224 return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
225 }
226 static std::string FormatArgs(const distribution_t& d) {
227 return absl::StrJoin(std::make_tuple(d.mean(), d.stddev()), ", ");
228 }
229 static std::string FormatResults(absl::Span<const result_t> results) {
230 return absl::StrJoin(results, ", ");
231 }
232};
233
234template <typename R>
235struct DistributionFormatTraits<absl::log_uniform_int_distribution<R>> {
236 using distribution_t = absl::log_uniform_int_distribution<R>;
237 using result_t = typename distribution_t::result_type;
238
239 static constexpr const char* Name() { return "LogUniform"; }
240
241 static std::string FunctionName() {
242 return absl::StrCat(Name(), "<", ScalarTypeName<R>(), ">");
243 }
244 static std::string FormatArgs(const distribution_t& d) {
245 return absl::StrJoin(std::make_tuple((d.min)(), (d.max)(), d.base()), ", ");
246 }
247 static std::string FormatResults(absl::Span<const result_t> results) {
248 return absl::StrJoin(results, ", ");
249 }
250};
251
252template <typename TagType, typename NumType>
253struct UniformDistributionWrapper;
254
255template <typename TagType, typename NumType>
256struct DistributionFormatTraits<UniformDistributionWrapper<TagType, NumType>> {
257 using distribution_t = UniformDistributionWrapper<TagType, NumType>;
258 using result_t = NumType;
259
260 static constexpr const char* Name() { return "Uniform"; }
261
262 static std::string FunctionName() {
263 return absl::StrCat(Name(), "<", ScalarTypeName<NumType>(), ">");
264 }
265 static std::string FormatArgs(const distribution_t& d) {
266 absl::string_view tag;
267 if (std::is_same<TagType, IntervalClosedClosedTag>::value) {
268 tag = "IntervalClosedClosed";
269 } else if (std::is_same<TagType, IntervalClosedOpenTag>::value) {
270 tag = "IntervalClosedOpen";
271 } else if (std::is_same<TagType, IntervalOpenClosedTag>::value) {
272 tag = "IntervalOpenClosed";
273 } else if (std::is_same<TagType, IntervalOpenOpenTag>::value) {
274 tag = "IntervalOpenOpen";
275 } else {
276 tag = "[[unknown tag type]]";
277 }
278 return absl::StrCat(tag, ", ", (d.min)(), ", ", (d.max)());
279 }
280 static std::string FormatResults(absl::Span<const result_t> results) {
281 return absl::StrJoin(results, ", ");
282 }
283};
284
285} // namespace random_internal
286} // namespace absl
287
288#endif // ABSL_RANDOM_DISTRIBUTION_FORMAT_TRAITS_H_