Austin Schuh | 189376f | 2018-12-20 22:11:15 +1100 | [diff] [blame^] | 1 | #include <Eigen/Core> |
| 2 | #include <iostream> |
| 3 | |
| 4 | using namespace Eigen; |
| 5 | |
| 6 | // [functor] |
| 7 | template<class ArgType, class RowIndexType, class ColIndexType> |
| 8 | class indexing_functor { |
| 9 | const ArgType &m_arg; |
| 10 | const RowIndexType &m_rowIndices; |
| 11 | const ColIndexType &m_colIndices; |
| 12 | public: |
| 13 | typedef Matrix<typename ArgType::Scalar, |
| 14 | RowIndexType::SizeAtCompileTime, |
| 15 | ColIndexType::SizeAtCompileTime, |
| 16 | ArgType::Flags&RowMajorBit?RowMajor:ColMajor, |
| 17 | RowIndexType::MaxSizeAtCompileTime, |
| 18 | ColIndexType::MaxSizeAtCompileTime> MatrixType; |
| 19 | |
| 20 | indexing_functor(const ArgType& arg, const RowIndexType& row_indices, const ColIndexType& col_indices) |
| 21 | : m_arg(arg), m_rowIndices(row_indices), m_colIndices(col_indices) |
| 22 | {} |
| 23 | |
| 24 | const typename ArgType::Scalar& operator() (Index row, Index col) const { |
| 25 | return m_arg(m_rowIndices[row], m_colIndices[col]); |
| 26 | } |
| 27 | }; |
| 28 | // [functor] |
| 29 | |
| 30 | // [function] |
| 31 | template <class ArgType, class RowIndexType, class ColIndexType> |
| 32 | CwiseNullaryOp<indexing_functor<ArgType,RowIndexType,ColIndexType>, typename indexing_functor<ArgType,RowIndexType,ColIndexType>::MatrixType> |
| 33 | indexing(const Eigen::MatrixBase<ArgType>& arg, const RowIndexType& row_indices, const ColIndexType& col_indices) |
| 34 | { |
| 35 | typedef indexing_functor<ArgType,RowIndexType,ColIndexType> Func; |
| 36 | typedef typename Func::MatrixType MatrixType; |
| 37 | return MatrixType::NullaryExpr(row_indices.size(), col_indices.size(), Func(arg.derived(), row_indices, col_indices)); |
| 38 | } |
| 39 | // [function] |
| 40 | |
| 41 | |
| 42 | int main() |
| 43 | { |
| 44 | std::cout << "[main1]\n"; |
| 45 | Eigen::MatrixXi A = Eigen::MatrixXi::Random(4,4); |
| 46 | Array3i ri(1,2,1); |
| 47 | ArrayXi ci(6); ci << 3,2,1,0,0,2; |
| 48 | Eigen::MatrixXi B = indexing(A, ri, ci); |
| 49 | std::cout << "A =" << std::endl; |
| 50 | std::cout << A << std::endl << std::endl; |
| 51 | std::cout << "A([" << ri.transpose() << "], [" << ci.transpose() << "]) =" << std::endl; |
| 52 | std::cout << B << std::endl; |
| 53 | std::cout << "[main1]\n"; |
| 54 | |
| 55 | std::cout << "[main2]\n"; |
| 56 | B = indexing(A, ri+1, ci); |
| 57 | std::cout << "A(ri+1,ci) =" << std::endl; |
| 58 | std::cout << B << std::endl << std::endl; |
| 59 | #if __cplusplus >= 201103L |
| 60 | B = indexing(A, ArrayXi::LinSpaced(13,0,12).unaryExpr([](int x){return x%4;}), ArrayXi::LinSpaced(4,0,3)); |
| 61 | std::cout << "A(ArrayXi::LinSpaced(13,0,12).unaryExpr([](int x){return x%4;}), ArrayXi::LinSpaced(4,0,3)) =" << std::endl; |
| 62 | std::cout << B << std::endl << std::endl; |
| 63 | #endif |
| 64 | std::cout << "[main2]\n"; |
| 65 | } |
| 66 | |