Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 1 | // This file is part of Eigen, a lightweight C++ template library |
| 2 | // for linear algebra. |
| 3 | // |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 4 | // Copyright (C) 2011, 2013 Jitse Niesen <jitse@maths.leeds.ac.uk> |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 5 | // |
| 6 | // This Source Code Form is subject to the terms of the Mozilla |
| 7 | // Public License v. 2.0. If a copy of the MPL was not distributed |
| 8 | // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. |
| 9 | |
| 10 | #ifndef EIGEN_MATRIX_SQUARE_ROOT |
| 11 | #define EIGEN_MATRIX_SQUARE_ROOT |
| 12 | |
| 13 | namespace Eigen { |
| 14 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 15 | namespace internal { |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 16 | |
| 17 | // pre: T.block(i,i,2,2) has complex conjugate eigenvalues |
| 18 | // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 19 | template <typename MatrixType, typename ResultType> |
| 20 | void matrix_sqrt_quasi_triangular_2x2_diagonal_block(const MatrixType& T, typename MatrixType::Index i, ResultType& sqrtT) |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 21 | { |
| 22 | // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere |
| 23 | // in EigenSolver. If we expose it, we could call it directly from here. |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 24 | typedef typename traits<MatrixType>::Scalar Scalar; |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 25 | Matrix<Scalar,2,2> block = T.template block<2,2>(i,i); |
| 26 | EigenSolver<Matrix<Scalar,2,2> > es(block); |
| 27 | sqrtT.template block<2,2>(i,i) |
| 28 | = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real(); |
| 29 | } |
| 30 | |
| 31 | // pre: block structure of T is such that (i,j) is a 1x1 block, |
| 32 | // all blocks of sqrtT to left of and below (i,j) are correct |
| 33 | // post: sqrtT(i,j) has the correct value |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 34 | template <typename MatrixType, typename ResultType> |
| 35 | void matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 36 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 37 | typedef typename traits<MatrixType>::Scalar Scalar; |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 38 | Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); |
| 39 | sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j)); |
| 40 | } |
| 41 | |
| 42 | // similar to compute1x1offDiagonalBlock() |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 43 | template <typename MatrixType, typename ResultType> |
| 44 | void matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 45 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 46 | typedef typename traits<MatrixType>::Scalar Scalar; |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 47 | Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j); |
| 48 | if (j-i > 1) |
| 49 | rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2); |
| 50 | Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity(); |
| 51 | A += sqrtT.template block<2,2>(j,j).transpose(); |
| 52 | sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose()); |
| 53 | } |
| 54 | |
| 55 | // similar to compute1x1offDiagonalBlock() |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 56 | template <typename MatrixType, typename ResultType> |
| 57 | void matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 58 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 59 | typedef typename traits<MatrixType>::Scalar Scalar; |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 60 | Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j); |
| 61 | if (j-i > 2) |
| 62 | rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1); |
| 63 | Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity(); |
| 64 | A += sqrtT.template block<2,2>(i,i); |
| 65 | sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs); |
| 66 | } |
| 67 | |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 68 | // solves the equation A X + X B = C where all matrices are 2-by-2 |
| 69 | template <typename MatrixType> |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 70 | void matrix_sqrt_quasi_triangular_solve_auxiliary_equation(MatrixType& X, const MatrixType& A, const MatrixType& B, const MatrixType& C) |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 71 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 72 | typedef typename traits<MatrixType>::Scalar Scalar; |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 73 | Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero(); |
| 74 | coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0); |
| 75 | coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1); |
| 76 | coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0); |
| 77 | coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1); |
| 78 | coeffMatrix.coeffRef(0,1) = B.coeff(1,0); |
| 79 | coeffMatrix.coeffRef(0,2) = A.coeff(0,1); |
| 80 | coeffMatrix.coeffRef(1,0) = B.coeff(0,1); |
| 81 | coeffMatrix.coeffRef(1,3) = A.coeff(0,1); |
| 82 | coeffMatrix.coeffRef(2,0) = A.coeff(1,0); |
| 83 | coeffMatrix.coeffRef(2,3) = B.coeff(1,0); |
| 84 | coeffMatrix.coeffRef(3,1) = A.coeff(1,0); |
| 85 | coeffMatrix.coeffRef(3,2) = B.coeff(0,1); |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 86 | |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 87 | Matrix<Scalar,4,1> rhs; |
| 88 | rhs.coeffRef(0) = C.coeff(0,0); |
| 89 | rhs.coeffRef(1) = C.coeff(0,1); |
| 90 | rhs.coeffRef(2) = C.coeff(1,0); |
| 91 | rhs.coeffRef(3) = C.coeff(1,1); |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 92 | |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 93 | Matrix<Scalar,4,1> result; |
| 94 | result = coeffMatrix.fullPivLu().solve(rhs); |
| 95 | |
| 96 | X.coeffRef(0,0) = result.coeff(0); |
| 97 | X.coeffRef(0,1) = result.coeff(1); |
| 98 | X.coeffRef(1,0) = result.coeff(2); |
| 99 | X.coeffRef(1,1) = result.coeff(3); |
| 100 | } |
| 101 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 102 | // similar to compute1x1offDiagonalBlock() |
| 103 | template <typename MatrixType, typename ResultType> |
| 104 | void matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(const MatrixType& T, typename MatrixType::Index i, typename MatrixType::Index j, ResultType& sqrtT) |
| 105 | { |
| 106 | typedef typename traits<MatrixType>::Scalar Scalar; |
| 107 | Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i); |
| 108 | Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j); |
| 109 | Matrix<Scalar,2,2> C = T.template block<2,2>(i,j); |
| 110 | if (j-i > 2) |
| 111 | C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2); |
| 112 | Matrix<Scalar,2,2> X; |
| 113 | matrix_sqrt_quasi_triangular_solve_auxiliary_equation(X, A, B, C); |
| 114 | sqrtT.template block<2,2>(i,j) = X; |
| 115 | } |
| 116 | |
| 117 | // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size |
| 118 | // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T |
| 119 | template <typename MatrixType, typename ResultType> |
| 120 | void matrix_sqrt_quasi_triangular_diagonal(const MatrixType& T, ResultType& sqrtT) |
| 121 | { |
| 122 | using std::sqrt; |
| 123 | const Index size = T.rows(); |
| 124 | for (Index i = 0; i < size; i++) { |
| 125 | if (i == size - 1 || T.coeff(i+1, i) == 0) { |
| 126 | eigen_assert(T(i,i) >= 0); |
| 127 | sqrtT.coeffRef(i,i) = sqrt(T.coeff(i,i)); |
| 128 | } |
| 129 | else { |
| 130 | matrix_sqrt_quasi_triangular_2x2_diagonal_block(T, i, sqrtT); |
| 131 | ++i; |
| 132 | } |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | // pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. |
| 137 | // post: sqrtT is the square root of T. |
| 138 | template <typename MatrixType, typename ResultType> |
| 139 | void matrix_sqrt_quasi_triangular_off_diagonal(const MatrixType& T, ResultType& sqrtT) |
| 140 | { |
| 141 | const Index size = T.rows(); |
| 142 | for (Index j = 1; j < size; j++) { |
| 143 | if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block |
| 144 | continue; |
| 145 | for (Index i = j-1; i >= 0; i--) { |
| 146 | if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block |
| 147 | continue; |
| 148 | bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0); |
| 149 | bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0); |
| 150 | if (iBlockIs2x2 && jBlockIs2x2) |
| 151 | matrix_sqrt_quasi_triangular_2x2_off_diagonal_block(T, i, j, sqrtT); |
| 152 | else if (iBlockIs2x2 && !jBlockIs2x2) |
| 153 | matrix_sqrt_quasi_triangular_2x1_off_diagonal_block(T, i, j, sqrtT); |
| 154 | else if (!iBlockIs2x2 && jBlockIs2x2) |
| 155 | matrix_sqrt_quasi_triangular_1x2_off_diagonal_block(T, i, j, sqrtT); |
| 156 | else if (!iBlockIs2x2 && !jBlockIs2x2) |
| 157 | matrix_sqrt_quasi_triangular_1x1_off_diagonal_block(T, i, j, sqrtT); |
| 158 | } |
| 159 | } |
| 160 | } |
| 161 | |
| 162 | } // end of namespace internal |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 163 | |
| 164 | /** \ingroup MatrixFunctions_Module |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 165 | * \brief Compute matrix square root of quasi-triangular matrix. |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 166 | * |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 167 | * \tparam MatrixType type of \p arg, the argument of matrix square root, |
| 168 | * expected to be an instantiation of the Matrix class template. |
| 169 | * \tparam ResultType type of \p result, where result is to be stored. |
| 170 | * \param[in] arg argument of matrix square root. |
| 171 | * \param[out] result matrix square root of upper Hessenberg part of \p arg. |
| 172 | * |
| 173 | * This function computes the square root of the upper quasi-triangular matrix stored in the upper |
| 174 | * Hessenberg part of \p arg. Only the upper Hessenberg part of \p result is updated, the rest is |
| 175 | * not touched. See MatrixBase::sqrt() for details on how this computation is implemented. |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 176 | * |
| 177 | * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular |
| 178 | */ |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 179 | template <typename MatrixType, typename ResultType> |
| 180 | void matrix_sqrt_quasi_triangular(const MatrixType &arg, ResultType &result) |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 181 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 182 | eigen_assert(arg.rows() == arg.cols()); |
| 183 | result.resize(arg.rows(), arg.cols()); |
| 184 | internal::matrix_sqrt_quasi_triangular_diagonal(arg, result); |
| 185 | internal::matrix_sqrt_quasi_triangular_off_diagonal(arg, result); |
| 186 | } |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 187 | |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 188 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 189 | /** \ingroup MatrixFunctions_Module |
| 190 | * \brief Compute matrix square root of triangular matrix. |
| 191 | * |
| 192 | * \tparam MatrixType type of \p arg, the argument of matrix square root, |
| 193 | * expected to be an instantiation of the Matrix class template. |
| 194 | * \tparam ResultType type of \p result, where result is to be stored. |
| 195 | * \param[in] arg argument of matrix square root. |
| 196 | * \param[out] result matrix square root of upper triangular part of \p arg. |
| 197 | * |
| 198 | * Only the upper triangular part (including the diagonal) of \p result is updated, the rest is not |
| 199 | * touched. See MatrixBase::sqrt() for details on how this computation is implemented. |
| 200 | * |
| 201 | * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular |
| 202 | */ |
| 203 | template <typename MatrixType, typename ResultType> |
| 204 | void matrix_sqrt_triangular(const MatrixType &arg, ResultType &result) |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 205 | { |
| 206 | using std::sqrt; |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 207 | typedef typename MatrixType::Scalar Scalar; |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 208 | |
| 209 | eigen_assert(arg.rows() == arg.cols()); |
| 210 | |
| 211 | // Compute square root of arg and store it in upper triangular part of result |
| 212 | // This uses that the square root of triangular matrices can be computed directly. |
| 213 | result.resize(arg.rows(), arg.cols()); |
| 214 | for (Index i = 0; i < arg.rows(); i++) { |
| 215 | result.coeffRef(i,i) = sqrt(arg.coeff(i,i)); |
| 216 | } |
| 217 | for (Index j = 1; j < arg.cols(); j++) { |
| 218 | for (Index i = j-1; i >= 0; i--) { |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 219 | // if i = j-1, then segment has length 0 so tmp = 0 |
| 220 | Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); |
| 221 | // denominator may be zero if original matrix is singular |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 222 | result.coeffRef(i,j) = (arg.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 223 | } |
| 224 | } |
| 225 | } |
| 226 | |
| 227 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 228 | namespace internal { |
| 229 | |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 230 | /** \ingroup MatrixFunctions_Module |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 231 | * \brief Helper struct for computing matrix square roots of general matrices. |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 232 | * \tparam MatrixType type of the argument of the matrix square root, |
| 233 | * expected to be an instantiation of the Matrix class template. |
| 234 | * |
| 235 | * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt() |
| 236 | */ |
| 237 | template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex> |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 238 | struct matrix_sqrt_compute |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 239 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 240 | /** \brief Compute the matrix square root |
| 241 | * |
| 242 | * \param[in] arg matrix whose square root is to be computed. |
| 243 | * \param[out] result square root of \p arg. |
| 244 | * |
| 245 | * See MatrixBase::sqrt() for details on how this computation is implemented. |
| 246 | */ |
| 247 | template <typename ResultType> static void run(const MatrixType &arg, ResultType &result); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 248 | }; |
| 249 | |
| 250 | |
| 251 | // ********** Partial specialization for real matrices ********** |
| 252 | |
| 253 | template <typename MatrixType> |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 254 | struct matrix_sqrt_compute<MatrixType, 0> |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 255 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 256 | template <typename ResultType> |
| 257 | static void run(const MatrixType &arg, ResultType &result) |
| 258 | { |
| 259 | eigen_assert(arg.rows() == arg.cols()); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 260 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 261 | // Compute Schur decomposition of arg |
| 262 | const RealSchur<MatrixType> schurOfA(arg); |
| 263 | const MatrixType& T = schurOfA.matrixT(); |
| 264 | const MatrixType& U = schurOfA.matrixU(); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 265 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 266 | // Compute square root of T |
| 267 | MatrixType sqrtT = MatrixType::Zero(arg.rows(), arg.cols()); |
| 268 | matrix_sqrt_quasi_triangular(T, sqrtT); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 269 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 270 | // Compute square root of arg |
| 271 | result = U * sqrtT * U.adjoint(); |
| 272 | } |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 273 | }; |
| 274 | |
| 275 | |
| 276 | // ********** Partial specialization for complex matrices ********** |
| 277 | |
| 278 | template <typename MatrixType> |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 279 | struct matrix_sqrt_compute<MatrixType, 1> |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 280 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 281 | template <typename ResultType> |
| 282 | static void run(const MatrixType &arg, ResultType &result) |
| 283 | { |
| 284 | eigen_assert(arg.rows() == arg.cols()); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 285 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 286 | // Compute Schur decomposition of arg |
| 287 | const ComplexSchur<MatrixType> schurOfA(arg); |
| 288 | const MatrixType& T = schurOfA.matrixT(); |
| 289 | const MatrixType& U = schurOfA.matrixU(); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 290 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 291 | // Compute square root of T |
| 292 | MatrixType sqrtT; |
| 293 | matrix_sqrt_triangular(T, sqrtT); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 294 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 295 | // Compute square root of arg |
| 296 | result = U * (sqrtT.template triangularView<Upper>() * U.adjoint()); |
| 297 | } |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 298 | }; |
| 299 | |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 300 | } // end namespace internal |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 301 | |
| 302 | /** \ingroup MatrixFunctions_Module |
| 303 | * |
| 304 | * \brief Proxy for the matrix square root of some matrix (expression). |
| 305 | * |
| 306 | * \tparam Derived Type of the argument to the matrix square root. |
| 307 | * |
| 308 | * This class holds the argument to the matrix square root until it |
| 309 | * is assigned or evaluated for some other reason (so the argument |
| 310 | * should not be changed in the meantime). It is the return type of |
| 311 | * MatrixBase::sqrt() and most of the time this is the only way it is |
| 312 | * used. |
| 313 | */ |
| 314 | template<typename Derived> class MatrixSquareRootReturnValue |
| 315 | : public ReturnByValue<MatrixSquareRootReturnValue<Derived> > |
| 316 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 317 | protected: |
| 318 | typedef typename internal::ref_selector<Derived>::type DerivedNested; |
| 319 | |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 320 | public: |
| 321 | /** \brief Constructor. |
| 322 | * |
| 323 | * \param[in] src %Matrix (expression) forming the argument of the |
| 324 | * matrix square root. |
| 325 | */ |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 326 | explicit MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { } |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 327 | |
| 328 | /** \brief Compute the matrix square root. |
| 329 | * |
| 330 | * \param[out] result the matrix square root of \p src in the |
| 331 | * constructor. |
| 332 | */ |
| 333 | template <typename ResultType> |
| 334 | inline void evalTo(ResultType& result) const |
| 335 | { |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 336 | typedef typename internal::nested_eval<Derived, 10>::type DerivedEvalType; |
| 337 | typedef typename internal::remove_all<DerivedEvalType>::type DerivedEvalTypeClean; |
| 338 | DerivedEvalType tmp(m_src); |
| 339 | internal::matrix_sqrt_compute<DerivedEvalTypeClean>::run(tmp, result); |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 340 | } |
| 341 | |
| 342 | Index rows() const { return m_src.rows(); } |
| 343 | Index cols() const { return m_src.cols(); } |
| 344 | |
| 345 | protected: |
Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 346 | const DerivedNested m_src; |
Brian Silverman | 72890c2 | 2015-09-19 14:37:37 -0400 | [diff] [blame] | 347 | }; |
| 348 | |
| 349 | namespace internal { |
| 350 | template<typename Derived> |
| 351 | struct traits<MatrixSquareRootReturnValue<Derived> > |
| 352 | { |
| 353 | typedef typename Derived::PlainObject ReturnType; |
| 354 | }; |
| 355 | } |
| 356 | |
| 357 | template <typename Derived> |
| 358 | const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const |
| 359 | { |
| 360 | eigen_assert(rows() == cols()); |
| 361 | return MatrixSquareRootReturnValue<Derived>(derived()); |
| 362 | } |
| 363 | |
| 364 | } // end namespace Eigen |
| 365 | |
| 366 | #endif // EIGEN_MATRIX_FUNCTION |