Add a convenient vector wrapper
Change-Id: I9b620d065f69909e5be74cd0df802bf9e57eb4d0
diff --git a/aos/vision/math/BUILD b/aos/vision/math/BUILD
new file mode 100644
index 0000000..c52e963
--- /dev/null
+++ b/aos/vision/math/BUILD
@@ -0,0 +1,21 @@
+cc_library(
+ name = 'vector',
+ hdrs = [
+ 'vector.h',
+ ],
+ deps = [
+ '//third_party/eigen',
+ ],
+)
+
+cc_test(
+ name = 'vector_test',
+ srcs = [
+ 'vector_test.cc',
+ ],
+ deps = [
+ ':vector',
+ '//aos/testing:googletest',
+ ],
+ size = 'small',
+)
diff --git a/aos/vision/math/vector.h b/aos/vision/math/vector.h
new file mode 100644
index 0000000..0d689e6
--- /dev/null
+++ b/aos/vision/math/vector.h
@@ -0,0 +1,202 @@
+#ifndef AOS_VISION_MATH_VECTOR_H_
+#define AOS_VISION_MATH_VECTOR_H_
+
+#include <cmath>
+
+#include "Eigen/Dense"
+
+namespace aos {
+namespace vision {
+
+// Represents an n-dimensional vector of doubles with various convenient
+// shortcuts for common operations.
+//
+// This includes overloads of various arithmetic operators for convenience.
+// Multiplication by doubles is scalar multiplication and multiplication by
+// other vectors is element-wise.
+//
+// Accessing elements which don't exist for a given size vector, mixing sizes of
+// vectors in appropriately, etc are compile-time errors.
+template <int size>
+class Vector {
+ public:
+ Vector() { data_.SetZero(); }
+
+ Vector(double x, double y) { Set(x, y); }
+
+ Vector(double x, double y, double z) { Set(x, y, z); }
+
+ double Get(int index) const { return data_(index); }
+ void Set(int index, double value) { data_(index) = value; }
+
+ void Set(double x, double y) {
+ static_assert(size == 2, "illegal size");
+ data_(0) = x;
+ data_(1) = y;
+ }
+ void Set(double x, double y, double z) {
+ static_assert(size == 3, "illegal size");
+ data_(0) = x;
+ data_(1) = y;
+ data_(2) = z;
+ }
+ void Set(double x, double y, double z, double w) {
+ static_assert(size == 4, "illegal size");
+ data_(0) = x;
+ data_(1) = y;
+ data_(2) = z;
+ data_(3) = w;
+ }
+
+ double x() const {
+ static_assert(size >= 1, "illegal size");
+ return data_(0);
+ }
+ void x(double xX) {
+ static_assert(size >= 1, "illegal size");
+ data_(0) = xX;
+ }
+
+ double y() const {
+ static_assert(size >= 2, "illegal size");
+ return data_(1);
+ }
+ void y(double yY) {
+ static_assert(size >= 2, "illegal size");
+ data_(1) = yY;
+ }
+
+ double z() const {
+ static_assert(size >= 3, "illegal size");
+ return data_(2);
+ }
+ void z(double zZ) {
+ static_assert(size >= 3, "illegal size");
+ data_(2) = zZ;
+ }
+
+ double w() const {
+ static_assert(size >= 4, "illegal size");
+ return data_(3);
+ }
+ void w(double wW) {
+ static_assert(size >= 4, "illegal size");
+ data_(3) = wW;
+ }
+
+ // Fast part of length.
+ double MagSqr() const { return data_.squaredNorm(); }
+
+ // Length of the vector.
+ double Mag() const { return data_.norm(); }
+
+ // Get underlying data structure
+ ::Eigen::Matrix<double, 1, size> GetData() const { return data_; }
+
+ // Set underlying data structure
+ void SetData(const ::Eigen::Matrix<double, 1, size> &other) { data_ = other; }
+
+ Vector<size> operator+(const Vector<size> &other) const {
+ Vector<size> nv = *this;
+ nv += other;
+ return nv;
+ }
+ Vector<size> operator+=(const Vector<size> &other) {
+ data_ += other.data_;
+ return *this;
+ }
+
+ Vector<size> operator-(const Vector<size> &other) const {
+ Vector<size> nv = *this;
+ nv -= other;
+ return nv;
+ }
+ Vector<size> operator-=(const Vector<size> &other) {
+ data_ -= other.data_;
+ return *this;
+ }
+
+ Vector<size> operator*(double other) {
+ Vector<size> nv = *this;
+ nv *= other;
+ return nv;
+ }
+ Vector<size> operator*=(double other) {
+ data_ *= other;
+ return *this;
+ }
+
+ Vector<size> operator*(const Vector<size> &other) const {
+ Vector<size> nv = *this;
+ nv *= other;
+ return nv;
+ }
+ Vector<size> operator*=(const Vector<size> &other) {
+ for (int i = 0; i < size; i++) {
+ Set(i, other.Get(i) * Get(i));
+ }
+ return *this;
+ }
+
+ bool operator==(const Vector<size> &other) const {
+ for (int i = 0; i < size; i++) {
+ if (Get(i) != other.Get(i)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ double dot(const Vector<size> &other) const {
+ return data_.dot(other.GetData());
+ }
+
+ Vector<size> cross(const Vector<size> &other) const {
+ Vector<size> nv;
+ nv.SetData(data_.cross(other.GetData()));
+ return nv;
+ }
+
+ // Returns a vector in the same direction as this one with a magnitude of 1.
+ Vector<size> Normalized() const {
+ double mag = Mag();
+ Vector<size> nv;
+ for (int i = 0; i < size; i++) {
+ nv.Set(i, Get(i) / mag);
+ }
+ return nv;
+ }
+
+ // Returns the angle between this vector and the 0 vector.
+ // Only valid for 2-dimensional vectors.
+ double AngleToZero() const {
+ static_assert(size == 2, "illegal size");
+ return ::std::atan2(y(), x());
+ }
+
+ // Return the angle between this and other.
+ double AngleTo(const Vector<size> other) const {
+ // cos(theta) = u.dot(v) / (u.magnitude() * v.magnitude())
+ return ::std::acos(dot(other) / (Mag() * other.Mag()));
+ }
+
+ // Returns the distance between this and other squared.
+ double SquaredDistanceTo(const Vector<size> other) {
+ Vector<size> tmp = *this - other;
+ return tmp.MagSqr();
+ }
+
+ private:
+ // The actual data.
+ ::Eigen::Matrix<double, 1, size> data_;
+};
+
+// Returns the cross product of two points.
+inline double PointsCrossProduct(const Vector<2> &a, const Vector<2> &b) {
+ return a.x() * b.y() - a.y() * b.x();
+}
+
+} // namespace vision
+} // namespace aos
+
+#endif // AOS_VISION_MATH_VECTOR_H_
diff --git a/aos/vision/math/vector_test.cc b/aos/vision/math/vector_test.cc
new file mode 100644
index 0000000..8c5244d
--- /dev/null
+++ b/aos/vision/math/vector_test.cc
@@ -0,0 +1,76 @@
+#include "aos/vision/math/vector.h"
+
+#include "gtest/gtest.h"
+
+namespace aos {
+namespace vision {
+namespace testing {
+
+class VectorTest : public ::testing::Test {
+ protected:
+ const Vector<3> vec1_{1.0, 1.0, 1.0};
+ const Vector<3> vec2_{2.0, 4.0, 6.0};
+ const Vector<3> vec5_{2.0, 2.0, 1.0};
+ const Vector<3> vec6_{1.0, 1.0, 1.0};
+};
+
+TEST_F(VectorTest, Equality) {
+ EXPECT_FALSE(vec1_ == vec2_);
+ EXPECT_TRUE(Vector<3>(2.0, 4.0, 6.0) == vec2_);
+}
+
+TEST_F(VectorTest, Addition) {
+ Vector<3> vec3 = vec1_ + vec2_;
+ EXPECT_EQ(3.0, vec3.x());
+ EXPECT_EQ(5.0, vec3.y());
+ EXPECT_EQ(7.0, vec3.z());
+}
+
+TEST_F(VectorTest, Multiplication) {
+ auto new_vec1 = vec1_;
+ new_vec1 *= 2.0;
+ EXPECT_EQ(2.0, new_vec1.x());
+ EXPECT_EQ(2.0, new_vec1.y());
+ EXPECT_EQ(2.0, new_vec1.z());
+
+ auto new_vec2 = new_vec1 * 2;
+ EXPECT_EQ(4.0, new_vec2.x());
+ EXPECT_EQ(4.0, new_vec2.y());
+ EXPECT_EQ(4.0, new_vec2.z());
+}
+
+TEST_F(VectorTest, Magnitude) {
+ const Vector<3> vec4(1.0, 1.0, 1.0);
+ EXPECT_NEAR(1.732, vec4.Mag(), 0.001);
+
+ EXPECT_NEAR(1.414, (vec5_ - vec6_).Mag(), 0.001);
+ EXPECT_NEAR(1.414, (vec6_ - vec5_).Mag(), 0.001);
+}
+
+TEST_F(VectorTest, Sign) {
+ const Vector<3> vec7 = vec5_ - vec6_;
+ const Vector<3> vec8 = vec6_ - vec5_;
+ EXPECT_EQ(1.0, vec7.x());
+ EXPECT_EQ(1.0, vec7.y());
+ EXPECT_EQ(0.0, vec7.z());
+ EXPECT_EQ(-1.0, vec8.x());
+ EXPECT_EQ(-1.0, vec8.y());
+ EXPECT_EQ(0.0, vec8.z());
+}
+
+TEST_F(VectorTest, Angle) {
+ const Vector<3> vec1(1.0, 0.0, 0.0);
+ const Vector<3> vec2(0.0, 1.0, 0.0);
+ EXPECT_NEAR(M_PI / 2, vec1.AngleTo(vec2), 0.0001);
+ const Vector<3> vec3 = Vector<3>(1.0, 1.0, 0.0);
+ EXPECT_NEAR(M_PI / 4, vec1.AngleTo(vec3), 0.0001);
+
+ const Vector<2> vec4(1, 1);
+ EXPECT_NEAR(M_PI / 4, vec4.AngleToZero(), 0.0001);
+ const Vector<2> vec5(0.5, 0.8660254037844386);
+ EXPECT_NEAR(M_PI / 3, vec5.AngleToZero(), 0.0001);
+}
+
+} // namespace testing
+} // namespace vision
+} // namespace aos