blob: 0f166e35f01f2ab4951404693366b191d0ebe71b [file] [log] [blame]
Brian Silverman72890c22015-09-19 14:37:37 -04001// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_AUTODIFF_SCALAR_H
11#define EIGEN_AUTODIFF_SCALAR_H
12
13namespace Eigen {
14
15namespace internal {
16
17template<typename A, typename B>
18struct make_coherent_impl {
19 static void run(A&, B&) {}
20};
21
22// resize a to match b is a.size()==0, and conversely.
23template<typename A, typename B>
24void make_coherent(const A& a, const B&b)
25{
26 make_coherent_impl<A,B>::run(a.const_cast_derived(), b.const_cast_derived());
27}
28
Austin Schuhc55b0172022-02-20 17:52:35 -080029template<typename DerivativeType, bool Enable> struct auto_diff_special_op;
Brian Silverman72890c22015-09-19 14:37:37 -040030
31} // end namespace internal
32
Austin Schuhc55b0172022-02-20 17:52:35 -080033template<typename DerivativeType> class AutoDiffScalar;
Austin Schuh189376f2018-12-20 22:11:15 +110034
35template<typename NewDerType>
36inline AutoDiffScalar<NewDerType> MakeAutoDiffScalar(const typename NewDerType::Scalar& value, const NewDerType &der) {
37 return AutoDiffScalar<NewDerType>(value,der);
38}
39
Brian Silverman72890c22015-09-19 14:37:37 -040040/** \class AutoDiffScalar
Austin Schuhc55b0172022-02-20 17:52:35 -080041 * \brief A scalar type replacement with automatic differentiation capability
Brian Silverman72890c22015-09-19 14:37:37 -040042 *
Austin Schuhc55b0172022-02-20 17:52:35 -080043 * \param DerivativeType the vector type used to store/represent the derivatives. The base scalar type
Brian Silverman72890c22015-09-19 14:37:37 -040044 * as well as the number of derivatives to compute are determined from this type.
45 * Typical choices include, e.g., \c Vector4f for 4 derivatives, or \c VectorXf
46 * if the number of derivatives is not known at compile time, and/or, the number
47 * of derivatives is large.
Austin Schuhc55b0172022-02-20 17:52:35 -080048 * Note that DerivativeType can also be a reference (e.g., \c VectorXf&) to wrap a
Brian Silverman72890c22015-09-19 14:37:37 -040049 * existing vector into an AutoDiffScalar.
Austin Schuhc55b0172022-02-20 17:52:35 -080050 * Finally, DerivativeType can also be any Eigen compatible expression.
Brian Silverman72890c22015-09-19 14:37:37 -040051 *
52 * This class represents a scalar value while tracking its respective derivatives using Eigen's expression
53 * template mechanism.
54 *
55 * It supports the following list of global math function:
56 * - std::abs, std::sqrt, std::pow, std::exp, std::log, std::sin, std::cos,
57 * - internal::abs, internal::sqrt, numext::pow, internal::exp, internal::log, internal::sin, internal::cos,
58 * - internal::conj, internal::real, internal::imag, numext::abs2.
59 *
60 * AutoDiffScalar can be used as the scalar type of an Eigen::Matrix object. However,
61 * in that case, the expression template mechanism only occurs at the top Matrix level,
62 * while derivatives are computed right away.
63 *
64 */
65
Austin Schuhc55b0172022-02-20 17:52:35 -080066template<typename DerivativeType>
Brian Silverman72890c22015-09-19 14:37:37 -040067class AutoDiffScalar
68 : public internal::auto_diff_special_op
Austin Schuhc55b0172022-02-20 17:52:35 -080069 <DerivativeType, !internal::is_same<typename internal::traits<typename internal::remove_all<DerivativeType>::type>::Scalar,
70 typename NumTraits<typename internal::traits<typename internal::remove_all<DerivativeType>::type>::Scalar>::Real>::value>
Brian Silverman72890c22015-09-19 14:37:37 -040071{
72 public:
73 typedef internal::auto_diff_special_op
Austin Schuhc55b0172022-02-20 17:52:35 -080074 <DerivativeType, !internal::is_same<typename internal::traits<typename internal::remove_all<DerivativeType>::type>::Scalar,
75 typename NumTraits<typename internal::traits<typename internal::remove_all<DerivativeType>::type>::Scalar>::Real>::value> Base;
76 typedef typename internal::remove_all<DerivativeType>::type DerType;
Brian Silverman72890c22015-09-19 14:37:37 -040077 typedef typename internal::traits<DerType>::Scalar Scalar;
78 typedef typename NumTraits<Scalar>::Real Real;
79
80 using Base::operator+;
81 using Base::operator*;
82
83 /** Default constructor without any initialization. */
84 AutoDiffScalar() {}
85
86 /** Constructs an active scalar from its \a value,
87 and initializes the \a nbDer derivatives such that it corresponds to the \a derNumber -th variable */
88 AutoDiffScalar(const Scalar& value, int nbDer, int derNumber)
89 : m_value(value), m_derivatives(DerType::Zero(nbDer))
90 {
91 m_derivatives.coeffRef(derNumber) = Scalar(1);
92 }
93
94 /** Conversion from a scalar constant to an active scalar.
95 * The derivatives are set to zero. */
96 /*explicit*/ AutoDiffScalar(const Real& value)
97 : m_value(value)
98 {
99 if(m_derivatives.size()>0)
100 m_derivatives.setZero();
101 }
102
103 /** Constructs an active scalar from its \a value and derivatives \a der */
104 AutoDiffScalar(const Scalar& value, const DerType& der)
105 : m_value(value), m_derivatives(der)
106 {}
107
108 template<typename OtherDerType>
Austin Schuh189376f2018-12-20 22:11:15 +1100109 AutoDiffScalar(const AutoDiffScalar<OtherDerType>& other
110#ifndef EIGEN_PARSED_BY_DOXYGEN
111 , typename internal::enable_if<
112 internal::is_same<Scalar, typename internal::traits<typename internal::remove_all<OtherDerType>::type>::Scalar>::value
113 && internal::is_convertible<OtherDerType,DerType>::value , void*>::type = 0
114#endif
115 )
Brian Silverman72890c22015-09-19 14:37:37 -0400116 : m_value(other.value()), m_derivatives(other.derivatives())
117 {}
118
119 friend std::ostream & operator << (std::ostream & s, const AutoDiffScalar& a)
120 {
121 return s << a.value();
122 }
123
124 AutoDiffScalar(const AutoDiffScalar& other)
125 : m_value(other.value()), m_derivatives(other.derivatives())
126 {}
127
128 template<typename OtherDerType>
129 inline AutoDiffScalar& operator=(const AutoDiffScalar<OtherDerType>& other)
130 {
131 m_value = other.value();
132 m_derivatives = other.derivatives();
133 return *this;
134 }
135
136 inline AutoDiffScalar& operator=(const AutoDiffScalar& other)
137 {
138 m_value = other.value();
139 m_derivatives = other.derivatives();
140 return *this;
141 }
142
Austin Schuh189376f2018-12-20 22:11:15 +1100143 inline AutoDiffScalar& operator=(const Scalar& other)
144 {
145 m_value = other;
146 if(m_derivatives.size()>0)
147 m_derivatives.setZero();
148 return *this;
149 }
150
Brian Silverman72890c22015-09-19 14:37:37 -0400151// inline operator const Scalar& () const { return m_value; }
152// inline operator Scalar& () { return m_value; }
153
154 inline const Scalar& value() const { return m_value; }
155 inline Scalar& value() { return m_value; }
156
157 inline const DerType& derivatives() const { return m_derivatives; }
158 inline DerType& derivatives() { return m_derivatives; }
159
160 inline bool operator< (const Scalar& other) const { return m_value < other; }
161 inline bool operator<=(const Scalar& other) const { return m_value <= other; }
162 inline bool operator> (const Scalar& other) const { return m_value > other; }
163 inline bool operator>=(const Scalar& other) const { return m_value >= other; }
164 inline bool operator==(const Scalar& other) const { return m_value == other; }
165 inline bool operator!=(const Scalar& other) const { return m_value != other; }
166
167 friend inline bool operator< (const Scalar& a, const AutoDiffScalar& b) { return a < b.value(); }
168 friend inline bool operator<=(const Scalar& a, const AutoDiffScalar& b) { return a <= b.value(); }
169 friend inline bool operator> (const Scalar& a, const AutoDiffScalar& b) { return a > b.value(); }
170 friend inline bool operator>=(const Scalar& a, const AutoDiffScalar& b) { return a >= b.value(); }
171 friend inline bool operator==(const Scalar& a, const AutoDiffScalar& b) { return a == b.value(); }
172 friend inline bool operator!=(const Scalar& a, const AutoDiffScalar& b) { return a != b.value(); }
173
174 template<typename OtherDerType> inline bool operator< (const AutoDiffScalar<OtherDerType>& b) const { return m_value < b.value(); }
175 template<typename OtherDerType> inline bool operator<=(const AutoDiffScalar<OtherDerType>& b) const { return m_value <= b.value(); }
176 template<typename OtherDerType> inline bool operator> (const AutoDiffScalar<OtherDerType>& b) const { return m_value > b.value(); }
177 template<typename OtherDerType> inline bool operator>=(const AutoDiffScalar<OtherDerType>& b) const { return m_value >= b.value(); }
178 template<typename OtherDerType> inline bool operator==(const AutoDiffScalar<OtherDerType>& b) const { return m_value == b.value(); }
179 template<typename OtherDerType> inline bool operator!=(const AutoDiffScalar<OtherDerType>& b) const { return m_value != b.value(); }
180
181 inline const AutoDiffScalar<DerType&> operator+(const Scalar& other) const
182 {
183 return AutoDiffScalar<DerType&>(m_value + other, m_derivatives);
184 }
185
186 friend inline const AutoDiffScalar<DerType&> operator+(const Scalar& a, const AutoDiffScalar& b)
187 {
188 return AutoDiffScalar<DerType&>(a + b.value(), b.derivatives());
189 }
190
191// inline const AutoDiffScalar<DerType&> operator+(const Real& other) const
192// {
193// return AutoDiffScalar<DerType&>(m_value + other, m_derivatives);
194// }
195
196// friend inline const AutoDiffScalar<DerType&> operator+(const Real& a, const AutoDiffScalar& b)
197// {
198// return AutoDiffScalar<DerType&>(a + b.value(), b.derivatives());
199// }
200
201 inline AutoDiffScalar& operator+=(const Scalar& other)
202 {
203 value() += other;
204 return *this;
205 }
206
207 template<typename OtherDerType>
208 inline const AutoDiffScalar<CwiseBinaryOp<internal::scalar_sum_op<Scalar>,const DerType,const typename internal::remove_all<OtherDerType>::type> >
209 operator+(const AutoDiffScalar<OtherDerType>& other) const
210 {
211 internal::make_coherent(m_derivatives, other.derivatives());
212 return AutoDiffScalar<CwiseBinaryOp<internal::scalar_sum_op<Scalar>,const DerType,const typename internal::remove_all<OtherDerType>::type> >(
213 m_value + other.value(),
214 m_derivatives + other.derivatives());
215 }
216
217 template<typename OtherDerType>
218 inline AutoDiffScalar&
219 operator+=(const AutoDiffScalar<OtherDerType>& other)
220 {
221 (*this) = (*this) + other;
222 return *this;
223 }
224
225 inline const AutoDiffScalar<DerType&> operator-(const Scalar& b) const
226 {
227 return AutoDiffScalar<DerType&>(m_value - b, m_derivatives);
228 }
229
230 friend inline const AutoDiffScalar<CwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const DerType> >
231 operator-(const Scalar& a, const AutoDiffScalar& b)
232 {
233 return AutoDiffScalar<CwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const DerType> >
234 (a - b.value(), -b.derivatives());
235 }
236
237 inline AutoDiffScalar& operator-=(const Scalar& other)
238 {
239 value() -= other;
240 return *this;
241 }
242
243 template<typename OtherDerType>
244 inline const AutoDiffScalar<CwiseBinaryOp<internal::scalar_difference_op<Scalar>, const DerType,const typename internal::remove_all<OtherDerType>::type> >
245 operator-(const AutoDiffScalar<OtherDerType>& other) const
246 {
247 internal::make_coherent(m_derivatives, other.derivatives());
248 return AutoDiffScalar<CwiseBinaryOp<internal::scalar_difference_op<Scalar>, const DerType,const typename internal::remove_all<OtherDerType>::type> >(
249 m_value - other.value(),
250 m_derivatives - other.derivatives());
251 }
252
253 template<typename OtherDerType>
254 inline AutoDiffScalar&
255 operator-=(const AutoDiffScalar<OtherDerType>& other)
256 {
257 *this = *this - other;
258 return *this;
259 }
260
261 inline const AutoDiffScalar<CwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const DerType> >
262 operator-() const
263 {
264 return AutoDiffScalar<CwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const DerType> >(
265 -m_value,
266 -m_derivatives);
267 }
268
Austin Schuh189376f2018-12-20 22:11:15 +1100269 inline const AutoDiffScalar<EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(DerType,Scalar,product) >
Brian Silverman72890c22015-09-19 14:37:37 -0400270 operator*(const Scalar& other) const
271 {
Austin Schuh189376f2018-12-20 22:11:15 +1100272 return MakeAutoDiffScalar(m_value * other, m_derivatives * other);
Brian Silverman72890c22015-09-19 14:37:37 -0400273 }
274
Austin Schuh189376f2018-12-20 22:11:15 +1100275 friend inline const AutoDiffScalar<EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(DerType,Scalar,product) >
Brian Silverman72890c22015-09-19 14:37:37 -0400276 operator*(const Scalar& other, const AutoDiffScalar& a)
277 {
Austin Schuh189376f2018-12-20 22:11:15 +1100278 return MakeAutoDiffScalar(a.value() * other, a.derivatives() * other);
Brian Silverman72890c22015-09-19 14:37:37 -0400279 }
280
281// inline const AutoDiffScalar<typename CwiseUnaryOp<internal::scalar_multiple_op<Real>, DerType>::Type >
282// operator*(const Real& other) const
283// {
284// return AutoDiffScalar<typename CwiseUnaryOp<internal::scalar_multiple_op<Real>, DerType>::Type >(
285// m_value * other,
286// (m_derivatives * other));
287// }
288//
289// friend inline const AutoDiffScalar<typename CwiseUnaryOp<internal::scalar_multiple_op<Real>, DerType>::Type >
290// operator*(const Real& other, const AutoDiffScalar& a)
291// {
292// return AutoDiffScalar<typename CwiseUnaryOp<internal::scalar_multiple_op<Real>, DerType>::Type >(
293// a.value() * other,
294// a.derivatives() * other);
295// }
296
Austin Schuh189376f2018-12-20 22:11:15 +1100297 inline const AutoDiffScalar<EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(DerType,Scalar,product) >
Brian Silverman72890c22015-09-19 14:37:37 -0400298 operator/(const Scalar& other) const
299 {
Austin Schuh189376f2018-12-20 22:11:15 +1100300 return MakeAutoDiffScalar(m_value / other, (m_derivatives * (Scalar(1)/other)));
Brian Silverman72890c22015-09-19 14:37:37 -0400301 }
302
Austin Schuh189376f2018-12-20 22:11:15 +1100303 friend inline const AutoDiffScalar<EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(DerType,Scalar,product) >
Brian Silverman72890c22015-09-19 14:37:37 -0400304 operator/(const Scalar& other, const AutoDiffScalar& a)
305 {
Austin Schuh189376f2018-12-20 22:11:15 +1100306 return MakeAutoDiffScalar(other / a.value(), a.derivatives() * (Scalar(-other) / (a.value()*a.value())));
Brian Silverman72890c22015-09-19 14:37:37 -0400307 }
308
309// inline const AutoDiffScalar<typename CwiseUnaryOp<internal::scalar_multiple_op<Real>, DerType>::Type >
310// operator/(const Real& other) const
311// {
312// return AutoDiffScalar<typename CwiseUnaryOp<internal::scalar_multiple_op<Real>, DerType>::Type >(
313// m_value / other,
314// (m_derivatives * (Real(1)/other)));
315// }
316//
317// friend inline const AutoDiffScalar<typename CwiseUnaryOp<internal::scalar_multiple_op<Real>, DerType>::Type >
318// operator/(const Real& other, const AutoDiffScalar& a)
319// {
320// return AutoDiffScalar<typename CwiseUnaryOp<internal::scalar_multiple_op<Real>, DerType>::Type >(
321// other / a.value(),
322// a.derivatives() * (-Real(1)/other));
323// }
324
325 template<typename OtherDerType>
Austin Schuh189376f2018-12-20 22:11:15 +1100326 inline const AutoDiffScalar<EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(
327 CwiseBinaryOp<internal::scalar_difference_op<Scalar> EIGEN_COMMA
328 const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(DerType,Scalar,product) EIGEN_COMMA
329 const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(typename internal::remove_all<OtherDerType>::type,Scalar,product) >,Scalar,product) >
Brian Silverman72890c22015-09-19 14:37:37 -0400330 operator/(const AutoDiffScalar<OtherDerType>& other) const
331 {
332 internal::make_coherent(m_derivatives, other.derivatives());
Austin Schuh189376f2018-12-20 22:11:15 +1100333 return MakeAutoDiffScalar(
Brian Silverman72890c22015-09-19 14:37:37 -0400334 m_value / other.value(),
Austin Schuh189376f2018-12-20 22:11:15 +1100335 ((m_derivatives * other.value()) - (other.derivatives() * m_value))
Brian Silverman72890c22015-09-19 14:37:37 -0400336 * (Scalar(1)/(other.value()*other.value())));
337 }
338
339 template<typename OtherDerType>
340 inline const AutoDiffScalar<CwiseBinaryOp<internal::scalar_sum_op<Scalar>,
Austin Schuh189376f2018-12-20 22:11:15 +1100341 const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(DerType,Scalar,product),
342 const EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(typename internal::remove_all<OtherDerType>::type,Scalar,product) > >
Brian Silverman72890c22015-09-19 14:37:37 -0400343 operator*(const AutoDiffScalar<OtherDerType>& other) const
344 {
345 internal::make_coherent(m_derivatives, other.derivatives());
Austin Schuh189376f2018-12-20 22:11:15 +1100346 return MakeAutoDiffScalar(
Brian Silverman72890c22015-09-19 14:37:37 -0400347 m_value * other.value(),
Austin Schuh189376f2018-12-20 22:11:15 +1100348 (m_derivatives * other.value()) + (other.derivatives() * m_value));
Brian Silverman72890c22015-09-19 14:37:37 -0400349 }
350
351 inline AutoDiffScalar& operator*=(const Scalar& other)
352 {
353 *this = *this * other;
354 return *this;
355 }
356
357 template<typename OtherDerType>
358 inline AutoDiffScalar& operator*=(const AutoDiffScalar<OtherDerType>& other)
359 {
360 *this = *this * other;
361 return *this;
362 }
363
364 inline AutoDiffScalar& operator/=(const Scalar& other)
365 {
366 *this = *this / other;
367 return *this;
368 }
369
370 template<typename OtherDerType>
371 inline AutoDiffScalar& operator/=(const AutoDiffScalar<OtherDerType>& other)
372 {
373 *this = *this / other;
374 return *this;
375 }
376
377 protected:
378 Scalar m_value;
379 DerType m_derivatives;
380
381};
382
383namespace internal {
384
Austin Schuhc55b0172022-02-20 17:52:35 -0800385template<typename DerivativeType>
386struct auto_diff_special_op<DerivativeType, true>
387// : auto_diff_scalar_op<DerivativeType, typename NumTraits<Scalar>::Real,
Brian Silverman72890c22015-09-19 14:37:37 -0400388// is_same<Scalar,typename NumTraits<Scalar>::Real>::value>
389{
Austin Schuhc55b0172022-02-20 17:52:35 -0800390 typedef typename remove_all<DerivativeType>::type DerType;
Brian Silverman72890c22015-09-19 14:37:37 -0400391 typedef typename traits<DerType>::Scalar Scalar;
392 typedef typename NumTraits<Scalar>::Real Real;
393
Austin Schuhc55b0172022-02-20 17:52:35 -0800394// typedef auto_diff_scalar_op<DerivativeType, typename NumTraits<Scalar>::Real,
Brian Silverman72890c22015-09-19 14:37:37 -0400395// is_same<Scalar,typename NumTraits<Scalar>::Real>::value> Base;
396
397// using Base::operator+;
398// using Base::operator+=;
399// using Base::operator-;
400// using Base::operator-=;
401// using Base::operator*;
402// using Base::operator*=;
403
Austin Schuhc55b0172022-02-20 17:52:35 -0800404 const AutoDiffScalar<DerivativeType>& derived() const { return *static_cast<const AutoDiffScalar<DerivativeType>*>(this); }
405 AutoDiffScalar<DerivativeType>& derived() { return *static_cast<AutoDiffScalar<DerivativeType>*>(this); }
Brian Silverman72890c22015-09-19 14:37:37 -0400406
407
408 inline const AutoDiffScalar<DerType&> operator+(const Real& other) const
409 {
410 return AutoDiffScalar<DerType&>(derived().value() + other, derived().derivatives());
411 }
412
Austin Schuhc55b0172022-02-20 17:52:35 -0800413 friend inline const AutoDiffScalar<DerType&> operator+(const Real& a, const AutoDiffScalar<DerivativeType>& b)
Brian Silverman72890c22015-09-19 14:37:37 -0400414 {
415 return AutoDiffScalar<DerType&>(a + b.value(), b.derivatives());
416 }
417
Austin Schuhc55b0172022-02-20 17:52:35 -0800418 inline AutoDiffScalar<DerivativeType>& operator+=(const Real& other)
Brian Silverman72890c22015-09-19 14:37:37 -0400419 {
420 derived().value() += other;
421 return derived();
422 }
423
424
Austin Schuh189376f2018-12-20 22:11:15 +1100425 inline const AutoDiffScalar<typename CwiseUnaryOp<bind2nd_op<scalar_product_op<Scalar,Real> >, DerType>::Type >
Brian Silverman72890c22015-09-19 14:37:37 -0400426 operator*(const Real& other) const
427 {
Austin Schuh189376f2018-12-20 22:11:15 +1100428 return AutoDiffScalar<typename CwiseUnaryOp<bind2nd_op<scalar_product_op<Scalar,Real> >, DerType>::Type >(
Brian Silverman72890c22015-09-19 14:37:37 -0400429 derived().value() * other,
430 derived().derivatives() * other);
431 }
432
Austin Schuh189376f2018-12-20 22:11:15 +1100433 friend inline const AutoDiffScalar<typename CwiseUnaryOp<bind1st_op<scalar_product_op<Real,Scalar> >, DerType>::Type >
Austin Schuhc55b0172022-02-20 17:52:35 -0800434 operator*(const Real& other, const AutoDiffScalar<DerivativeType>& a)
Brian Silverman72890c22015-09-19 14:37:37 -0400435 {
Austin Schuh189376f2018-12-20 22:11:15 +1100436 return AutoDiffScalar<typename CwiseUnaryOp<bind1st_op<scalar_product_op<Real,Scalar> >, DerType>::Type >(
Brian Silverman72890c22015-09-19 14:37:37 -0400437 a.value() * other,
438 a.derivatives() * other);
439 }
440
Austin Schuhc55b0172022-02-20 17:52:35 -0800441 inline AutoDiffScalar<DerivativeType>& operator*=(const Scalar& other)
Brian Silverman72890c22015-09-19 14:37:37 -0400442 {
443 *this = *this * other;
444 return derived();
445 }
446};
447
Austin Schuhc55b0172022-02-20 17:52:35 -0800448template<typename DerivativeType>
449struct auto_diff_special_op<DerivativeType, false>
Brian Silverman72890c22015-09-19 14:37:37 -0400450{
451 void operator*() const;
452 void operator-() const;
453 void operator+() const;
454};
455
Austin Schuhc55b0172022-02-20 17:52:35 -0800456template<typename BinOp, typename A, typename B, typename RefType>
457void make_coherent_expression(CwiseBinaryOp<BinOp,A,B> xpr, const RefType &ref)
458{
459 make_coherent(xpr.const_cast_derived().lhs(), ref);
460 make_coherent(xpr.const_cast_derived().rhs(), ref);
461}
462
463template<typename UnaryOp, typename A, typename RefType>
464void make_coherent_expression(const CwiseUnaryOp<UnaryOp,A> &xpr, const RefType &ref)
465{
466 make_coherent(xpr.nestedExpression().const_cast_derived(), ref);
467}
468
469// needed for compilation only
470template<typename UnaryOp, typename A, typename RefType>
471void make_coherent_expression(const CwiseNullaryOp<UnaryOp,A> &, const RefType &)
472{}
473
Brian Silverman72890c22015-09-19 14:37:37 -0400474template<typename A_Scalar, int A_Rows, int A_Cols, int A_Options, int A_MaxRows, int A_MaxCols, typename B>
475struct make_coherent_impl<Matrix<A_Scalar, A_Rows, A_Cols, A_Options, A_MaxRows, A_MaxCols>, B> {
476 typedef Matrix<A_Scalar, A_Rows, A_Cols, A_Options, A_MaxRows, A_MaxCols> A;
477 static void run(A& a, B& b) {
478 if((A_Rows==Dynamic || A_Cols==Dynamic) && (a.size()==0))
479 {
480 a.resize(b.size());
481 a.setZero();
482 }
Austin Schuhc55b0172022-02-20 17:52:35 -0800483 else if (B::SizeAtCompileTime==Dynamic && a.size()!=0 && b.size()==0)
484 {
485 make_coherent_expression(b,a);
486 }
Brian Silverman72890c22015-09-19 14:37:37 -0400487 }
488};
489
490template<typename A, typename B_Scalar, int B_Rows, int B_Cols, int B_Options, int B_MaxRows, int B_MaxCols>
491struct make_coherent_impl<A, Matrix<B_Scalar, B_Rows, B_Cols, B_Options, B_MaxRows, B_MaxCols> > {
492 typedef Matrix<B_Scalar, B_Rows, B_Cols, B_Options, B_MaxRows, B_MaxCols> B;
493 static void run(A& a, B& b) {
494 if((B_Rows==Dynamic || B_Cols==Dynamic) && (b.size()==0))
495 {
496 b.resize(a.size());
497 b.setZero();
498 }
Austin Schuhc55b0172022-02-20 17:52:35 -0800499 else if (A::SizeAtCompileTime==Dynamic && b.size()!=0 && a.size()==0)
500 {
501 make_coherent_expression(a,b);
502 }
Brian Silverman72890c22015-09-19 14:37:37 -0400503 }
504};
505
506template<typename A_Scalar, int A_Rows, int A_Cols, int A_Options, int A_MaxRows, int A_MaxCols,
507 typename B_Scalar, int B_Rows, int B_Cols, int B_Options, int B_MaxRows, int B_MaxCols>
508struct make_coherent_impl<Matrix<A_Scalar, A_Rows, A_Cols, A_Options, A_MaxRows, A_MaxCols>,
Austin Schuhc55b0172022-02-20 17:52:35 -0800509 Matrix<B_Scalar, B_Rows, B_Cols, B_Options, B_MaxRows, B_MaxCols> > {
Brian Silverman72890c22015-09-19 14:37:37 -0400510 typedef Matrix<A_Scalar, A_Rows, A_Cols, A_Options, A_MaxRows, A_MaxCols> A;
511 typedef Matrix<B_Scalar, B_Rows, B_Cols, B_Options, B_MaxRows, B_MaxCols> B;
512 static void run(A& a, B& b) {
513 if((A_Rows==Dynamic || A_Cols==Dynamic) && (a.size()==0))
514 {
515 a.resize(b.size());
516 a.setZero();
517 }
518 else if((B_Rows==Dynamic || B_Cols==Dynamic) && (b.size()==0))
519 {
520 b.resize(a.size());
521 b.setZero();
522 }
523 }
524};
525
Brian Silverman72890c22015-09-19 14:37:37 -0400526} // end namespace internal
527
Austin Schuh189376f2018-12-20 22:11:15 +1100528template<typename DerType, typename BinOp>
529struct ScalarBinaryOpTraits<AutoDiffScalar<DerType>,typename DerType::Scalar,BinOp>
530{
531 typedef AutoDiffScalar<DerType> ReturnType;
532};
533
534template<typename DerType, typename BinOp>
535struct ScalarBinaryOpTraits<typename DerType::Scalar,AutoDiffScalar<DerType>, BinOp>
536{
537 typedef AutoDiffScalar<DerType> ReturnType;
538};
539
540
541// The following is an attempt to let Eigen's known about expression template, but that's more tricky!
542
543// template<typename DerType, typename BinOp>
544// struct ScalarBinaryOpTraits<AutoDiffScalar<DerType>,AutoDiffScalar<DerType>, BinOp>
545// {
546// enum { Defined = 1 };
547// typedef AutoDiffScalar<typename DerType::PlainObject> ReturnType;
548// };
549//
550// template<typename DerType1,typename DerType2, typename BinOp>
551// struct ScalarBinaryOpTraits<AutoDiffScalar<DerType1>,AutoDiffScalar<DerType2>, BinOp>
552// {
553// enum { Defined = 1 };//internal::is_same<typename DerType1::Scalar,typename DerType2::Scalar>::value };
554// typedef AutoDiffScalar<typename DerType1::PlainObject> ReturnType;
555// };
556
Brian Silverman72890c22015-09-19 14:37:37 -0400557#define EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(FUNC,CODE) \
558 template<typename DerType> \
Austin Schuh189376f2018-12-20 22:11:15 +1100559 inline const Eigen::AutoDiffScalar< \
560 EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(typename Eigen::internal::remove_all<DerType>::type, typename Eigen::internal::traits<typename Eigen::internal::remove_all<DerType>::type>::Scalar, product) > \
Brian Silverman72890c22015-09-19 14:37:37 -0400561 FUNC(const Eigen::AutoDiffScalar<DerType>& x) { \
562 using namespace Eigen; \
563 typedef typename Eigen::internal::traits<typename Eigen::internal::remove_all<DerType>::type>::Scalar Scalar; \
Austin Schuh189376f2018-12-20 22:11:15 +1100564 EIGEN_UNUSED_VARIABLE(sizeof(Scalar)); \
Brian Silverman72890c22015-09-19 14:37:37 -0400565 CODE; \
566 }
567
568template<typename DerType>
Austin Schuhc55b0172022-02-20 17:52:35 -0800569struct CleanedUpDerType {
570 typedef AutoDiffScalar<typename Eigen::internal::remove_all<DerType>::type::PlainObject> type;
571};
572
573template<typename DerType>
Brian Silverman72890c22015-09-19 14:37:37 -0400574inline const AutoDiffScalar<DerType>& conj(const AutoDiffScalar<DerType>& x) { return x; }
575template<typename DerType>
576inline const AutoDiffScalar<DerType>& real(const AutoDiffScalar<DerType>& x) { return x; }
577template<typename DerType>
578inline typename DerType::Scalar imag(const AutoDiffScalar<DerType>&) { return 0.; }
579template<typename DerType, typename T>
Austin Schuhc55b0172022-02-20 17:52:35 -0800580inline typename CleanedUpDerType<DerType>::type (min)(const AutoDiffScalar<DerType>& x, const T& y) {
581 typedef typename CleanedUpDerType<DerType>::type ADS;
Austin Schuh189376f2018-12-20 22:11:15 +1100582 return (x <= y ? ADS(x) : ADS(y));
583}
Brian Silverman72890c22015-09-19 14:37:37 -0400584template<typename DerType, typename T>
Austin Schuhc55b0172022-02-20 17:52:35 -0800585inline typename CleanedUpDerType<DerType>::type (max)(const AutoDiffScalar<DerType>& x, const T& y) {
586 typedef typename CleanedUpDerType<DerType>::type ADS;
Austin Schuh189376f2018-12-20 22:11:15 +1100587 return (x >= y ? ADS(x) : ADS(y));
588}
Brian Silverman72890c22015-09-19 14:37:37 -0400589template<typename DerType, typename T>
Austin Schuhc55b0172022-02-20 17:52:35 -0800590inline typename CleanedUpDerType<DerType>::type (min)(const T& x, const AutoDiffScalar<DerType>& y) {
591 typedef typename CleanedUpDerType<DerType>::type ADS;
Austin Schuh189376f2018-12-20 22:11:15 +1100592 return (x < y ? ADS(x) : ADS(y));
593}
Brian Silverman72890c22015-09-19 14:37:37 -0400594template<typename DerType, typename T>
Austin Schuhc55b0172022-02-20 17:52:35 -0800595inline typename CleanedUpDerType<DerType>::type (max)(const T& x, const AutoDiffScalar<DerType>& y) {
596 typedef typename CleanedUpDerType<DerType>::type ADS;
Austin Schuh189376f2018-12-20 22:11:15 +1100597 return (x > y ? ADS(x) : ADS(y));
598}
599template<typename DerType>
Austin Schuhc55b0172022-02-20 17:52:35 -0800600inline typename CleanedUpDerType<DerType>::type (min)(const AutoDiffScalar<DerType>& x, const AutoDiffScalar<DerType>& y) {
Austin Schuh189376f2018-12-20 22:11:15 +1100601 return (x.value() < y.value() ? x : y);
602}
603template<typename DerType>
Austin Schuhc55b0172022-02-20 17:52:35 -0800604inline typename CleanedUpDerType<DerType>::type (max)(const AutoDiffScalar<DerType>& x, const AutoDiffScalar<DerType>& y) {
Austin Schuh189376f2018-12-20 22:11:15 +1100605 return (x.value() >= y.value() ? x : y);
606}
607
Brian Silverman72890c22015-09-19 14:37:37 -0400608
609EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(abs,
610 using std::abs;
Austin Schuh189376f2018-12-20 22:11:15 +1100611 return Eigen::MakeAutoDiffScalar(abs(x.value()), x.derivatives() * (x.value()<0 ? -1 : 1) );)
Brian Silverman72890c22015-09-19 14:37:37 -0400612
613EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(abs2,
614 using numext::abs2;
Austin Schuh189376f2018-12-20 22:11:15 +1100615 return Eigen::MakeAutoDiffScalar(abs2(x.value()), x.derivatives() * (Scalar(2)*x.value()));)
Brian Silverman72890c22015-09-19 14:37:37 -0400616
617EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(sqrt,
618 using std::sqrt;
619 Scalar sqrtx = sqrt(x.value());
Austin Schuh189376f2018-12-20 22:11:15 +1100620 return Eigen::MakeAutoDiffScalar(sqrtx,x.derivatives() * (Scalar(0.5) / sqrtx));)
Brian Silverman72890c22015-09-19 14:37:37 -0400621
622EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(cos,
623 using std::cos;
624 using std::sin;
Austin Schuh189376f2018-12-20 22:11:15 +1100625 return Eigen::MakeAutoDiffScalar(cos(x.value()), x.derivatives() * (-sin(x.value())));)
Brian Silverman72890c22015-09-19 14:37:37 -0400626
627EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(sin,
628 using std::sin;
629 using std::cos;
Austin Schuh189376f2018-12-20 22:11:15 +1100630 return Eigen::MakeAutoDiffScalar(sin(x.value()),x.derivatives() * cos(x.value()));)
Brian Silverman72890c22015-09-19 14:37:37 -0400631
632EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(exp,
633 using std::exp;
634 Scalar expx = exp(x.value());
Austin Schuh189376f2018-12-20 22:11:15 +1100635 return Eigen::MakeAutoDiffScalar(expx,x.derivatives() * expx);)
Brian Silverman72890c22015-09-19 14:37:37 -0400636
637EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(log,
638 using std::log;
Austin Schuh189376f2018-12-20 22:11:15 +1100639 return Eigen::MakeAutoDiffScalar(log(x.value()),x.derivatives() * (Scalar(1)/x.value()));)
Brian Silverman72890c22015-09-19 14:37:37 -0400640
641template<typename DerType>
Austin Schuh189376f2018-12-20 22:11:15 +1100642inline const Eigen::AutoDiffScalar<
643EIGEN_EXPR_BINARYOP_SCALAR_RETURN_TYPE(typename internal::remove_all<DerType>::type,typename internal::traits<typename internal::remove_all<DerType>::type>::Scalar,product) >
644pow(const Eigen::AutoDiffScalar<DerType> &x, const typename internal::traits<typename internal::remove_all<DerType>::type>::Scalar &y)
Brian Silverman72890c22015-09-19 14:37:37 -0400645{
646 using namespace Eigen;
Austin Schuh189376f2018-12-20 22:11:15 +1100647 using std::pow;
648 return Eigen::MakeAutoDiffScalar(pow(x.value(),y), x.derivatives() * (y * pow(x.value(),y-1)));
Brian Silverman72890c22015-09-19 14:37:37 -0400649}
650
651
652template<typename DerTypeA,typename DerTypeB>
Austin Schuh189376f2018-12-20 22:11:15 +1100653inline const AutoDiffScalar<Matrix<typename internal::traits<typename internal::remove_all<DerTypeA>::type>::Scalar,Dynamic,1> >
Brian Silverman72890c22015-09-19 14:37:37 -0400654atan2(const AutoDiffScalar<DerTypeA>& a, const AutoDiffScalar<DerTypeB>& b)
655{
656 using std::atan2;
Austin Schuh189376f2018-12-20 22:11:15 +1100657 typedef typename internal::traits<typename internal::remove_all<DerTypeA>::type>::Scalar Scalar;
Brian Silverman72890c22015-09-19 14:37:37 -0400658 typedef AutoDiffScalar<Matrix<Scalar,Dynamic,1> > PlainADS;
659 PlainADS ret;
660 ret.value() = atan2(a.value(), b.value());
661
Austin Schuh189376f2018-12-20 22:11:15 +1100662 Scalar squared_hypot = a.value() * a.value() + b.value() * b.value();
Brian Silverman72890c22015-09-19 14:37:37 -0400663
Austin Schuh189376f2018-12-20 22:11:15 +1100664 // if (squared_hypot==0) the derivation is undefined and the following results in a NaN:
665 ret.derivatives() = (a.derivatives() * b.value() - a.value() * b.derivatives()) / squared_hypot;
Brian Silverman72890c22015-09-19 14:37:37 -0400666
667 return ret;
668}
669
670EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(tan,
671 using std::tan;
672 using std::cos;
Austin Schuh189376f2018-12-20 22:11:15 +1100673 return Eigen::MakeAutoDiffScalar(tan(x.value()),x.derivatives() * (Scalar(1)/numext::abs2(cos(x.value()))));)
Brian Silverman72890c22015-09-19 14:37:37 -0400674
675EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(asin,
676 using std::sqrt;
677 using std::asin;
Austin Schuh189376f2018-12-20 22:11:15 +1100678 return Eigen::MakeAutoDiffScalar(asin(x.value()),x.derivatives() * (Scalar(1)/sqrt(1-numext::abs2(x.value()))));)
Brian Silverman72890c22015-09-19 14:37:37 -0400679
680EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(acos,
681 using std::sqrt;
682 using std::acos;
Austin Schuh189376f2018-12-20 22:11:15 +1100683 return Eigen::MakeAutoDiffScalar(acos(x.value()),x.derivatives() * (Scalar(-1)/sqrt(1-numext::abs2(x.value()))));)
684
685EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(tanh,
686 using std::cosh;
687 using std::tanh;
688 return Eigen::MakeAutoDiffScalar(tanh(x.value()),x.derivatives() * (Scalar(1)/numext::abs2(cosh(x.value()))));)
689
690EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(sinh,
691 using std::sinh;
692 using std::cosh;
693 return Eigen::MakeAutoDiffScalar(sinh(x.value()),x.derivatives() * cosh(x.value()));)
694
695EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(cosh,
696 using std::sinh;
697 using std::cosh;
698 return Eigen::MakeAutoDiffScalar(cosh(x.value()),x.derivatives() * sinh(x.value()));)
Brian Silverman72890c22015-09-19 14:37:37 -0400699
700#undef EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY
701
702template<typename DerType> struct NumTraits<AutoDiffScalar<DerType> >
Austin Schuh189376f2018-12-20 22:11:15 +1100703 : NumTraits< typename NumTraits<typename internal::remove_all<DerType>::type::Scalar>::Real >
Brian Silverman72890c22015-09-19 14:37:37 -0400704{
Austin Schuh189376f2018-12-20 22:11:15 +1100705 typedef typename internal::remove_all<DerType>::type DerTypeCleaned;
706 typedef AutoDiffScalar<Matrix<typename NumTraits<typename DerTypeCleaned::Scalar>::Real,DerTypeCleaned::RowsAtCompileTime,DerTypeCleaned::ColsAtCompileTime,
707 0, DerTypeCleaned::MaxRowsAtCompileTime, DerTypeCleaned::MaxColsAtCompileTime> > Real;
Brian Silverman72890c22015-09-19 14:37:37 -0400708 typedef AutoDiffScalar<DerType> NonInteger;
Austin Schuh189376f2018-12-20 22:11:15 +1100709 typedef AutoDiffScalar<DerType> Nested;
710 typedef typename NumTraits<typename DerTypeCleaned::Scalar>::Literal Literal;
Brian Silverman72890c22015-09-19 14:37:37 -0400711 enum{
712 RequireInitialization = 1
713 };
714};
715
716}
717
Austin Schuh189376f2018-12-20 22:11:15 +1100718namespace std {
Austin Schuhc55b0172022-02-20 17:52:35 -0800719
Austin Schuh189376f2018-12-20 22:11:15 +1100720template <typename T>
721class numeric_limits<Eigen::AutoDiffScalar<T> >
722 : public numeric_limits<typename T::Scalar> {};
723
Austin Schuhc55b0172022-02-20 17:52:35 -0800724template <typename T>
725class numeric_limits<Eigen::AutoDiffScalar<T&> >
726 : public numeric_limits<typename T::Scalar> {};
727
Austin Schuh189376f2018-12-20 22:11:15 +1100728} // namespace std
729
Brian Silverman72890c22015-09-19 14:37:37 -0400730#endif // EIGEN_AUTODIFF_SCALAR_H