blob: f1c260e29450eb6b557e664138b8093d343a6f9e [file] [log] [blame]
Austin Schuh189376f2018-12-20 22:11:15 +11001// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPECIAL_FUNCTIONS_H
11#define EIGEN_SPECIAL_FUNCTIONS_H
12
13namespace Eigen {
14namespace internal {
15
16// Parts of this code are based on the Cephes Math Library.
17//
18// Cephes Math Library Release 2.8: June, 2000
19// Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier
20//
21// Permission has been kindly provided by the original author
22// to incorporate the Cephes software into the Eigen codebase:
23//
24// From: Stephen Moshier
25// To: Eugene Brevdo
26// Subject: Re: Permission to wrap several cephes functions in Eigen
27//
28// Hello Eugene,
29//
30// Thank you for writing.
31//
32// If your licensing is similar to BSD, the formal way that has been
33// handled is simply to add a statement to the effect that you are incorporating
34// the Cephes software by permission of the author.
35//
36// Good luck with your project,
37// Steve
38
Austin Schuh189376f2018-12-20 22:11:15 +110039
40/****************************************************************************
41 * Implementation of lgamma, requires C++11/C99 *
42 ****************************************************************************/
43
44template <typename Scalar>
45struct lgamma_impl {
46 EIGEN_DEVICE_FUNC
47 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
48 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
49 THIS_TYPE_IS_NOT_SUPPORTED);
50 return Scalar(0);
51 }
52};
53
54template <typename Scalar>
55struct lgamma_retval {
56 typedef Scalar type;
57};
58
59#if EIGEN_HAS_C99_MATH
Austin Schuhc55b0172022-02-20 17:52:35 -080060// Since glibc 2.19
61#if defined(__GLIBC__) && ((__GLIBC__>=2 && __GLIBC_MINOR__ >= 19) || __GLIBC__>2) \
62 && (defined(_DEFAULT_SOURCE) || defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
63#define EIGEN_HAS_LGAMMA_R
64#endif
65
66// Glibc versions before 2.19
67#if defined(__GLIBC__) && ((__GLIBC__==2 && __GLIBC_MINOR__ < 19) || __GLIBC__<2) \
68 && (defined(_BSD_SOURCE) || defined(_SVID_SOURCE))
69#define EIGEN_HAS_LGAMMA_R
70#endif
71
Austin Schuh189376f2018-12-20 22:11:15 +110072template <>
73struct lgamma_impl<float> {
74 EIGEN_DEVICE_FUNC
75 static EIGEN_STRONG_INLINE float run(float x) {
Austin Schuhc55b0172022-02-20 17:52:35 -080076#if !defined(EIGEN_GPU_COMPILE_PHASE) && defined (EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
77 int dummy;
78 return ::lgammaf_r(x, &dummy);
79#elif defined(SYCL_DEVICE_ONLY)
80 return cl::sycl::lgamma(x);
Austin Schuh189376f2018-12-20 22:11:15 +110081#else
82 return ::lgammaf(x);
83#endif
84 }
85};
86
87template <>
88struct lgamma_impl<double> {
89 EIGEN_DEVICE_FUNC
90 static EIGEN_STRONG_INLINE double run(double x) {
Austin Schuhc55b0172022-02-20 17:52:35 -080091#if !defined(EIGEN_GPU_COMPILE_PHASE) && defined(EIGEN_HAS_LGAMMA_R) && !defined(__APPLE__)
92 int dummy;
93 return ::lgamma_r(x, &dummy);
94#elif defined(SYCL_DEVICE_ONLY)
95 return cl::sycl::lgamma(x);
Austin Schuh189376f2018-12-20 22:11:15 +110096#else
97 return ::lgamma(x);
98#endif
99 }
100};
Austin Schuhc55b0172022-02-20 17:52:35 -0800101
102#undef EIGEN_HAS_LGAMMA_R
Austin Schuh189376f2018-12-20 22:11:15 +1100103#endif
104
105/****************************************************************************
106 * Implementation of digamma (psi), based on Cephes *
107 ****************************************************************************/
108
109template <typename Scalar>
110struct digamma_retval {
111 typedef Scalar type;
112};
113
114/*
115 *
116 * Polynomial evaluation helper for the Psi (digamma) function.
117 *
118 * digamma_impl_maybe_poly::run(s) evaluates the asymptotic Psi expansion for
119 * input Scalar s, assuming s is above 10.0.
120 *
121 * If s is above a certain threshold for the given Scalar type, zero
122 * is returned. Otherwise the polynomial is evaluated with enough
123 * coefficients for results matching Scalar machine precision.
124 *
125 *
126 */
127template <typename Scalar>
128struct digamma_impl_maybe_poly {
129 EIGEN_DEVICE_FUNC
130 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
131 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
132 THIS_TYPE_IS_NOT_SUPPORTED);
133 return Scalar(0);
134 }
135};
136
137
138template <>
139struct digamma_impl_maybe_poly<float> {
140 EIGEN_DEVICE_FUNC
141 static EIGEN_STRONG_INLINE float run(const float s) {
142 const float A[] = {
143 -4.16666666666666666667E-3f,
144 3.96825396825396825397E-3f,
145 -8.33333333333333333333E-3f,
146 8.33333333333333333333E-2f
147 };
148
149 float z;
150 if (s < 1.0e8f) {
151 z = 1.0f / (s * s);
Austin Schuhc55b0172022-02-20 17:52:35 -0800152 return z * internal::ppolevl<float, 3>::run(z, A);
Austin Schuh189376f2018-12-20 22:11:15 +1100153 } else return 0.0f;
154 }
155};
156
157template <>
158struct digamma_impl_maybe_poly<double> {
159 EIGEN_DEVICE_FUNC
160 static EIGEN_STRONG_INLINE double run(const double s) {
161 const double A[] = {
162 8.33333333333333333333E-2,
163 -2.10927960927960927961E-2,
164 7.57575757575757575758E-3,
165 -4.16666666666666666667E-3,
166 3.96825396825396825397E-3,
167 -8.33333333333333333333E-3,
168 8.33333333333333333333E-2
169 };
170
171 double z;
172 if (s < 1.0e17) {
173 z = 1.0 / (s * s);
Austin Schuhc55b0172022-02-20 17:52:35 -0800174 return z * internal::ppolevl<double, 6>::run(z, A);
Austin Schuh189376f2018-12-20 22:11:15 +1100175 }
176 else return 0.0;
177 }
178};
179
180template <typename Scalar>
181struct digamma_impl {
182 EIGEN_DEVICE_FUNC
183 static Scalar run(Scalar x) {
184 /*
185 *
186 * Psi (digamma) function (modified for Eigen)
187 *
188 *
189 * SYNOPSIS:
190 *
191 * double x, y, psi();
192 *
193 * y = psi( x );
194 *
195 *
196 * DESCRIPTION:
197 *
198 * d -
199 * psi(x) = -- ln | (x)
200 * dx
201 *
202 * is the logarithmic derivative of the gamma function.
203 * For integer x,
204 * n-1
205 * -
206 * psi(n) = -EUL + > 1/k.
207 * -
208 * k=1
209 *
210 * If x is negative, it is transformed to a positive argument by the
211 * reflection formula psi(1-x) = psi(x) + pi cot(pi x).
212 * For general positive x, the argument is made greater than 10
213 * using the recurrence psi(x+1) = psi(x) + 1/x.
214 * Then the following asymptotic expansion is applied:
215 *
216 * inf. B
217 * - 2k
218 * psi(x) = log(x) - 1/2x - > -------
219 * - 2k
220 * k=1 2k x
221 *
222 * where the B2k are Bernoulli numbers.
223 *
224 * ACCURACY (float):
225 * Relative error (except absolute when |psi| < 1):
226 * arithmetic domain # trials peak rms
227 * IEEE 0,30 30000 1.3e-15 1.4e-16
228 * IEEE -30,0 40000 1.5e-15 2.2e-16
229 *
230 * ACCURACY (double):
231 * Absolute error, relative when |psi| > 1 :
232 * arithmetic domain # trials peak rms
233 * IEEE -33,0 30000 8.2e-7 1.2e-7
234 * IEEE 0,33 100000 7.3e-7 7.7e-8
235 *
236 * ERROR MESSAGES:
237 * message condition value returned
238 * psi singularity x integer <=0 INFINITY
239 */
240
241 Scalar p, q, nz, s, w, y;
242 bool negative = false;
243
Austin Schuhc55b0172022-02-20 17:52:35 -0800244 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
Austin Schuh189376f2018-12-20 22:11:15 +1100245 const Scalar m_pi = Scalar(EIGEN_PI);
246
247 const Scalar zero = Scalar(0);
248 const Scalar one = Scalar(1);
249 const Scalar half = Scalar(0.5);
250 nz = zero;
251
252 if (x <= zero) {
253 negative = true;
254 q = x;
255 p = numext::floor(q);
256 if (p == q) {
Austin Schuhc55b0172022-02-20 17:52:35 -0800257 return nan;
Austin Schuh189376f2018-12-20 22:11:15 +1100258 }
259 /* Remove the zeros of tan(m_pi x)
260 * by subtracting the nearest integer from x
261 */
262 nz = q - p;
263 if (nz != half) {
264 if (nz > half) {
265 p += one;
266 nz = q - p;
267 }
268 nz = m_pi / numext::tan(m_pi * nz);
269 }
270 else {
271 nz = zero;
272 }
273 x = one - x;
274 }
275
276 /* use the recurrence psi(x+1) = psi(x) + 1/x. */
277 s = x;
278 w = zero;
279 while (s < Scalar(10)) {
280 w += one / s;
281 s += one;
282 }
283
284 y = digamma_impl_maybe_poly<Scalar>::run(s);
285
286 y = numext::log(s) - (half / s) - y - w;
287
288 return (negative) ? y - nz : y;
289 }
290};
291
292/****************************************************************************
293 * Implementation of erf, requires C++11/C99 *
294 ****************************************************************************/
295
Austin Schuhc55b0172022-02-20 17:52:35 -0800296/** \internal \returns the error function of \a a (coeff-wise)
297 Doesn't do anything fancy, just a 13/8-degree rational interpolant which
298 is accurate up to a couple of ulp in the range [-4, 4], outside of which
299 fl(erf(x)) = +/-1.
300
301 This implementation works on both scalars and Ts.
302*/
303template <typename T>
304EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
305 // Clamp the inputs to the range [-4, 4] since anything outside
306 // this range is +/-1.0f in single-precision.
307 const T plus_4 = pset1<T>(4.f);
308 const T minus_4 = pset1<T>(-4.f);
309 const T x = pmax(pmin(a_x, plus_4), minus_4);
310 // The monomial coefficients of the numerator polynomial (odd).
311 const T alpha_1 = pset1<T>(-1.60960333262415e-02f);
312 const T alpha_3 = pset1<T>(-2.95459980854025e-03f);
313 const T alpha_5 = pset1<T>(-7.34990630326855e-04f);
314 const T alpha_7 = pset1<T>(-5.69250639462346e-05f);
315 const T alpha_9 = pset1<T>(-2.10102402082508e-06f);
316 const T alpha_11 = pset1<T>(2.77068142495902e-08f);
317 const T alpha_13 = pset1<T>(-2.72614225801306e-10f);
318
319 // The monomial coefficients of the denominator polynomial (even).
320 const T beta_0 = pset1<T>(-1.42647390514189e-02f);
321 const T beta_2 = pset1<T>(-7.37332916720468e-03f);
322 const T beta_4 = pset1<T>(-1.68282697438203e-03f);
323 const T beta_6 = pset1<T>(-2.13374055278905e-04f);
324 const T beta_8 = pset1<T>(-1.45660718464996e-05f);
325
326 // Since the polynomials are odd/even, we need x^2.
327 const T x2 = pmul(x, x);
328
329 // Evaluate the numerator polynomial p.
330 T p = pmadd(x2, alpha_13, alpha_11);
331 p = pmadd(x2, p, alpha_9);
332 p = pmadd(x2, p, alpha_7);
333 p = pmadd(x2, p, alpha_5);
334 p = pmadd(x2, p, alpha_3);
335 p = pmadd(x2, p, alpha_1);
336 p = pmul(x, p);
337
338 // Evaluate the denominator polynomial p.
339 T q = pmadd(x2, beta_8, beta_6);
340 q = pmadd(x2, q, beta_4);
341 q = pmadd(x2, q, beta_2);
342 q = pmadd(x2, q, beta_0);
343
344 // Divide the numerator by the denominator.
345 return pdiv(p, q);
346}
347
348template <typename T>
Austin Schuh189376f2018-12-20 22:11:15 +1100349struct erf_impl {
350 EIGEN_DEVICE_FUNC
Austin Schuhc55b0172022-02-20 17:52:35 -0800351 static EIGEN_STRONG_INLINE T run(const T& x) {
352 return generic_fast_erf_float(x);
Austin Schuh189376f2018-12-20 22:11:15 +1100353 }
354};
355
356template <typename Scalar>
357struct erf_retval {
358 typedef Scalar type;
359};
360
361#if EIGEN_HAS_C99_MATH
362template <>
363struct erf_impl<float> {
364 EIGEN_DEVICE_FUNC
Austin Schuhc55b0172022-02-20 17:52:35 -0800365 static EIGEN_STRONG_INLINE float run(float x) {
366#if defined(SYCL_DEVICE_ONLY)
367 return cl::sycl::erf(x);
368#else
369 return generic_fast_erf_float(x);
370#endif
371 }
Austin Schuh189376f2018-12-20 22:11:15 +1100372};
373
374template <>
375struct erf_impl<double> {
376 EIGEN_DEVICE_FUNC
Austin Schuhc55b0172022-02-20 17:52:35 -0800377 static EIGEN_STRONG_INLINE double run(double x) {
378#if defined(SYCL_DEVICE_ONLY)
379 return cl::sycl::erf(x);
380#else
381 return ::erf(x);
382#endif
383 }
Austin Schuh189376f2018-12-20 22:11:15 +1100384};
385#endif // EIGEN_HAS_C99_MATH
386
387/***************************************************************************
388* Implementation of erfc, requires C++11/C99 *
389****************************************************************************/
390
391template <typename Scalar>
392struct erfc_impl {
393 EIGEN_DEVICE_FUNC
394 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
395 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
396 THIS_TYPE_IS_NOT_SUPPORTED);
397 return Scalar(0);
398 }
399};
400
401template <typename Scalar>
402struct erfc_retval {
403 typedef Scalar type;
404};
405
406#if EIGEN_HAS_C99_MATH
407template <>
408struct erfc_impl<float> {
409 EIGEN_DEVICE_FUNC
Austin Schuhc55b0172022-02-20 17:52:35 -0800410 static EIGEN_STRONG_INLINE float run(const float x) {
411#if defined(SYCL_DEVICE_ONLY)
412 return cl::sycl::erfc(x);
413#else
414 return ::erfcf(x);
415#endif
416 }
Austin Schuh189376f2018-12-20 22:11:15 +1100417};
418
419template <>
420struct erfc_impl<double> {
421 EIGEN_DEVICE_FUNC
Austin Schuhc55b0172022-02-20 17:52:35 -0800422 static EIGEN_STRONG_INLINE double run(const double x) {
423#if defined(SYCL_DEVICE_ONLY)
424 return cl::sycl::erfc(x);
425#else
426 return ::erfc(x);
427#endif
428 }
Austin Schuh189376f2018-12-20 22:11:15 +1100429};
430#endif // EIGEN_HAS_C99_MATH
431
Austin Schuhc55b0172022-02-20 17:52:35 -0800432
433/***************************************************************************
434* Implementation of ndtri. *
435****************************************************************************/
436
437/* Inverse of Normal distribution function (modified for Eigen).
438 *
439 *
440 * SYNOPSIS:
441 *
442 * double x, y, ndtri();
443 *
444 * x = ndtri( y );
445 *
446 *
447 *
448 * DESCRIPTION:
449 *
450 * Returns the argument, x, for which the area under the
451 * Gaussian probability density function (integrated from
452 * minus infinity to x) is equal to y.
453 *
454 *
455 * For small arguments 0 < y < exp(-2), the program computes
456 * z = sqrt( -2.0 * log(y) ); then the approximation is
457 * x = z - log(z)/z - (1/z) P(1/z) / Q(1/z).
458 * There are two rational functions P/Q, one for 0 < y < exp(-32)
459 * and the other for y up to exp(-2). For larger arguments,
460 * w = y - 0.5, and x/sqrt(2pi) = w + w**3 R(w**2)/S(w**2)).
461 *
462 *
463 * ACCURACY:
464 *
465 * Relative error:
466 * arithmetic domain # trials peak rms
467 * DEC 0.125, 1 5500 9.5e-17 2.1e-17
468 * DEC 6e-39, 0.135 3500 5.7e-17 1.3e-17
469 * IEEE 0.125, 1 20000 7.2e-16 1.3e-16
470 * IEEE 3e-308, 0.135 50000 4.6e-16 9.8e-17
471 *
472 *
473 * ERROR MESSAGES:
474 *
475 * message condition value returned
476 * ndtri domain x <= 0 -MAXNUM
477 * ndtri domain x >= 1 MAXNUM
478 *
479 */
480 /*
481 Cephes Math Library Release 2.2: June, 1992
482 Copyright 1985, 1987, 1992 by Stephen L. Moshier
483 Direct inquiries to 30 Frost Street, Cambridge, MA 02140
484 */
485
486
487// TODO: Add a cheaper approximation for float.
488
489
490template<typename T>
491EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T flipsign(
492 const T& should_flipsign, const T& x) {
493 typedef typename unpacket_traits<T>::type Scalar;
494 const T sign_mask = pset1<T>(Scalar(-0.0));
495 T sign_bit = pand<T>(should_flipsign, sign_mask);
496 return pxor<T>(sign_bit, x);
497}
498
499template<>
500EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double flipsign<double>(
501 const double& should_flipsign, const double& x) {
502 return should_flipsign == 0 ? x : -x;
503}
504
505template<>
506EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float flipsign<float>(
507 const float& should_flipsign, const float& x) {
508 return should_flipsign == 0 ? x : -x;
509}
510
511// We split this computation in to two so that in the scalar path
512// only one branch is evaluated (due to our template specialization of pselect
513// being an if statement.)
514
515template <typename T, typename ScalarType>
516EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(const T& b) {
517 const ScalarType p0[] = {
518 ScalarType(-5.99633501014107895267e1),
519 ScalarType(9.80010754185999661536e1),
520 ScalarType(-5.66762857469070293439e1),
521 ScalarType(1.39312609387279679503e1),
522 ScalarType(-1.23916583867381258016e0)
523 };
524 const ScalarType q0[] = {
525 ScalarType(1.0),
526 ScalarType(1.95448858338141759834e0),
527 ScalarType(4.67627912898881538453e0),
528 ScalarType(8.63602421390890590575e1),
529 ScalarType(-2.25462687854119370527e2),
530 ScalarType(2.00260212380060660359e2),
531 ScalarType(-8.20372256168333339912e1),
532 ScalarType(1.59056225126211695515e1),
533 ScalarType(-1.18331621121330003142e0)
534 };
535 const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0));
536 const T half = pset1<T>(ScalarType(0.5));
537 T c, c2, ndtri_gt_exp_neg_two;
538
539 c = psub(b, half);
540 c2 = pmul(c, c);
541 ndtri_gt_exp_neg_two = pmadd(c, pmul(
542 c2, pdiv(
543 internal::ppolevl<T, 4>::run(c2, p0),
544 internal::ppolevl<T, 8>::run(c2, q0))), c);
545 return pmul(ndtri_gt_exp_neg_two, sqrt2pi);
546}
547
548template <typename T, typename ScalarType>
549EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_lt_exp_neg_two(
550 const T& b, const T& should_flipsign) {
551 /* Approximation for interval z = sqrt(-2 log a ) between 2 and 8
552 * i.e., a between exp(-2) = .135 and exp(-32) = 1.27e-14.
553 */
554 const ScalarType p1[] = {
555 ScalarType(4.05544892305962419923e0),
556 ScalarType(3.15251094599893866154e1),
557 ScalarType(5.71628192246421288162e1),
558 ScalarType(4.40805073893200834700e1),
559 ScalarType(1.46849561928858024014e1),
560 ScalarType(2.18663306850790267539e0),
561 ScalarType(-1.40256079171354495875e-1),
562 ScalarType(-3.50424626827848203418e-2),
563 ScalarType(-8.57456785154685413611e-4)
564 };
565 const ScalarType q1[] = {
566 ScalarType(1.0),
567 ScalarType(1.57799883256466749731e1),
568 ScalarType(4.53907635128879210584e1),
569 ScalarType(4.13172038254672030440e1),
570 ScalarType(1.50425385692907503408e1),
571 ScalarType(2.50464946208309415979e0),
572 ScalarType(-1.42182922854787788574e-1),
573 ScalarType(-3.80806407691578277194e-2),
574 ScalarType(-9.33259480895457427372e-4)
575 };
576 /* Approximation for interval z = sqrt(-2 log a ) between 8 and 64
577 * i.e., a between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
578 */
579 const ScalarType p2[] = {
580 ScalarType(3.23774891776946035970e0),
581 ScalarType(6.91522889068984211695e0),
582 ScalarType(3.93881025292474443415e0),
583 ScalarType(1.33303460815807542389e0),
584 ScalarType(2.01485389549179081538e-1),
585 ScalarType(1.23716634817820021358e-2),
586 ScalarType(3.01581553508235416007e-4),
587 ScalarType(2.65806974686737550832e-6),
588 ScalarType(6.23974539184983293730e-9)
589 };
590 const ScalarType q2[] = {
591 ScalarType(1.0),
592 ScalarType(6.02427039364742014255e0),
593 ScalarType(3.67983563856160859403e0),
594 ScalarType(1.37702099489081330271e0),
595 ScalarType(2.16236993594496635890e-1),
596 ScalarType(1.34204006088543189037e-2),
597 ScalarType(3.28014464682127739104e-4),
598 ScalarType(2.89247864745380683936e-6),
599 ScalarType(6.79019408009981274425e-9)
600 };
601 const T eight = pset1<T>(ScalarType(8.0));
602 const T one = pset1<T>(ScalarType(1));
603 const T neg_two = pset1<T>(ScalarType(-2));
604 T x, x0, x1, z;
605
606 x = psqrt(pmul(neg_two, plog(b)));
607 x0 = psub(x, pdiv(plog(x), x));
608 z = pdiv(one, x);
609 x1 = pmul(
610 z, pselect(
611 pcmp_lt(x, eight),
612 pdiv(internal::ppolevl<T, 8>::run(z, p1),
613 internal::ppolevl<T, 8>::run(z, q1)),
614 pdiv(internal::ppolevl<T, 8>::run(z, p2),
615 internal::ppolevl<T, 8>::run(z, q2))));
616 return flipsign(should_flipsign, psub(x0, x1));
617}
618
619template <typename T, typename ScalarType>
620EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
621T generic_ndtri(const T& a) {
622 const T maxnum = pset1<T>(NumTraits<ScalarType>::infinity());
623 const T neg_maxnum = pset1<T>(-NumTraits<ScalarType>::infinity());
624
625 const T zero = pset1<T>(ScalarType(0));
626 const T one = pset1<T>(ScalarType(1));
627 // exp(-2)
628 const T exp_neg_two = pset1<T>(ScalarType(0.13533528323661269189));
629 T b, ndtri, should_flipsign;
630
631 should_flipsign = pcmp_le(a, psub(one, exp_neg_two));
632 b = pselect(should_flipsign, a, psub(one, a));
633
634 ndtri = pselect(
635 pcmp_lt(exp_neg_two, b),
636 generic_ndtri_gt_exp_neg_two<T, ScalarType>(b),
637 generic_ndtri_lt_exp_neg_two<T, ScalarType>(b, should_flipsign));
638
639 return pselect(
640 pcmp_le(a, zero), neg_maxnum,
641 pselect(pcmp_le(one, a), maxnum, ndtri));
642}
643
644template <typename Scalar>
645struct ndtri_retval {
646 typedef Scalar type;
647};
648
649#if !EIGEN_HAS_C99_MATH
650
651template <typename Scalar>
652struct ndtri_impl {
653 EIGEN_DEVICE_FUNC
654 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
655 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
656 THIS_TYPE_IS_NOT_SUPPORTED);
657 return Scalar(0);
658 }
659};
660
661# else
662
663template <typename Scalar>
664struct ndtri_impl {
665 EIGEN_DEVICE_FUNC
666 static EIGEN_STRONG_INLINE Scalar run(const Scalar x) {
667 return generic_ndtri<Scalar, Scalar>(x);
668 }
669};
670
671#endif // EIGEN_HAS_C99_MATH
672
673
Austin Schuh189376f2018-12-20 22:11:15 +1100674/**************************************************************************************************************
675 * Implementation of igammac (complemented incomplete gamma integral), based on Cephes but requires C++11/C99 *
676 **************************************************************************************************************/
677
678template <typename Scalar>
679struct igammac_retval {
680 typedef Scalar type;
681};
682
683// NOTE: cephes_helper is also used to implement zeta
684template <typename Scalar>
685struct cephes_helper {
686 EIGEN_DEVICE_FUNC
687 static EIGEN_STRONG_INLINE Scalar machep() { assert(false && "machep not supported for this type"); return 0.0; }
688 EIGEN_DEVICE_FUNC
689 static EIGEN_STRONG_INLINE Scalar big() { assert(false && "big not supported for this type"); return 0.0; }
690 EIGEN_DEVICE_FUNC
691 static EIGEN_STRONG_INLINE Scalar biginv() { assert(false && "biginv not supported for this type"); return 0.0; }
692};
693
694template <>
695struct cephes_helper<float> {
696 EIGEN_DEVICE_FUNC
697 static EIGEN_STRONG_INLINE float machep() {
698 return NumTraits<float>::epsilon() / 2; // 1.0 - machep == 1.0
699 }
700 EIGEN_DEVICE_FUNC
701 static EIGEN_STRONG_INLINE float big() {
702 // use epsneg (1.0 - epsneg == 1.0)
703 return 1.0f / (NumTraits<float>::epsilon() / 2);
704 }
705 EIGEN_DEVICE_FUNC
706 static EIGEN_STRONG_INLINE float biginv() {
707 // epsneg
708 return machep();
709 }
710};
711
712template <>
713struct cephes_helper<double> {
714 EIGEN_DEVICE_FUNC
715 static EIGEN_STRONG_INLINE double machep() {
716 return NumTraits<double>::epsilon() / 2; // 1.0 - machep == 1.0
717 }
718 EIGEN_DEVICE_FUNC
719 static EIGEN_STRONG_INLINE double big() {
720 return 1.0 / NumTraits<double>::epsilon();
721 }
722 EIGEN_DEVICE_FUNC
723 static EIGEN_STRONG_INLINE double biginv() {
724 // inverse of eps
725 return NumTraits<double>::epsilon();
726 }
727};
728
Austin Schuhc55b0172022-02-20 17:52:35 -0800729enum IgammaComputationMode { VALUE, DERIVATIVE, SAMPLE_DERIVATIVE };
730
731template <typename Scalar>
732EIGEN_DEVICE_FUNC
733static EIGEN_STRONG_INLINE Scalar main_igamma_term(Scalar a, Scalar x) {
734 /* Compute x**a * exp(-x) / gamma(a) */
735 Scalar logax = a * numext::log(x) - x - lgamma_impl<Scalar>::run(a);
736 if (logax < -numext::log(NumTraits<Scalar>::highest()) ||
737 // Assuming x and a aren't Nan.
738 (numext::isnan)(logax)) {
739 return Scalar(0);
740 }
741 return numext::exp(logax);
742}
743
744template <typename Scalar, IgammaComputationMode mode>
745EIGEN_DEVICE_FUNC
746int igamma_num_iterations() {
747 /* Returns the maximum number of internal iterations for igamma computation.
748 */
749 if (mode == VALUE) {
750 return 2000;
751 }
752
753 if (internal::is_same<Scalar, float>::value) {
754 return 200;
755 } else if (internal::is_same<Scalar, double>::value) {
756 return 500;
757 } else {
758 return 2000;
759 }
760}
761
762template <typename Scalar, IgammaComputationMode mode>
763struct igammac_cf_impl {
764 /* Computes igamc(a, x) or derivative (depending on the mode)
765 * using the continued fraction expansion of the complementary
766 * incomplete Gamma function.
767 *
768 * Preconditions:
769 * a > 0
770 * x >= 1
771 * x >= a
772 */
773 EIGEN_DEVICE_FUNC
774 static Scalar run(Scalar a, Scalar x) {
775 const Scalar zero = 0;
776 const Scalar one = 1;
777 const Scalar two = 2;
778 const Scalar machep = cephes_helper<Scalar>::machep();
779 const Scalar big = cephes_helper<Scalar>::big();
780 const Scalar biginv = cephes_helper<Scalar>::biginv();
781
782 if ((numext::isinf)(x)) {
783 return zero;
784 }
785
786 Scalar ax = main_igamma_term<Scalar>(a, x);
787 // This is independent of mode. If this value is zero,
788 // then the function value is zero. If the function value is zero,
789 // then we are in a neighborhood where the function value evalutes to zero,
790 // so the derivative is zero.
791 if (ax == zero) {
792 return zero;
793 }
794
795 // continued fraction
796 Scalar y = one - a;
797 Scalar z = x + y + one;
798 Scalar c = zero;
799 Scalar pkm2 = one;
800 Scalar qkm2 = x;
801 Scalar pkm1 = x + one;
802 Scalar qkm1 = z * x;
803 Scalar ans = pkm1 / qkm1;
804
805 Scalar dpkm2_da = zero;
806 Scalar dqkm2_da = zero;
807 Scalar dpkm1_da = zero;
808 Scalar dqkm1_da = -x;
809 Scalar dans_da = (dpkm1_da - ans * dqkm1_da) / qkm1;
810
811 for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
812 c += one;
813 y += one;
814 z += two;
815
816 Scalar yc = y * c;
817 Scalar pk = pkm1 * z - pkm2 * yc;
818 Scalar qk = qkm1 * z - qkm2 * yc;
819
820 Scalar dpk_da = dpkm1_da * z - pkm1 - dpkm2_da * yc + pkm2 * c;
821 Scalar dqk_da = dqkm1_da * z - qkm1 - dqkm2_da * yc + qkm2 * c;
822
823 if (qk != zero) {
824 Scalar ans_prev = ans;
825 ans = pk / qk;
826
827 Scalar dans_da_prev = dans_da;
828 dans_da = (dpk_da - ans * dqk_da) / qk;
829
830 if (mode == VALUE) {
831 if (numext::abs(ans_prev - ans) <= machep * numext::abs(ans)) {
832 break;
833 }
834 } else {
835 if (numext::abs(dans_da - dans_da_prev) <= machep) {
836 break;
837 }
838 }
839 }
840
841 pkm2 = pkm1;
842 pkm1 = pk;
843 qkm2 = qkm1;
844 qkm1 = qk;
845
846 dpkm2_da = dpkm1_da;
847 dpkm1_da = dpk_da;
848 dqkm2_da = dqkm1_da;
849 dqkm1_da = dqk_da;
850
851 if (numext::abs(pk) > big) {
852 pkm2 *= biginv;
853 pkm1 *= biginv;
854 qkm2 *= biginv;
855 qkm1 *= biginv;
856
857 dpkm2_da *= biginv;
858 dpkm1_da *= biginv;
859 dqkm2_da *= biginv;
860 dqkm1_da *= biginv;
861 }
862 }
863
864 /* Compute x**a * exp(-x) / gamma(a) */
865 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a);
866 Scalar dax_da = ax * dlogax_da;
867
868 switch (mode) {
869 case VALUE:
870 return ans * ax;
871 case DERIVATIVE:
872 return ans * dax_da + dans_da * ax;
873 case SAMPLE_DERIVATIVE:
874 default: // this is needed to suppress clang warning
875 return -(dans_da + ans * dlogax_da) * x;
876 }
877 }
878};
879
880template <typename Scalar, IgammaComputationMode mode>
881struct igamma_series_impl {
882 /* Computes igam(a, x) or its derivative (depending on the mode)
883 * using the series expansion of the incomplete Gamma function.
884 *
885 * Preconditions:
886 * x > 0
887 * a > 0
888 * !(x > 1 && x > a)
889 */
890 EIGEN_DEVICE_FUNC
891 static Scalar run(Scalar a, Scalar x) {
892 const Scalar zero = 0;
893 const Scalar one = 1;
894 const Scalar machep = cephes_helper<Scalar>::machep();
895
896 Scalar ax = main_igamma_term<Scalar>(a, x);
897
898 // This is independent of mode. If this value is zero,
899 // then the function value is zero. If the function value is zero,
900 // then we are in a neighborhood where the function value evalutes to zero,
901 // so the derivative is zero.
902 if (ax == zero) {
903 return zero;
904 }
905
906 ax /= a;
907
908 /* power series */
909 Scalar r = a;
910 Scalar c = one;
911 Scalar ans = one;
912
913 Scalar dc_da = zero;
914 Scalar dans_da = zero;
915
916 for (int i = 0; i < igamma_num_iterations<Scalar, mode>(); i++) {
917 r += one;
918 Scalar term = x / r;
919 Scalar dterm_da = -x / (r * r);
920 dc_da = term * dc_da + dterm_da * c;
921 dans_da += dc_da;
922 c *= term;
923 ans += c;
924
925 if (mode == VALUE) {
926 if (c <= machep * ans) {
927 break;
928 }
929 } else {
930 if (numext::abs(dc_da) <= machep * numext::abs(dans_da)) {
931 break;
932 }
933 }
934 }
935
936 Scalar dlogax_da = numext::log(x) - digamma_impl<Scalar>::run(a + one);
937 Scalar dax_da = ax * dlogax_da;
938
939 switch (mode) {
940 case VALUE:
941 return ans * ax;
942 case DERIVATIVE:
943 return ans * dax_da + dans_da * ax;
944 case SAMPLE_DERIVATIVE:
945 default: // this is needed to suppress clang warning
946 return -(dans_da + ans * dlogax_da) * x / a;
947 }
948 }
949};
950
Austin Schuh189376f2018-12-20 22:11:15 +1100951#if !EIGEN_HAS_C99_MATH
952
953template <typename Scalar>
954struct igammac_impl {
955 EIGEN_DEVICE_FUNC
956 static Scalar run(Scalar a, Scalar x) {
957 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
958 THIS_TYPE_IS_NOT_SUPPORTED);
959 return Scalar(0);
960 }
961};
962
963#else
964
Austin Schuh189376f2018-12-20 22:11:15 +1100965template <typename Scalar>
966struct igammac_impl {
967 EIGEN_DEVICE_FUNC
968 static Scalar run(Scalar a, Scalar x) {
969 /* igamc()
970 *
971 * Incomplete gamma integral (modified for Eigen)
972 *
973 *
974 *
975 * SYNOPSIS:
976 *
977 * double a, x, y, igamc();
978 *
979 * y = igamc( a, x );
980 *
981 * DESCRIPTION:
982 *
983 * The function is defined by
984 *
985 *
986 * igamc(a,x) = 1 - igam(a,x)
987 *
988 * inf.
989 * -
990 * 1 | | -t a-1
991 * = ----- | e t dt.
992 * - | |
993 * | (a) -
994 * x
995 *
996 *
997 * In this implementation both arguments must be positive.
998 * The integral is evaluated by either a power series or
999 * continued fraction expansion, depending on the relative
1000 * values of a and x.
1001 *
1002 * ACCURACY (float):
1003 *
1004 * Relative error:
1005 * arithmetic domain # trials peak rms
1006 * IEEE 0,30 30000 7.8e-6 5.9e-7
1007 *
1008 *
1009 * ACCURACY (double):
1010 *
1011 * Tested at random a, x.
1012 * a x Relative error:
1013 * arithmetic domain domain # trials peak rms
1014 * IEEE 0.5,100 0,100 200000 1.9e-14 1.7e-15
1015 * IEEE 0.01,0.5 0,100 200000 1.4e-13 1.6e-15
1016 *
1017 */
1018 /*
1019 Cephes Math Library Release 2.2: June, 1992
1020 Copyright 1985, 1987, 1992 by Stephen L. Moshier
1021 Direct inquiries to 30 Frost Street, Cambridge, MA 02140
1022 */
1023 const Scalar zero = 0;
1024 const Scalar one = 1;
1025 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1026
1027 if ((x < zero) || (a <= zero)) {
1028 // domain error
1029 return nan;
1030 }
1031
Austin Schuhc55b0172022-02-20 17:52:35 -08001032 if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
1033 return nan;
1034 }
1035
Austin Schuh189376f2018-12-20 22:11:15 +11001036 if ((x < one) || (x < a)) {
Austin Schuhc55b0172022-02-20 17:52:35 -08001037 return (one - igamma_series_impl<Scalar, VALUE>::run(a, x));
Austin Schuh189376f2018-12-20 22:11:15 +11001038 }
1039
Austin Schuhc55b0172022-02-20 17:52:35 -08001040 return igammac_cf_impl<Scalar, VALUE>::run(a, x);
Austin Schuh189376f2018-12-20 22:11:15 +11001041 }
1042};
1043
1044#endif // EIGEN_HAS_C99_MATH
1045
1046/************************************************************************************************
1047 * Implementation of igamma (incomplete gamma integral), based on Cephes but requires C++11/C99 *
1048 ************************************************************************************************/
1049
Austin Schuh189376f2018-12-20 22:11:15 +11001050#if !EIGEN_HAS_C99_MATH
1051
Austin Schuhc55b0172022-02-20 17:52:35 -08001052template <typename Scalar, IgammaComputationMode mode>
1053struct igamma_generic_impl {
Austin Schuh189376f2018-12-20 22:11:15 +11001054 EIGEN_DEVICE_FUNC
1055 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar x) {
1056 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1057 THIS_TYPE_IS_NOT_SUPPORTED);
1058 return Scalar(0);
1059 }
1060};
1061
1062#else
1063
Austin Schuhc55b0172022-02-20 17:52:35 -08001064template <typename Scalar, IgammaComputationMode mode>
1065struct igamma_generic_impl {
Austin Schuh189376f2018-12-20 22:11:15 +11001066 EIGEN_DEVICE_FUNC
1067 static Scalar run(Scalar a, Scalar x) {
Austin Schuhc55b0172022-02-20 17:52:35 -08001068 /* Depending on the mode, returns
1069 * - VALUE: incomplete Gamma function igamma(a, x)
1070 * - DERIVATIVE: derivative of incomplete Gamma function d/da igamma(a, x)
1071 * - SAMPLE_DERIVATIVE: implicit derivative of a Gamma random variable
1072 * x ~ Gamma(x | a, 1), dx/da = -1 / Gamma(x | a, 1) * d igamma(a, x) / dx
Austin Schuh189376f2018-12-20 22:11:15 +11001073 *
Austin Schuhc55b0172022-02-20 17:52:35 -08001074 * Derivatives are implemented by forward-mode differentiation.
Austin Schuh189376f2018-12-20 22:11:15 +11001075 */
1076 const Scalar zero = 0;
1077 const Scalar one = 1;
1078 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1079
1080 if (x == zero) return zero;
1081
1082 if ((x < zero) || (a <= zero)) { // domain error
1083 return nan;
1084 }
1085
Austin Schuhc55b0172022-02-20 17:52:35 -08001086 if ((numext::isnan)(a) || (numext::isnan)(x)) { // propagate nans
1087 return nan;
1088 }
1089
Austin Schuh189376f2018-12-20 22:11:15 +11001090 if ((x > one) && (x > a)) {
Austin Schuhc55b0172022-02-20 17:52:35 -08001091 Scalar ret = igammac_cf_impl<Scalar, mode>::run(a, x);
1092 if (mode == VALUE) {
1093 return one - ret;
1094 } else {
1095 return -ret;
Austin Schuh189376f2018-12-20 22:11:15 +11001096 }
1097 }
1098
Austin Schuhc55b0172022-02-20 17:52:35 -08001099 return igamma_series_impl<Scalar, mode>::run(a, x);
Austin Schuh189376f2018-12-20 22:11:15 +11001100 }
1101};
1102
1103#endif // EIGEN_HAS_C99_MATH
1104
Austin Schuhc55b0172022-02-20 17:52:35 -08001105template <typename Scalar>
1106struct igamma_retval {
1107 typedef Scalar type;
1108};
1109
1110template <typename Scalar>
1111struct igamma_impl : igamma_generic_impl<Scalar, VALUE> {
1112 /* igam()
1113 * Incomplete gamma integral.
1114 *
1115 * The CDF of Gamma(a, 1) random variable at the point x.
1116 *
1117 * Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample
1118 * 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points.
1119 * The ground truth is computed by mpmath. Mean absolute error:
1120 * float: 1.26713e-05
1121 * double: 2.33606e-12
1122 *
1123 * Cephes documentation below.
1124 *
1125 * SYNOPSIS:
1126 *
1127 * double a, x, y, igam();
1128 *
1129 * y = igam( a, x );
1130 *
1131 * DESCRIPTION:
1132 *
1133 * The function is defined by
1134 *
1135 * x
1136 * -
1137 * 1 | | -t a-1
1138 * igam(a,x) = ----- | e t dt.
1139 * - | |
1140 * | (a) -
1141 * 0
1142 *
1143 *
1144 * In this implementation both arguments must be positive.
1145 * The integral is evaluated by either a power series or
1146 * continued fraction expansion, depending on the relative
1147 * values of a and x.
1148 *
1149 * ACCURACY (double):
1150 *
1151 * Relative error:
1152 * arithmetic domain # trials peak rms
1153 * IEEE 0,30 200000 3.6e-14 2.9e-15
1154 * IEEE 0,100 300000 9.9e-14 1.5e-14
1155 *
1156 *
1157 * ACCURACY (float):
1158 *
1159 * Relative error:
1160 * arithmetic domain # trials peak rms
1161 * IEEE 0,30 20000 7.8e-6 5.9e-7
1162 *
1163 */
1164 /*
1165 Cephes Math Library Release 2.2: June, 1992
1166 Copyright 1985, 1987, 1992 by Stephen L. Moshier
1167 Direct inquiries to 30 Frost Street, Cambridge, MA 02140
1168 */
1169
1170 /* left tail of incomplete gamma function:
1171 *
1172 * inf. k
1173 * a -x - x
1174 * x e > ----------
1175 * - -
1176 * k=0 | (a+k+1)
1177 *
1178 */
1179};
1180
1181template <typename Scalar>
1182struct igamma_der_a_retval : igamma_retval<Scalar> {};
1183
1184template <typename Scalar>
1185struct igamma_der_a_impl : igamma_generic_impl<Scalar, DERIVATIVE> {
1186 /* Derivative of the incomplete Gamma function with respect to a.
1187 *
1188 * Computes d/da igamma(a, x) by forward differentiation of the igamma code.
1189 *
1190 * Accuracy estimation. For each a in [10^-2, 10^-1...10^3] we sample
1191 * 50 Gamma random variables x ~ Gamma(x | a, 1), a total of 300 points.
1192 * The ground truth is computed by mpmath. Mean absolute error:
1193 * float: 6.17992e-07
1194 * double: 4.60453e-12
1195 *
1196 * Reference:
1197 * R. Moore. "Algorithm AS 187: Derivatives of the incomplete gamma
1198 * integral". Journal of the Royal Statistical Society. 1982
1199 */
1200};
1201
1202template <typename Scalar>
1203struct gamma_sample_der_alpha_retval : igamma_retval<Scalar> {};
1204
1205template <typename Scalar>
1206struct gamma_sample_der_alpha_impl
1207 : igamma_generic_impl<Scalar, SAMPLE_DERIVATIVE> {
1208 /* Derivative of a Gamma random variable sample with respect to alpha.
1209 *
1210 * Consider a sample of a Gamma random variable with the concentration
1211 * parameter alpha: sample ~ Gamma(alpha, 1). The reparameterization
1212 * derivative that we want to compute is dsample / dalpha =
1213 * d igammainv(alpha, u) / dalpha, where u = igamma(alpha, sample).
1214 * However, this formula is numerically unstable and expensive, so instead
1215 * we use implicit differentiation:
1216 *
1217 * igamma(alpha, sample) = u, where u ~ Uniform(0, 1).
1218 * Apply d / dalpha to both sides:
1219 * d igamma(alpha, sample) / dalpha
1220 * + d igamma(alpha, sample) / dsample * dsample/dalpha = 0
1221 * d igamma(alpha, sample) / dalpha
1222 * + Gamma(sample | alpha, 1) dsample / dalpha = 0
1223 * dsample/dalpha = - (d igamma(alpha, sample) / dalpha)
1224 * / Gamma(sample | alpha, 1)
1225 *
1226 * Here Gamma(sample | alpha, 1) is the PDF of the Gamma distribution
1227 * (note that the derivative of the CDF w.r.t. sample is the PDF).
1228 * See the reference below for more details.
1229 *
1230 * The derivative of igamma(alpha, sample) is computed by forward
1231 * differentiation of the igamma code. Division by the Gamma PDF is performed
1232 * in the same code, increasing the accuracy and speed due to cancellation
1233 * of some terms.
1234 *
1235 * Accuracy estimation. For each alpha in [10^-2, 10^-1...10^3] we sample
1236 * 50 Gamma random variables sample ~ Gamma(sample | alpha, 1), a total of 300
1237 * points. The ground truth is computed by mpmath. Mean absolute error:
1238 * float: 2.1686e-06
1239 * double: 1.4774e-12
1240 *
1241 * Reference:
1242 * M. Figurnov, S. Mohamed, A. Mnih "Implicit Reparameterization Gradients".
1243 * 2018
1244 */
1245};
1246
Austin Schuh189376f2018-12-20 22:11:15 +11001247/*****************************************************************************
1248 * Implementation of Riemann zeta function of two arguments, based on Cephes *
1249 *****************************************************************************/
1250
1251template <typename Scalar>
1252struct zeta_retval {
1253 typedef Scalar type;
1254};
1255
1256template <typename Scalar>
1257struct zeta_impl_series {
1258 EIGEN_DEVICE_FUNC
1259 static EIGEN_STRONG_INLINE Scalar run(const Scalar) {
1260 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1261 THIS_TYPE_IS_NOT_SUPPORTED);
1262 return Scalar(0);
1263 }
1264};
1265
1266template <>
1267struct zeta_impl_series<float> {
1268 EIGEN_DEVICE_FUNC
1269 static EIGEN_STRONG_INLINE bool run(float& a, float& b, float& s, const float x, const float machep) {
1270 int i = 0;
1271 while(i < 9)
1272 {
1273 i += 1;
1274 a += 1.0f;
1275 b = numext::pow( a, -x );
1276 s += b;
1277 if( numext::abs(b/s) < machep )
1278 return true;
1279 }
1280
1281 //Return whether we are done
1282 return false;
1283 }
1284};
1285
1286template <>
1287struct zeta_impl_series<double> {
1288 EIGEN_DEVICE_FUNC
1289 static EIGEN_STRONG_INLINE bool run(double& a, double& b, double& s, const double x, const double machep) {
1290 int i = 0;
1291 while( (i < 9) || (a <= 9.0) )
1292 {
1293 i += 1;
1294 a += 1.0;
1295 b = numext::pow( a, -x );
1296 s += b;
1297 if( numext::abs(b/s) < machep )
1298 return true;
1299 }
1300
1301 //Return whether we are done
1302 return false;
1303 }
1304};
1305
1306template <typename Scalar>
1307struct zeta_impl {
1308 EIGEN_DEVICE_FUNC
1309 static Scalar run(Scalar x, Scalar q) {
1310 /* zeta.c
1311 *
1312 * Riemann zeta function of two arguments
1313 *
1314 *
1315 *
1316 * SYNOPSIS:
1317 *
1318 * double x, q, y, zeta();
1319 *
1320 * y = zeta( x, q );
1321 *
1322 *
1323 *
1324 * DESCRIPTION:
1325 *
1326 *
1327 *
1328 * inf.
1329 * - -x
1330 * zeta(x,q) = > (k+q)
1331 * -
1332 * k=0
1333 *
1334 * where x > 1 and q is not a negative integer or zero.
1335 * The Euler-Maclaurin summation formula is used to obtain
1336 * the expansion
1337 *
1338 * n
1339 * - -x
1340 * zeta(x,q) = > (k+q)
1341 * -
1342 * k=1
1343 *
1344 * 1-x inf. B x(x+1)...(x+2j)
1345 * (n+q) 1 - 2j
1346 * + --------- - ------- + > --------------------
1347 * x-1 x - x+2j+1
1348 * 2(n+q) j=1 (2j)! (n+q)
1349 *
1350 * where the B2j are Bernoulli numbers. Note that (see zetac.c)
1351 * zeta(x,1) = zetac(x) + 1.
1352 *
1353 *
1354 *
1355 * ACCURACY:
1356 *
1357 * Relative error for single precision:
1358 * arithmetic domain # trials peak rms
1359 * IEEE 0,25 10000 6.9e-7 1.0e-7
1360 *
1361 * Large arguments may produce underflow in powf(), in which
1362 * case the results are inaccurate.
1363 *
1364 * REFERENCE:
1365 *
1366 * Gradshteyn, I. S., and I. M. Ryzhik, Tables of Integrals,
1367 * Series, and Products, p. 1073; Academic Press, 1980.
1368 *
1369 */
1370
1371 int i;
1372 Scalar p, r, a, b, k, s, t, w;
1373
1374 const Scalar A[] = {
1375 Scalar(12.0),
1376 Scalar(-720.0),
1377 Scalar(30240.0),
1378 Scalar(-1209600.0),
1379 Scalar(47900160.0),
1380 Scalar(-1.8924375803183791606e9), /*1.307674368e12/691*/
1381 Scalar(7.47242496e10),
1382 Scalar(-2.950130727918164224e12), /*1.067062284288e16/3617*/
1383 Scalar(1.1646782814350067249e14), /*5.109094217170944e18/43867*/
1384 Scalar(-4.5979787224074726105e15), /*8.028576626982912e20/174611*/
1385 Scalar(1.8152105401943546773e17), /*1.5511210043330985984e23/854513*/
1386 Scalar(-7.1661652561756670113e18) /*1.6938241367317436694528e27/236364091*/
1387 };
1388
1389 const Scalar maxnum = NumTraits<Scalar>::infinity();
1390 const Scalar zero = 0.0, half = 0.5, one = 1.0;
1391 const Scalar machep = cephes_helper<Scalar>::machep();
1392 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1393
1394 if( x == one )
1395 return maxnum;
1396
1397 if( x < one )
1398 {
1399 return nan;
1400 }
1401
1402 if( q <= zero )
1403 {
1404 if(q == numext::floor(q))
1405 {
Austin Schuhc55b0172022-02-20 17:52:35 -08001406 if (x == numext::floor(x) && long(x) % 2 == 0) {
1407 return maxnum;
1408 }
1409 else {
1410 return nan;
1411 }
Austin Schuh189376f2018-12-20 22:11:15 +11001412 }
1413 p = x;
1414 r = numext::floor(p);
1415 if (p != r)
1416 return nan;
1417 }
1418
1419 /* Permit negative q but continue sum until n+q > +9 .
1420 * This case should be handled by a reflection formula.
1421 * If q<0 and x is an integer, there is a relation to
1422 * the polygamma function.
1423 */
1424 s = numext::pow( q, -x );
1425 a = q;
1426 b = zero;
1427 // Run the summation in a helper function that is specific to the floating precision
1428 if (zeta_impl_series<Scalar>::run(a, b, s, x, machep)) {
1429 return s;
1430 }
1431
1432 w = a;
1433 s += b*w/(x-one);
1434 s -= half * b;
1435 a = one;
1436 k = zero;
1437 for( i=0; i<12; i++ )
1438 {
1439 a *= x + k;
1440 b /= w;
1441 t = a*b/A[i];
1442 s = s + t;
1443 t = numext::abs(t/s);
1444 if( t < machep ) {
1445 break;
1446 }
1447 k += one;
1448 a *= x + k;
1449 b /= w;
1450 k += one;
1451 }
1452 return s;
1453 }
1454};
1455
1456/****************************************************************************
1457 * Implementation of polygamma function, requires C++11/C99 *
1458 ****************************************************************************/
1459
1460template <typename Scalar>
1461struct polygamma_retval {
1462 typedef Scalar type;
1463};
1464
1465#if !EIGEN_HAS_C99_MATH
1466
1467template <typename Scalar>
1468struct polygamma_impl {
1469 EIGEN_DEVICE_FUNC
1470 static EIGEN_STRONG_INLINE Scalar run(Scalar n, Scalar x) {
1471 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1472 THIS_TYPE_IS_NOT_SUPPORTED);
1473 return Scalar(0);
1474 }
1475};
1476
1477#else
1478
1479template <typename Scalar>
1480struct polygamma_impl {
1481 EIGEN_DEVICE_FUNC
1482 static Scalar run(Scalar n, Scalar x) {
1483 Scalar zero = 0.0, one = 1.0;
1484 Scalar nplus = n + one;
1485 const Scalar nan = NumTraits<Scalar>::quiet_NaN();
1486
Austin Schuhc55b0172022-02-20 17:52:35 -08001487 // Check that n is a non-negative integer
1488 if (numext::floor(n) != n || n < zero) {
Austin Schuh189376f2018-12-20 22:11:15 +11001489 return nan;
1490 }
Austin Schuhc55b0172022-02-20 17:52:35 -08001491 // Just return the digamma function for n = 0
Austin Schuh189376f2018-12-20 22:11:15 +11001492 else if (n == zero) {
1493 return digamma_impl<Scalar>::run(x);
1494 }
1495 // Use the same implementation as scipy
1496 else {
1497 Scalar factorial = numext::exp(lgamma_impl<Scalar>::run(nplus));
1498 return numext::pow(-one, nplus) * factorial * zeta_impl<Scalar>::run(nplus, x);
1499 }
1500 }
1501};
1502
1503#endif // EIGEN_HAS_C99_MATH
1504
1505/************************************************************************************************
1506 * Implementation of betainc (incomplete beta integral), based on Cephes but requires C++11/C99 *
1507 ************************************************************************************************/
1508
1509template <typename Scalar>
1510struct betainc_retval {
1511 typedef Scalar type;
1512};
1513
1514#if !EIGEN_HAS_C99_MATH
1515
1516template <typename Scalar>
1517struct betainc_impl {
1518 EIGEN_DEVICE_FUNC
1519 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x) {
1520 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1521 THIS_TYPE_IS_NOT_SUPPORTED);
1522 return Scalar(0);
1523 }
1524};
1525
1526#else
1527
1528template <typename Scalar>
1529struct betainc_impl {
1530 EIGEN_DEVICE_FUNC
1531 static EIGEN_STRONG_INLINE Scalar run(Scalar, Scalar, Scalar) {
1532 /* betaincf.c
1533 *
1534 * Incomplete beta integral
1535 *
1536 *
1537 * SYNOPSIS:
1538 *
1539 * float a, b, x, y, betaincf();
1540 *
1541 * y = betaincf( a, b, x );
1542 *
1543 *
1544 * DESCRIPTION:
1545 *
1546 * Returns incomplete beta integral of the arguments, evaluated
1547 * from zero to x. The function is defined as
1548 *
1549 * x
1550 * - -
1551 * | (a+b) | | a-1 b-1
1552 * ----------- | t (1-t) dt.
1553 * - - | |
1554 * | (a) | (b) -
1555 * 0
1556 *
1557 * The domain of definition is 0 <= x <= 1. In this
1558 * implementation a and b are restricted to positive values.
1559 * The integral from x to 1 may be obtained by the symmetry
1560 * relation
1561 *
1562 * 1 - betainc( a, b, x ) = betainc( b, a, 1-x ).
1563 *
1564 * The integral is evaluated by a continued fraction expansion.
1565 * If a < 1, the function calls itself recursively after a
1566 * transformation to increase a to a+1.
1567 *
1568 * ACCURACY (float):
1569 *
1570 * Tested at random points (a,b,x) with a and b in the indicated
1571 * interval and x between 0 and 1.
1572 *
1573 * arithmetic domain # trials peak rms
1574 * Relative error:
1575 * IEEE 0,30 10000 3.7e-5 5.1e-6
1576 * IEEE 0,100 10000 1.7e-4 2.5e-5
1577 * The useful domain for relative error is limited by underflow
1578 * of the single precision exponential function.
1579 * Absolute error:
1580 * IEEE 0,30 100000 2.2e-5 9.6e-7
1581 * IEEE 0,100 10000 6.5e-5 3.7e-6
1582 *
1583 * Larger errors may occur for extreme ratios of a and b.
1584 *
1585 * ACCURACY (double):
1586 * arithmetic domain # trials peak rms
1587 * IEEE 0,5 10000 6.9e-15 4.5e-16
1588 * IEEE 0,85 250000 2.2e-13 1.7e-14
1589 * IEEE 0,1000 30000 5.3e-12 6.3e-13
1590 * IEEE 0,10000 250000 9.3e-11 7.1e-12
1591 * IEEE 0,100000 10000 8.7e-10 4.8e-11
1592 * Outputs smaller than the IEEE gradual underflow threshold
1593 * were excluded from these statistics.
1594 *
1595 * ERROR MESSAGES:
1596 * message condition value returned
1597 * incbet domain x<0, x>1 nan
1598 * incbet underflow nan
1599 */
1600
1601 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, Scalar>::value == false),
1602 THIS_TYPE_IS_NOT_SUPPORTED);
1603 return Scalar(0);
1604 }
1605};
1606
1607/* Continued fraction expansion #1 for incomplete beta integral (small_branch = True)
1608 * Continued fraction expansion #2 for incomplete beta integral (small_branch = False)
1609 */
1610template <typename Scalar>
1611struct incbeta_cfe {
1612 EIGEN_DEVICE_FUNC
1613 static EIGEN_STRONG_INLINE Scalar run(Scalar a, Scalar b, Scalar x, bool small_branch) {
1614 EIGEN_STATIC_ASSERT((internal::is_same<Scalar, float>::value ||
1615 internal::is_same<Scalar, double>::value),
1616 THIS_TYPE_IS_NOT_SUPPORTED);
1617 const Scalar big = cephes_helper<Scalar>::big();
1618 const Scalar machep = cephes_helper<Scalar>::machep();
1619 const Scalar biginv = cephes_helper<Scalar>::biginv();
1620
1621 const Scalar zero = 0;
1622 const Scalar one = 1;
1623 const Scalar two = 2;
1624
1625 Scalar xk, pk, pkm1, pkm2, qk, qkm1, qkm2;
1626 Scalar k1, k2, k3, k4, k5, k6, k7, k8, k26update;
1627 Scalar ans;
1628 int n;
1629
1630 const int num_iters = (internal::is_same<Scalar, float>::value) ? 100 : 300;
1631 const Scalar thresh =
1632 (internal::is_same<Scalar, float>::value) ? machep : Scalar(3) * machep;
1633 Scalar r = (internal::is_same<Scalar, float>::value) ? zero : one;
1634
1635 if (small_branch) {
1636 k1 = a;
1637 k2 = a + b;
1638 k3 = a;
1639 k4 = a + one;
1640 k5 = one;
1641 k6 = b - one;
1642 k7 = k4;
1643 k8 = a + two;
1644 k26update = one;
1645 } else {
1646 k1 = a;
1647 k2 = b - one;
1648 k3 = a;
1649 k4 = a + one;
1650 k5 = one;
1651 k6 = a + b;
1652 k7 = a + one;
1653 k8 = a + two;
1654 k26update = -one;
1655 x = x / (one - x);
1656 }
1657
1658 pkm2 = zero;
1659 qkm2 = one;
1660 pkm1 = one;
1661 qkm1 = one;
1662 ans = one;
1663 n = 0;
1664
1665 do {
1666 xk = -(x * k1 * k2) / (k3 * k4);
1667 pk = pkm1 + pkm2 * xk;
1668 qk = qkm1 + qkm2 * xk;
1669 pkm2 = pkm1;
1670 pkm1 = pk;
1671 qkm2 = qkm1;
1672 qkm1 = qk;
1673
1674 xk = (x * k5 * k6) / (k7 * k8);
1675 pk = pkm1 + pkm2 * xk;
1676 qk = qkm1 + qkm2 * xk;
1677 pkm2 = pkm1;
1678 pkm1 = pk;
1679 qkm2 = qkm1;
1680 qkm1 = qk;
1681
1682 if (qk != zero) {
1683 r = pk / qk;
1684 if (numext::abs(ans - r) < numext::abs(r) * thresh) {
1685 return r;
1686 }
1687 ans = r;
1688 }
1689
1690 k1 += one;
1691 k2 += k26update;
1692 k3 += two;
1693 k4 += two;
1694 k5 += one;
1695 k6 -= k26update;
1696 k7 += two;
1697 k8 += two;
1698
1699 if ((numext::abs(qk) + numext::abs(pk)) > big) {
1700 pkm2 *= biginv;
1701 pkm1 *= biginv;
1702 qkm2 *= biginv;
1703 qkm1 *= biginv;
1704 }
1705 if ((numext::abs(qk) < biginv) || (numext::abs(pk) < biginv)) {
1706 pkm2 *= big;
1707 pkm1 *= big;
1708 qkm2 *= big;
1709 qkm1 *= big;
1710 }
1711 } while (++n < num_iters);
1712
1713 return ans;
1714 }
1715};
1716
1717/* Helper functions depending on the Scalar type */
1718template <typename Scalar>
1719struct betainc_helper {};
1720
1721template <>
1722struct betainc_helper<float> {
1723 /* Core implementation, assumes a large (> 1.0) */
1724 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE float incbsa(float aa, float bb,
1725 float xx) {
1726 float ans, a, b, t, x, onemx;
1727 bool reversed_a_b = false;
1728
1729 onemx = 1.0f - xx;
1730
1731 /* see if x is greater than the mean */
1732 if (xx > (aa / (aa + bb))) {
1733 reversed_a_b = true;
1734 a = bb;
1735 b = aa;
1736 t = xx;
1737 x = onemx;
1738 } else {
1739 a = aa;
1740 b = bb;
1741 t = onemx;
1742 x = xx;
1743 }
1744
1745 /* Choose expansion for optimal convergence */
1746 if (b > 10.0f) {
1747 if (numext::abs(b * x / a) < 0.3f) {
1748 t = betainc_helper<float>::incbps(a, b, x);
1749 if (reversed_a_b) t = 1.0f - t;
1750 return t;
1751 }
1752 }
1753
1754 ans = x * (a + b - 2.0f) / (a - 1.0f);
1755 if (ans < 1.0f) {
1756 ans = incbeta_cfe<float>::run(a, b, x, true /* small_branch */);
1757 t = b * numext::log(t);
1758 } else {
1759 ans = incbeta_cfe<float>::run(a, b, x, false /* small_branch */);
1760 t = (b - 1.0f) * numext::log(t);
1761 }
1762
1763 t += a * numext::log(x) + lgamma_impl<float>::run(a + b) -
1764 lgamma_impl<float>::run(a) - lgamma_impl<float>::run(b);
1765 t += numext::log(ans / a);
1766 t = numext::exp(t);
1767
1768 if (reversed_a_b) t = 1.0f - t;
1769 return t;
1770 }
1771
1772 EIGEN_DEVICE_FUNC
1773 static EIGEN_STRONG_INLINE float incbps(float a, float b, float x) {
1774 float t, u, y, s;
1775 const float machep = cephes_helper<float>::machep();
1776
1777 y = a * numext::log(x) + (b - 1.0f) * numext::log1p(-x) - numext::log(a);
1778 y -= lgamma_impl<float>::run(a) + lgamma_impl<float>::run(b);
1779 y += lgamma_impl<float>::run(a + b);
1780
1781 t = x / (1.0f - x);
1782 s = 0.0f;
1783 u = 1.0f;
1784 do {
1785 b -= 1.0f;
1786 if (b == 0.0f) {
1787 break;
1788 }
1789 a += 1.0f;
1790 u *= t * b / a;
1791 s += u;
1792 } while (numext::abs(u) > machep);
1793
1794 return numext::exp(y) * (1.0f + s);
1795 }
1796};
1797
1798template <>
1799struct betainc_impl<float> {
1800 EIGEN_DEVICE_FUNC
1801 static float run(float a, float b, float x) {
1802 const float nan = NumTraits<float>::quiet_NaN();
1803 float ans, t;
1804
1805 if (a <= 0.0f) return nan;
1806 if (b <= 0.0f) return nan;
1807 if ((x <= 0.0f) || (x >= 1.0f)) {
1808 if (x == 0.0f) return 0.0f;
1809 if (x == 1.0f) return 1.0f;
1810 // mtherr("betaincf", DOMAIN);
1811 return nan;
1812 }
1813
1814 /* transformation for small aa */
1815 if (a <= 1.0f) {
1816 ans = betainc_helper<float>::incbsa(a + 1.0f, b, x);
1817 t = a * numext::log(x) + b * numext::log1p(-x) +
1818 lgamma_impl<float>::run(a + b) - lgamma_impl<float>::run(a + 1.0f) -
1819 lgamma_impl<float>::run(b);
1820 return (ans + numext::exp(t));
1821 } else {
1822 return betainc_helper<float>::incbsa(a, b, x);
1823 }
1824 }
1825};
1826
1827template <>
1828struct betainc_helper<double> {
1829 EIGEN_DEVICE_FUNC
1830 static EIGEN_STRONG_INLINE double incbps(double a, double b, double x) {
1831 const double machep = cephes_helper<double>::machep();
1832
1833 double s, t, u, v, n, t1, z, ai;
1834
1835 ai = 1.0 / a;
1836 u = (1.0 - b) * x;
1837 v = u / (a + 1.0);
1838 t1 = v;
1839 t = u;
1840 n = 2.0;
1841 s = 0.0;
1842 z = machep * ai;
1843 while (numext::abs(v) > z) {
1844 u = (n - b) * x / n;
1845 t *= u;
1846 v = t / (a + n);
1847 s += v;
1848 n += 1.0;
1849 }
1850 s += t1;
1851 s += ai;
1852
1853 u = a * numext::log(x);
1854 // TODO: gamma() is not directly implemented in Eigen.
1855 /*
1856 if ((a + b) < maxgam && numext::abs(u) < maxlog) {
1857 t = gamma(a + b) / (gamma(a) * gamma(b));
1858 s = s * t * pow(x, a);
Austin Schuhc55b0172022-02-20 17:52:35 -08001859 }
Austin Schuh189376f2018-12-20 22:11:15 +11001860 */
1861 t = lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
1862 lgamma_impl<double>::run(b) + u + numext::log(s);
1863 return s = numext::exp(t);
1864 }
1865};
1866
1867template <>
1868struct betainc_impl<double> {
1869 EIGEN_DEVICE_FUNC
1870 static double run(double aa, double bb, double xx) {
1871 const double nan = NumTraits<double>::quiet_NaN();
1872 const double machep = cephes_helper<double>::machep();
1873 // const double maxgam = 171.624376956302725;
1874
1875 double a, b, t, x, xc, w, y;
1876 bool reversed_a_b = false;
1877
1878 if (aa <= 0.0 || bb <= 0.0) {
1879 return nan; // goto domerr;
1880 }
1881
1882 if ((xx <= 0.0) || (xx >= 1.0)) {
1883 if (xx == 0.0) return (0.0);
1884 if (xx == 1.0) return (1.0);
1885 // mtherr("incbet", DOMAIN);
1886 return nan;
1887 }
1888
1889 if ((bb * xx) <= 1.0 && xx <= 0.95) {
1890 return betainc_helper<double>::incbps(aa, bb, xx);
1891 }
1892
1893 w = 1.0 - xx;
1894
1895 /* Reverse a and b if x is greater than the mean. */
1896 if (xx > (aa / (aa + bb))) {
1897 reversed_a_b = true;
1898 a = bb;
1899 b = aa;
1900 xc = xx;
1901 x = w;
1902 } else {
1903 a = aa;
1904 b = bb;
1905 xc = w;
1906 x = xx;
1907 }
1908
1909 if (reversed_a_b && (b * x) <= 1.0 && x <= 0.95) {
1910 t = betainc_helper<double>::incbps(a, b, x);
1911 if (t <= machep) {
1912 t = 1.0 - machep;
1913 } else {
1914 t = 1.0 - t;
1915 }
1916 return t;
1917 }
1918
1919 /* Choose expansion for better convergence. */
1920 y = x * (a + b - 2.0) - (a - 1.0);
1921 if (y < 0.0) {
1922 w = incbeta_cfe<double>::run(a, b, x, true /* small_branch */);
1923 } else {
1924 w = incbeta_cfe<double>::run(a, b, x, false /* small_branch */) / xc;
1925 }
1926
1927 /* Multiply w by the factor
1928 a b _ _ _
1929 x (1-x) | (a+b) / ( a | (a) | (b) ) . */
1930
1931 y = a * numext::log(x);
1932 t = b * numext::log(xc);
1933 // TODO: gamma is not directly implemented in Eigen.
1934 /*
1935 if ((a + b) < maxgam && numext::abs(y) < maxlog && numext::abs(t) < maxlog)
1936 {
1937 t = pow(xc, b);
1938 t *= pow(x, a);
1939 t /= a;
1940 t *= w;
1941 t *= gamma(a + b) / (gamma(a) * gamma(b));
1942 } else {
1943 */
1944 /* Resort to logarithms. */
1945 y += t + lgamma_impl<double>::run(a + b) - lgamma_impl<double>::run(a) -
1946 lgamma_impl<double>::run(b);
1947 y += numext::log(w / a);
1948 t = numext::exp(y);
1949
1950 /* } */
1951 // done:
1952
1953 if (reversed_a_b) {
1954 if (t <= machep) {
1955 t = 1.0 - machep;
1956 } else {
1957 t = 1.0 - t;
1958 }
1959 }
1960 return t;
1961 }
1962};
1963
1964#endif // EIGEN_HAS_C99_MATH
1965
1966} // end namespace internal
1967
1968namespace numext {
1969
1970template <typename Scalar>
1971EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(lgamma, Scalar)
1972 lgamma(const Scalar& x) {
1973 return EIGEN_MATHFUNC_IMPL(lgamma, Scalar)::run(x);
1974}
1975
1976template <typename Scalar>
1977EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(digamma, Scalar)
1978 digamma(const Scalar& x) {
1979 return EIGEN_MATHFUNC_IMPL(digamma, Scalar)::run(x);
1980}
1981
1982template <typename Scalar>
1983EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(zeta, Scalar)
1984zeta(const Scalar& x, const Scalar& q) {
1985 return EIGEN_MATHFUNC_IMPL(zeta, Scalar)::run(x, q);
1986}
1987
1988template <typename Scalar>
1989EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(polygamma, Scalar)
1990polygamma(const Scalar& n, const Scalar& x) {
1991 return EIGEN_MATHFUNC_IMPL(polygamma, Scalar)::run(n, x);
1992}
1993
1994template <typename Scalar>
1995EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erf, Scalar)
1996 erf(const Scalar& x) {
1997 return EIGEN_MATHFUNC_IMPL(erf, Scalar)::run(x);
1998}
1999
2000template <typename Scalar>
2001EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(erfc, Scalar)
2002 erfc(const Scalar& x) {
2003 return EIGEN_MATHFUNC_IMPL(erfc, Scalar)::run(x);
2004}
2005
2006template <typename Scalar>
Austin Schuhc55b0172022-02-20 17:52:35 -08002007EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(ndtri, Scalar)
2008 ndtri(const Scalar& x) {
2009 return EIGEN_MATHFUNC_IMPL(ndtri, Scalar)::run(x);
2010}
2011
2012template <typename Scalar>
Austin Schuh189376f2018-12-20 22:11:15 +11002013EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma, Scalar)
2014 igamma(const Scalar& a, const Scalar& x) {
2015 return EIGEN_MATHFUNC_IMPL(igamma, Scalar)::run(a, x);
2016}
2017
2018template <typename Scalar>
Austin Schuhc55b0172022-02-20 17:52:35 -08002019EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igamma_der_a, Scalar)
2020 igamma_der_a(const Scalar& a, const Scalar& x) {
2021 return EIGEN_MATHFUNC_IMPL(igamma_der_a, Scalar)::run(a, x);
2022}
2023
2024template <typename Scalar>
2025EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(gamma_sample_der_alpha, Scalar)
2026 gamma_sample_der_alpha(const Scalar& a, const Scalar& x) {
2027 return EIGEN_MATHFUNC_IMPL(gamma_sample_der_alpha, Scalar)::run(a, x);
2028}
2029
2030template <typename Scalar>
Austin Schuh189376f2018-12-20 22:11:15 +11002031EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(igammac, Scalar)
2032 igammac(const Scalar& a, const Scalar& x) {
2033 return EIGEN_MATHFUNC_IMPL(igammac, Scalar)::run(a, x);
2034}
2035
2036template <typename Scalar>
2037EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(betainc, Scalar)
2038 betainc(const Scalar& a, const Scalar& b, const Scalar& x) {
2039 return EIGEN_MATHFUNC_IMPL(betainc, Scalar)::run(a, b, x);
2040}
2041
2042} // end namespace numext
Austin Schuh189376f2018-12-20 22:11:15 +11002043} // end namespace Eigen
2044
2045#endif // EIGEN_SPECIAL_FUNCTIONS_H