Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame] | 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
| 4 | // Copyright (C) 2017 Gael Guennebaud <gael.guennebaud@inria.fr> |
| 5 | // |
| 6 | // This Source Code Form is subject to the terms of the Mozilla |
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 9 | |
| 10 | #include "main.h" |
| 11 | |
Austin Schuh | c55b017 | 2022-02-20 17:52:35 -0800 | [diff] [blame] | 12 | template<typename T, typename U> |
| 13 | bool check_if_equal_or_nans(const T& actual, const U& expected) { |
| 14 | return ((actual == expected) || ((numext::isnan)(actual) && (numext::isnan)(expected))); |
| 15 | } |
| 16 | |
| 17 | template<typename T, typename U> |
| 18 | bool check_if_equal_or_nans(const std::complex<T>& actual, const std::complex<U>& expected) { |
| 19 | return check_if_equal_or_nans(numext::real(actual), numext::real(expected)) |
| 20 | && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected)); |
| 21 | } |
| 22 | |
| 23 | template<typename T, typename U> |
| 24 | bool test_is_equal_or_nans(const T& actual, const U& expected) |
| 25 | { |
| 26 | if (check_if_equal_or_nans(actual, expected)) { |
| 27 | return true; |
| 28 | } |
| 29 | |
| 30 | // false: |
| 31 | std::cerr |
| 32 | << "\n actual = " << actual |
| 33 | << "\n expected = " << expected << "\n\n"; |
| 34 | return false; |
| 35 | } |
| 36 | |
| 37 | #define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b)) |
| 38 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame] | 39 | template<typename T> |
| 40 | void check_abs() { |
| 41 | typedef typename NumTraits<T>::Real Real; |
Austin Schuh | c55b017 | 2022-02-20 17:52:35 -0800 | [diff] [blame] | 42 | Real zero(0); |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame] | 43 | |
| 44 | if(NumTraits<T>::IsSigned) |
| 45 | VERIFY_IS_EQUAL(numext::abs(-T(1)), T(1)); |
| 46 | VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); |
| 47 | VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); |
| 48 | |
Austin Schuh | c55b017 | 2022-02-20 17:52:35 -0800 | [diff] [blame] | 49 | for(int k=0; k<100; ++k) |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame] | 50 | { |
| 51 | T x = internal::random<T>(); |
| 52 | if(!internal::is_same<T,bool>::value) |
| 53 | x = x/Real(2); |
| 54 | if(NumTraits<T>::IsSigned) |
| 55 | { |
| 56 | VERIFY_IS_EQUAL(numext::abs(x), numext::abs(-x)); |
Austin Schuh | c55b017 | 2022-02-20 17:52:35 -0800 | [diff] [blame] | 57 | VERIFY( numext::abs(-x) >= zero ); |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame] | 58 | } |
Austin Schuh | c55b017 | 2022-02-20 17:52:35 -0800 | [diff] [blame] | 59 | VERIFY( numext::abs(x) >= zero ); |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame] | 60 | VERIFY_IS_APPROX( numext::abs2(x), numext::abs2(numext::abs(x)) ); |
| 61 | } |
| 62 | } |
| 63 | |
Austin Schuh | c55b017 | 2022-02-20 17:52:35 -0800 | [diff] [blame] | 64 | template<typename T> |
| 65 | void check_arg() { |
| 66 | typedef typename NumTraits<T>::Real Real; |
| 67 | VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); |
| 68 | VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame] | 69 | |
Austin Schuh | c55b017 | 2022-02-20 17:52:35 -0800 | [diff] [blame] | 70 | for(int k=0; k<100; ++k) |
| 71 | { |
| 72 | T x = internal::random<T>(); |
| 73 | Real y = numext::arg(x); |
| 74 | VERIFY_IS_APPROX( y, std::arg(x) ); |
| 75 | } |
| 76 | } |
| 77 | |
| 78 | template<typename T> |
| 79 | struct check_sqrt_impl { |
| 80 | static void run() { |
| 81 | for (int i=0; i<1000; ++i) { |
| 82 | const T x = numext::abs(internal::random<T>()); |
| 83 | const T sqrtx = numext::sqrt(x); |
| 84 | VERIFY_IS_APPROX(sqrtx*sqrtx, x); |
| 85 | } |
| 86 | |
| 87 | // Corner cases. |
| 88 | const T zero = T(0); |
| 89 | const T one = T(1); |
| 90 | const T inf = std::numeric_limits<T>::infinity(); |
| 91 | const T nan = std::numeric_limits<T>::quiet_NaN(); |
| 92 | VERIFY_IS_EQUAL(numext::sqrt(zero), zero); |
| 93 | VERIFY_IS_EQUAL(numext::sqrt(inf), inf); |
| 94 | VERIFY((numext::isnan)(numext::sqrt(nan))); |
| 95 | VERIFY((numext::isnan)(numext::sqrt(-one))); |
| 96 | } |
| 97 | }; |
| 98 | |
| 99 | template<typename T> |
| 100 | struct check_sqrt_impl<std::complex<T> > { |
| 101 | static void run() { |
| 102 | typedef typename std::complex<T> ComplexT; |
| 103 | |
| 104 | for (int i=0; i<1000; ++i) { |
| 105 | const ComplexT x = internal::random<ComplexT>(); |
| 106 | const ComplexT sqrtx = numext::sqrt(x); |
| 107 | VERIFY_IS_APPROX(sqrtx*sqrtx, x); |
| 108 | } |
| 109 | |
| 110 | // Corner cases. |
| 111 | const T zero = T(0); |
| 112 | const T one = T(1); |
| 113 | const T inf = std::numeric_limits<T>::infinity(); |
| 114 | const T nan = std::numeric_limits<T>::quiet_NaN(); |
| 115 | |
| 116 | // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt |
| 117 | const int kNumCorners = 20; |
| 118 | const ComplexT corners[kNumCorners][2] = { |
| 119 | {ComplexT(zero, zero), ComplexT(zero, zero)}, |
| 120 | {ComplexT(-zero, zero), ComplexT(zero, zero)}, |
| 121 | {ComplexT(zero, -zero), ComplexT(zero, zero)}, |
| 122 | {ComplexT(-zero, -zero), ComplexT(zero, zero)}, |
| 123 | {ComplexT(one, inf), ComplexT(inf, inf)}, |
| 124 | {ComplexT(nan, inf), ComplexT(inf, inf)}, |
| 125 | {ComplexT(one, -inf), ComplexT(inf, -inf)}, |
| 126 | {ComplexT(nan, -inf), ComplexT(inf, -inf)}, |
| 127 | {ComplexT(-inf, one), ComplexT(zero, inf)}, |
| 128 | {ComplexT(inf, one), ComplexT(inf, zero)}, |
| 129 | {ComplexT(-inf, -one), ComplexT(zero, -inf)}, |
| 130 | {ComplexT(inf, -one), ComplexT(inf, -zero)}, |
| 131 | {ComplexT(-inf, nan), ComplexT(nan, inf)}, |
| 132 | {ComplexT(inf, nan), ComplexT(inf, nan)}, |
| 133 | {ComplexT(zero, nan), ComplexT(nan, nan)}, |
| 134 | {ComplexT(one, nan), ComplexT(nan, nan)}, |
| 135 | {ComplexT(nan, zero), ComplexT(nan, nan)}, |
| 136 | {ComplexT(nan, one), ComplexT(nan, nan)}, |
| 137 | {ComplexT(nan, -one), ComplexT(nan, nan)}, |
| 138 | {ComplexT(nan, nan), ComplexT(nan, nan)}, |
| 139 | }; |
| 140 | |
| 141 | for (int i=0; i<kNumCorners; ++i) { |
| 142 | const ComplexT& x = corners[i][0]; |
| 143 | const ComplexT sqrtx = corners[i][1]; |
| 144 | VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx); |
| 145 | } |
| 146 | } |
| 147 | }; |
| 148 | |
| 149 | template<typename T> |
| 150 | void check_sqrt() { |
| 151 | check_sqrt_impl<T>::run(); |
| 152 | } |
| 153 | |
| 154 | template<typename T> |
| 155 | struct check_rsqrt_impl { |
| 156 | static void run() { |
| 157 | const T zero = T(0); |
| 158 | const T one = T(1); |
| 159 | const T inf = std::numeric_limits<T>::infinity(); |
| 160 | const T nan = std::numeric_limits<T>::quiet_NaN(); |
| 161 | |
| 162 | for (int i=0; i<1000; ++i) { |
| 163 | const T x = numext::abs(internal::random<T>()); |
| 164 | const T rsqrtx = numext::rsqrt(x); |
| 165 | const T invx = one / x; |
| 166 | VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx); |
| 167 | } |
| 168 | |
| 169 | // Corner cases. |
| 170 | VERIFY_IS_EQUAL(numext::rsqrt(zero), inf); |
| 171 | VERIFY_IS_EQUAL(numext::rsqrt(inf), zero); |
| 172 | VERIFY((numext::isnan)(numext::rsqrt(nan))); |
| 173 | VERIFY((numext::isnan)(numext::rsqrt(-one))); |
| 174 | } |
| 175 | }; |
| 176 | |
| 177 | template<typename T> |
| 178 | struct check_rsqrt_impl<std::complex<T> > { |
| 179 | static void run() { |
| 180 | typedef typename std::complex<T> ComplexT; |
| 181 | const T zero = T(0); |
| 182 | const T one = T(1); |
| 183 | const T inf = std::numeric_limits<T>::infinity(); |
| 184 | const T nan = std::numeric_limits<T>::quiet_NaN(); |
| 185 | |
| 186 | for (int i=0; i<1000; ++i) { |
| 187 | const ComplexT x = internal::random<ComplexT>(); |
| 188 | const ComplexT invx = ComplexT(one, zero) / x; |
| 189 | const ComplexT rsqrtx = numext::rsqrt(x); |
| 190 | VERIFY_IS_APPROX(rsqrtx*rsqrtx, invx); |
| 191 | } |
| 192 | |
| 193 | // GCC and MSVC differ in their treatment of 1/(0 + 0i) |
| 194 | // GCC/clang = (inf, nan) |
| 195 | // MSVC = (nan, nan) |
| 196 | // and 1 / (x + inf i) |
| 197 | // GCC/clang = (0, 0) |
| 198 | // MSVC = (nan, nan) |
| 199 | #if (EIGEN_COMP_GNUC) |
| 200 | { |
| 201 | const int kNumCorners = 20; |
| 202 | const ComplexT corners[kNumCorners][2] = { |
| 203 | // Only consistent across GCC, clang |
| 204 | {ComplexT(zero, zero), ComplexT(zero, zero)}, |
| 205 | {ComplexT(-zero, zero), ComplexT(zero, zero)}, |
| 206 | {ComplexT(zero, -zero), ComplexT(zero, zero)}, |
| 207 | {ComplexT(-zero, -zero), ComplexT(zero, zero)}, |
| 208 | {ComplexT(one, inf), ComplexT(inf, inf)}, |
| 209 | {ComplexT(nan, inf), ComplexT(inf, inf)}, |
| 210 | {ComplexT(one, -inf), ComplexT(inf, -inf)}, |
| 211 | {ComplexT(nan, -inf), ComplexT(inf, -inf)}, |
| 212 | // Consistent across GCC, clang, MSVC |
| 213 | {ComplexT(-inf, one), ComplexT(zero, inf)}, |
| 214 | {ComplexT(inf, one), ComplexT(inf, zero)}, |
| 215 | {ComplexT(-inf, -one), ComplexT(zero, -inf)}, |
| 216 | {ComplexT(inf, -one), ComplexT(inf, -zero)}, |
| 217 | {ComplexT(-inf, nan), ComplexT(nan, inf)}, |
| 218 | {ComplexT(inf, nan), ComplexT(inf, nan)}, |
| 219 | {ComplexT(zero, nan), ComplexT(nan, nan)}, |
| 220 | {ComplexT(one, nan), ComplexT(nan, nan)}, |
| 221 | {ComplexT(nan, zero), ComplexT(nan, nan)}, |
| 222 | {ComplexT(nan, one), ComplexT(nan, nan)}, |
| 223 | {ComplexT(nan, -one), ComplexT(nan, nan)}, |
| 224 | {ComplexT(nan, nan), ComplexT(nan, nan)}, |
| 225 | }; |
| 226 | |
| 227 | for (int i=0; i<kNumCorners; ++i) { |
| 228 | const ComplexT& x = corners[i][0]; |
| 229 | const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1]; |
| 230 | VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx); |
| 231 | } |
| 232 | } |
| 233 | #endif |
| 234 | } |
| 235 | }; |
| 236 | |
| 237 | template<typename T> |
| 238 | void check_rsqrt() { |
| 239 | check_rsqrt_impl<T>::run(); |
| 240 | } |
| 241 | |
| 242 | EIGEN_DECLARE_TEST(numext) { |
| 243 | for(int k=0; k<g_repeat; ++k) |
| 244 | { |
| 245 | CALL_SUBTEST( check_abs<bool>() ); |
| 246 | CALL_SUBTEST( check_abs<signed char>() ); |
| 247 | CALL_SUBTEST( check_abs<unsigned char>() ); |
| 248 | CALL_SUBTEST( check_abs<short>() ); |
| 249 | CALL_SUBTEST( check_abs<unsigned short>() ); |
| 250 | CALL_SUBTEST( check_abs<int>() ); |
| 251 | CALL_SUBTEST( check_abs<unsigned int>() ); |
| 252 | CALL_SUBTEST( check_abs<long>() ); |
| 253 | CALL_SUBTEST( check_abs<unsigned long>() ); |
| 254 | CALL_SUBTEST( check_abs<half>() ); |
| 255 | CALL_SUBTEST( check_abs<bfloat16>() ); |
| 256 | CALL_SUBTEST( check_abs<float>() ); |
| 257 | CALL_SUBTEST( check_abs<double>() ); |
| 258 | CALL_SUBTEST( check_abs<long double>() ); |
| 259 | CALL_SUBTEST( check_abs<std::complex<float> >() ); |
| 260 | CALL_SUBTEST( check_abs<std::complex<double> >() ); |
| 261 | |
| 262 | CALL_SUBTEST( check_arg<std::complex<float> >() ); |
| 263 | CALL_SUBTEST( check_arg<std::complex<double> >() ); |
| 264 | |
| 265 | CALL_SUBTEST( check_sqrt<float>() ); |
| 266 | CALL_SUBTEST( check_sqrt<double>() ); |
| 267 | CALL_SUBTEST( check_sqrt<std::complex<float> >() ); |
| 268 | CALL_SUBTEST( check_sqrt<std::complex<double> >() ); |
| 269 | |
| 270 | CALL_SUBTEST( check_rsqrt<float>() ); |
| 271 | CALL_SUBTEST( check_rsqrt<double>() ); |
| 272 | CALL_SUBTEST( check_rsqrt<std::complex<float> >() ); |
| 273 | CALL_SUBTEST( check_rsqrt<std::complex<double> >() ); |
| 274 | } |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame] | 275 | } |