Fix GPU test bugs, a slight optimization, and moving/adding tests

Change-Id: I96e129a1093f6df3466fc02263afbbde6f1a9bb8
Signed-off-by: Justin Turcotte <jjturcot@gmail.com>
diff --git a/frc971/orin/BUILD b/frc971/orin/BUILD
index a81abf1..15aaffe 100644
--- a/frc971/orin/BUILD
+++ b/frc971/orin/BUILD
@@ -24,6 +24,7 @@
         "labeling_allegretti_2019_BKE.h",
         "points.h",
         "threshold.h",
+        "transform_output_iterator.h",
     ],
     copts = [
         "-Wno-pass-failed",
@@ -65,3 +66,31 @@
         "//y2023/vision:ToGreyscaleAndDecimateHalide",
     ],
 )
+
+cc_test(
+    name = "output_iterator_test",
+    srcs = [
+        "output_iterator_test.cc",
+    ],
+    features = ["cuda"],
+    deps = [
+        ":cuda",
+        "//aos/testing:googletest",
+        "//aos/testing:random_seed",
+        "//third_party:cudart",
+    ],
+)
+
+cc_test(
+    name = "points_test",
+    srcs = [
+        "points_test.cc",
+    ],
+    features = ["cuda"],
+    deps = [
+        ":cuda",
+        "//aos/testing:googletest",
+        "//aos/testing:random_seed",
+        "//third_party:cudart",
+    ],
+)
diff --git a/frc971/orin/apriltag.cc b/frc971/orin/apriltag.cc
index 381ffae..ff8ccec 100644
--- a/frc971/orin/apriltag.cc
+++ b/frc971/orin/apriltag.cc
@@ -19,6 +19,7 @@
 #include "aos/time/time.h"
 #include "frc971/orin/labeling_allegretti_2019_BKE.h"
 #include "frc971/orin/threshold.h"
+#include "frc971/orin/transform_output_iterator.h"
 
 namespace frc971 {
 namespace apriltag {
@@ -315,23 +316,32 @@
   __host__ __device__ __forceinline__ IndexPoint
   operator()(cub::KeyValuePair<long, QuadBoundaryPoint> pt) const {
     size_t index = blob_finder_.FindBlobIndex(pt.key);
-    MinMaxExtents extents = blob_finder_.Get(index);
     IndexPoint result(index, pt.value.point_bits());
-
-    float theta =
-        (atan2f(pt.value.y() - extents.cy(), pt.value.x() - extents.cx()) +
-         M_PI) *
-        8e6;
-    long long int theta_int = llrintf(theta);
-
-    result.set_theta(std::max<long long int>(0, theta_int));
-
     return result;
   }
 
   BlobExtentsIndexFinder blob_finder_;
 };
 
+// Calculates Theta for a given IndexPoint
+class AddThetaToIndexPoint {
+ public:
+  AddThetaToIndexPoint(MinMaxExtents *extents_device, size_t num_extents)
+      : blob_finder_(extents_device, num_extents) {}
+  __host__ __device__ __forceinline__ IndexPoint operator()(IndexPoint a) {
+    MinMaxExtents extents = blob_finder_.Get(a.blob_index());
+    float theta =
+        (atan2f(a.y() - extents.cy(), a.x() - extents.cx()) + M_PI) * 8e6;
+    long long int theta_int = llrintf(theta);
+
+    a.set_theta(std::max<long long int>(0, theta_int));
+    return a;
+  }
+
+ private:
+  BlobExtentsIndexFinder blob_finder_;
+};
+
 // TODO(austin): Make something which rewrites points on the way back out to
 // memory and adds the slope.
 
@@ -756,10 +766,16 @@
     cub::ArgIndexInputIterator<QuadBoundaryPoint *> value_index_input_iterator(
         sorted_union_marker_pair_device_.get());
     RewriteToIndexPoint rewrite(extents_device_.get(), num_quads_host);
+
     cub::TransformInputIterator<IndexPoint, RewriteToIndexPoint,
                                 cub::ArgIndexInputIterator<QuadBoundaryPoint *>>
         input_iterator(value_index_input_iterator, rewrite);
 
+    AddThetaToIndexPoint add_theta(extents_device_.get(), num_quads_host);
+
+    TransformOutputIterator<IndexPoint, IndexPoint, AddThetaToIndexPoint>
+        output_iterator(selected_blobs_device_.get(), add_theta);
+
     SelectBlobs select_blobs(
         extents_device_.get(), reduced_dot_blobs_pair_device_.get(),
         min_tag_width_, reversed_border_, normal_border_,
@@ -767,12 +783,12 @@
 
     size_t temp_storage_bytes =
         temp_storage_compressed_filtered_blobs_device_.size();
+
     CHECK_CUDA(cub::DeviceSelect::If(
         temp_storage_compressed_filtered_blobs_device_.get(),
-        temp_storage_bytes, input_iterator, selected_blobs_device_.get(),
+        temp_storage_bytes, input_iterator, output_iterator,
         num_selected_blobs_device_.get(), num_compressed_union_marker_pair_host,
         select_blobs, stream_.get()));
-
     MaybeCheckAndSynchronize();
 
     num_selected_blobs_device_.MemcpyTo(&num_selected_blobs_host);
diff --git a/frc971/orin/cuda_april_tag_test.cc b/frc971/orin/cuda_april_tag_test.cc
index 1d5aa9b..776caf5 100644
--- a/frc971/orin/cuda_april_tag_test.cc
+++ b/frc971/orin/cuda_april_tag_test.cc
@@ -533,7 +533,7 @@
     };
 
     // The theta algorithm used by cuda.
-    auto ComputeTheta = [&blob_centers](QuadBoundaryPoint pair) -> float {
+    auto ComputeTheta = [&blob_centers](QuadBoundaryPoint pair) -> long double {
       const int32_t x = pair.x();
       const int32_t y = pair.y();
 
@@ -546,7 +546,8 @@
       float dx = x - cx;
       float dy = y - cy;
 
-      return atan2f(dy, dx);
+      // make atan2 more accurate than cuda to correctly sort
+      return atan2f64(dy, dx);
     };
 
     for (size_t i = 0; i < points.size();) {
@@ -1083,7 +1084,7 @@
           slope_sorted_expected_grouped_points[i];
 
       CHECK_EQ(cuda_grouped_blob.size(), slope_sorted_points.size());
-      if (VLOG_IS_ON(1) && cuda_grouped_blob[0].blob_index() == 73) {
+      if (VLOG_IS_ON(1) && cuda_grouped_blob[0].blob_index() == 160) {
         for (size_t j = 0; j < cuda_grouped_points[i].size(); ++j) {
           LOG(INFO) << "For blob " << cuda_grouped_blob[0].blob_index()
                     << ", got " << cuda_grouped_blob[j] << " ("
@@ -1094,20 +1095,86 @@
                     << slope_sorted_points[j].y() << ")";
         }
       }
+
       size_t missmatched_runs = 0;
+
+      // recalculates the theta to be used in the following check
+      std::map<uint64_t, std::pair<float, float>> blob_centers;
+      auto ComputeTheta =
+          [&blob_centers](QuadBoundaryPoint pair) -> long double {
+        const int32_t x = pair.x();
+        const int32_t y = pair.y();
+
+        auto blob_center = blob_centers.find(pair.rep01());
+        CHECK(blob_center != blob_centers.end());
+
+        const float cx = blob_center->second.first;
+        const float cy = blob_center->second.second;
+
+        float dx = x - cx;
+        float dy = y - cy;
+
+        return atan2f64(dy, dx);
+      };
+
+      // refinds blob centers
+      for (size_t i = 0; i < slope_sorted_points.size();) {
+        uint64_t first_rep01 = slope_sorted_points[i].rep01();
+
+        int min_x, min_y, max_x, max_y;
+        min_x = max_x = slope_sorted_points[i].x();
+        min_y = max_y = slope_sorted_points[i].y();
+
+        size_t j = i;
+        for (; j < slope_sorted_points.size() &&
+               slope_sorted_points[j].rep01() == first_rep01;
+             ++j) {
+          QuadBoundaryPoint pt = slope_sorted_points[j];
+
+          int x = pt.x();
+          int y = pt.y();
+          min_x = std::min(min_x, x);
+          max_x = std::max(max_x, x);
+          min_y = std::min(min_y, y);
+          max_y = std::max(max_y, y);
+        }
+
+        const float cx = (max_x + min_x) * 0.5 + 0.05118;
+        const float cy = (max_y + min_y) * 0.5 + -0.028581;
+
+        blob_centers[first_rep01] = std::make_pair(cx, cy);
+        i = j;
+      }
+
       for (size_t j = 0; j < cuda_grouped_points[i].size(); ++j) {
         if (cuda_grouped_blob[j].x() != slope_sorted_points[j].x() ||
             cuda_grouped_blob[j].y() != slope_sorted_points[j].y()) {
+          size_t allowable_swapped_indices = 1;
+          long double max_allowed_imprecision = 3e-7;
+          // search through indixes j - a to j + a to see if they're swapped
+          // also only doesn't count if the precision needed to differentiate is
+          // less than max_allowed_imprecision
+          for (size_t k = j - allowable_swapped_indices;
+               k <= j + allowable_swapped_indices; k++) {
+            if (cuda_grouped_blob[j].x() == slope_sorted_points[k].x() &&
+                cuda_grouped_blob[j].y() == slope_sorted_points[k].y() &&
+                abs(ComputeTheta(slope_sorted_points[k]) -
+                    ComputeTheta(slope_sorted_points[j])) <
+                    max_allowed_imprecision) {
+              continue;
+            }
+          }
           ++missmatched_points;
           ++missmatched_runs;
           // We shouldn't see a lot of points in a row which don't match.
-          CHECK_LE(missmatched_runs, 4u);
           VLOG(1) << "Missmatched point in blob "
                   << cuda_grouped_blob[0].blob_index() << ", point " << j
                   << " (" << cuda_grouped_blob[j].x() << ", "
                   << cuda_grouped_blob[j].y() << ") vs ("
                   << slope_sorted_points[j].x() << ", "
-                  << slope_sorted_points[j].y() << ")";
+                  << slope_sorted_points[j].y() << ")"
+                  << " in size " << cuda_grouped_points[i].size();
+          CHECK_LE(missmatched_runs, 4u);
         } else {
           missmatched_runs = 0;
         }
@@ -1456,108 +1523,15 @@
 INSTANTIATE_TEST_SUITE_P(
     CapturedImages, SingleAprilDetectionTest,
     ::testing::Values("bfbs-capture-2013-01-18_08-54-16.869057537.bfbs",
-                      "bfbs-capture-2013-01-18_08-54-09.501047728.bfbs"
-                      // TODO(austin): Figure out why these fail...
-                      //"bfbs-capture-2013-01-18_08-51-24.861065764.bfbs",
-                      //"bfbs-capture-2013-01-18_08-52-01.846912552.bfbs",
-                      //"bfbs-capture-2013-01-18_08-52-33.462848049.bfbs",
-                      //"bfbs-capture-2013-01-18_08-54-24.931661979.bfbs",
-                      //"bfbs-capture-2013-01-18_09-29-16.806073486.bfbs",
-                      //"bfbs-capture-2013-01-18_09-33-00.993756514.bfbs",
-                      //"bfbs-capture-2013-01-18_08-57-00.171120695.bfbs"
-                      //"bfbs-capture-2013-01-18_08-57-17.858752817.bfbs",
-                      //"bfbs-capture-2013-01-18_08-57-08.096597542.bfbs"
-                      ));
-
-// Tests that QuadBoundaryPoint doesn't corrupt data.
-TEST(QuadBoundaryPoint, MasksWork) {
-  std::mt19937 generator(aos::testing::RandomSeed());
-  std::uniform_int_distribution<uint32_t> random_rep_scalar(0, 0xfffff);
-  std::uniform_int_distribution<uint32_t> random_point_scalar(0, 0x3ff);
-  std::uniform_int_distribution<uint32_t> random_dxy_scalar(0, 3);
-  std::uniform_int_distribution<uint32_t> random_bool(0, 1);
-
-  QuadBoundaryPoint point;
-
-  EXPECT_EQ(point.key, 0);
-
-  for (int i = 0; i < 25; ++i) {
-    const uint32_t rep0 = random_rep_scalar(generator);
-    for (int j = 0; j < 25; ++j) {
-      const uint32_t rep1 = random_rep_scalar(generator);
-      for (int k = 0; k < 25; ++k) {
-        const uint32_t x = random_point_scalar(generator);
-        const uint32_t y = random_point_scalar(generator);
-        for (int k = 0; k < 25; ++k) {
-          const uint32_t dxy = random_dxy_scalar(generator);
-          for (int m = 0; m < 2; ++m) {
-            const bool black_to_white = random_bool(generator) == 1;
-            if (point.rep0() != rep0) {
-              point.set_rep0(rep0);
-            }
-            if (point.rep1() != rep1) {
-              point.set_rep1(rep1);
-            }
-            if (point.base_x() != x || point.base_y() != y) {
-              point.set_base_xy(x, y);
-            }
-            switch (dxy) {
-              case 0:
-                if (point.dx() != 1 || point.dy() != 0) {
-                  point.set_dxy(dxy);
-                }
-                break;
-              case 1:
-                if (point.dx() != 1 || point.dy() != 1) {
-                  point.set_dxy(dxy);
-                }
-                break;
-              case 2:
-                if (point.dx() != 0 || point.dy() != 1) {
-                  point.set_dxy(dxy);
-                }
-                break;
-              case 3:
-                if (point.dx() != -1 || point.dy() != 1) {
-                  point.set_dxy(dxy);
-                }
-                break;
-            }
-            if (black_to_white != point.black_to_white()) {
-              point.set_black_to_white(black_to_white);
-            }
-
-            ASSERT_EQ(point.rep0(), rep0);
-            ASSERT_EQ(point.rep1(), rep1);
-            ASSERT_EQ(point.base_x(), x);
-            ASSERT_EQ(point.base_y(), y);
-            switch (dxy) {
-              case 0:
-                ASSERT_EQ(point.dx(), 1);
-                ASSERT_EQ(point.dy(), 0);
-                break;
-              case 1:
-                ASSERT_EQ(point.dx(), 1);
-                ASSERT_EQ(point.dy(), 1);
-                break;
-              case 2:
-                ASSERT_EQ(point.dx(), 0);
-                ASSERT_EQ(point.dy(), 1);
-                break;
-              case 3:
-                ASSERT_EQ(point.dx(), -1);
-                ASSERT_EQ(point.dy(), 1);
-                break;
-            }
-            ASSERT_EQ(point.x(), x * 2 + point.dx());
-            ASSERT_EQ(point.y(), y * 2 + point.dy());
-
-            ASSERT_EQ(point.black_to_white(), black_to_white);
-          }
-        }
-      }
-    }
-  }
-}
+                      "bfbs-capture-2013-01-18_08-54-09.501047728.bfbs",
+                      "bfbs-capture-2013-01-18_08-51-24.861065764.bfbs",
+                      "bfbs-capture-2013-01-18_08-52-01.846912552.bfbs",
+                      "bfbs-capture-2013-01-18_08-52-33.462848049.bfbs",
+                      "bfbs-capture-2013-01-18_08-54-24.931661979.bfbs",
+                      "bfbs-capture-2013-01-18_09-29-16.806073486.bfbs",
+                      "bfbs-capture-2013-01-18_09-33-00.993756514.bfbs",
+                      "bfbs-capture-2013-01-18_08-57-00.171120695.bfbs",
+                      "bfbs-capture-2013-01-18_08-57-17.858752817.bfbs",
+                      "bfbs-capture-2013-01-18_08-57-08.096597542.bfbs"));
 
 }  // namespace frc971::apriltag::testing
diff --git a/frc971/orin/output_iterator_test.cc b/frc971/orin/output_iterator_test.cc
new file mode 100644
index 0000000..d4087d0
--- /dev/null
+++ b/frc971/orin/output_iterator_test.cc
@@ -0,0 +1,72 @@
+#include <random>
+
+#include "gtest/gtest.h"
+
+#include "aos/testing/random_seed.h"
+#include "frc971/orin/transform_output_iterator.h"
+
+namespace frc971::apriltag::testing {
+
+struct Mul2 {
+  uint64_t operator()(const uint32_t num) const {
+    return static_cast<uint64_t>(num) * 2;
+  }
+};
+
+// Tests that the transform output iterator both transforms and otherwise acts
+// like a normal pointer
+TEST(TransformOutputIteratorTest, IntArr) {
+  std::mt19937 generator(aos::testing::RandomSeed());
+  std::uniform_int_distribution<uint32_t> random_uint32(0, UINT32_MAX);
+
+  uint32_t *nums_in = (uint32_t *)malloc(UINT32_WIDTH * 20);
+  uint64_t *nums_out = (uint64_t *)malloc(UINT64_WIDTH * 20);
+  uint64_t *expected_out = (uint64_t *)malloc(UINT64_WIDTH * 20);
+
+  for (size_t i = 0; i < 20; i++) {
+    nums_in[i] = random_uint32(generator);
+    expected_out[i] = 2 * static_cast<uint64_t>(nums_in[i]);
+  }
+
+  Mul2 convert_op;
+  TransformOutputIterator<uint32_t, uint64_t, Mul2> itr(nums_out, convert_op);
+
+  // check indirection, array index, increments, and decrements
+  EXPECT_EQ(itr == itr, true);
+  EXPECT_EQ(itr != itr, false);
+  *itr = *nums_in;          // [0]
+  *(++itr) = *(++nums_in);  // [1]
+  auto temp = itr;
+  auto temp2 = itr++;
+  EXPECT_EQ(temp, temp2);  // [2]
+  EXPECT_NE(temp, itr);
+  EXPECT_NE(temp2, itr);
+  nums_in++;        // [2]
+  *itr = *nums_in;  // [2]
+  auto temp3 = ++itr;
+  auto temp4 = itr;
+  EXPECT_EQ(temp3, temp4);  // [3]
+  EXPECT_EQ(temp3, itr);
+  itr--;  // [2]
+  auto temp5 = --itr;
+  auto temp6 = itr;
+  EXPECT_EQ(temp5, temp6);  // [1]
+  EXPECT_EQ(temp5, itr);
+  nums_in--;  // [1]
+  auto temp7 = itr;
+  auto temp8 = itr--;
+  EXPECT_EQ(temp7, temp8);  // [0]
+  EXPECT_NE(temp7, itr);
+  EXPECT_NE(temp8, itr);
+  nums_in--;  // [0]
+
+  for (size_t i = 3; i < 20; i++) {
+    itr[i] = nums_in[i];  // [3] -> [19]
+  }
+
+  // check expected out and converted out
+  for (size_t i = 0; i < 20; i++) {
+    EXPECT_EQ(expected_out[i], nums_out[i]);
+  }
+}
+}  // namespace frc971::apriltag::testing
diff --git a/frc971/orin/points.h b/frc971/orin/points.h
index 312d90a..d530a17 100644
--- a/frc971/orin/points.h
+++ b/frc971/orin/points.h
@@ -203,9 +203,17 @@
   }
 
   // See QuadBoundaryPoint for a description of the rest of these.
+  // Sets the 10 bit x and y.
+  __forceinline__ __host__ __device__ void set_base_xy(uint32_t x, uint32_t y) {
+    key = (key & 0xffffffffff00000full) |
+          (static_cast<uint64_t>(x & 0x3ff) << 14) |
+          (static_cast<uint64_t>(y & 0x3ff) << 4);
+  }
+
   __forceinline__ __host__ __device__ uint32_t base_x() const {
     return ((key >> 14) & 0x3ff);
   }
+
   __forceinline__ __host__ __device__ uint32_t base_y() const {
     return ((key >> 4) & 0x3ff);
   }
diff --git a/frc971/orin/points_test.cc b/frc971/orin/points_test.cc
new file mode 100644
index 0000000..852a1e4
--- /dev/null
+++ b/frc971/orin/points_test.cc
@@ -0,0 +1,204 @@
+#include "frc971/orin/points.h"
+
+#include <random>
+
+#include "gtest/gtest.h"
+
+#include "aos/testing/random_seed.h"
+
+namespace frc971::apriltag::testing {
+
+// Tests that QuadBoundaryPoint doesn't corrupt data.
+TEST(QuadBoundaryPoint, MasksWork) {
+  std::mt19937 generator(aos::testing::RandomSeed());
+  std::uniform_int_distribution<uint32_t> random_rep_scalar(0, 0xfffff);
+  std::uniform_int_distribution<uint32_t> random_point_scalar(0, 0x3ff);
+  std::uniform_int_distribution<uint32_t> random_dxy_scalar(0, 3);
+  std::uniform_int_distribution<uint32_t> random_bool(0, 1);
+
+  QuadBoundaryPoint point;
+
+  EXPECT_EQ(point.key, 0);
+
+  for (int i = 0; i < 25; ++i) {
+    const uint32_t rep0 = random_rep_scalar(generator);
+    for (int j = 0; j < 25; ++j) {
+      const uint32_t rep1 = random_rep_scalar(generator);
+      for (int k = 0; k < 25; ++k) {
+        const uint32_t x = random_point_scalar(generator);
+        const uint32_t y = random_point_scalar(generator);
+        for (int k = 0; k < 25; ++k) {
+          const uint32_t dxy = random_dxy_scalar(generator);
+          for (int m = 0; m < 2; ++m) {
+            const bool black_to_white = random_bool(generator) == 1;
+
+            if (point.rep0() != rep0) {
+              point.set_rep0(rep0);
+            }
+
+            if (point.rep1() != rep1) {
+              point.set_rep1(rep1);
+            }
+
+            if (point.base_x() != x || point.base_y() != y) {
+              point.set_base_xy(x, y);
+            }
+
+            switch (dxy) {
+              case 0:
+                if (point.dx() != 1 || point.dy() != 0) {
+                  point.set_dxy(dxy);
+                }
+                break;
+              case 1:
+                if (point.dx() != 1 || point.dy() != 1) {
+                  point.set_dxy(dxy);
+                }
+                break;
+              case 2:
+                if (point.dx() != 0 || point.dy() != 1) {
+                  point.set_dxy(dxy);
+                }
+                break;
+              case 3:
+                if (point.dx() != -1 || point.dy() != 1) {
+                  point.set_dxy(dxy);
+                }
+                break;
+            }
+
+            if (black_to_white != point.black_to_white()) {
+              point.set_black_to_white(black_to_white);
+            }
+
+            EXPECT_EQ(point.rep0(), rep0);
+            EXPECT_EQ(point.rep1(), rep1);
+            EXPECT_EQ(point.base_x(), x);
+            EXPECT_EQ(point.base_y(), y);
+            switch (dxy) {
+              case 0:
+                EXPECT_EQ(point.dx(), 1);
+                EXPECT_EQ(point.dy(), 0);
+                break;
+              case 1:
+                EXPECT_EQ(point.dx(), 1);
+                EXPECT_EQ(point.dy(), 1);
+                break;
+              case 2:
+                EXPECT_EQ(point.dx(), 0);
+                EXPECT_EQ(point.dy(), 1);
+                break;
+              case 3:
+                EXPECT_EQ(point.dx(), -1);
+                EXPECT_EQ(point.dy(), 1);
+                break;
+            }
+
+            EXPECT_EQ(point.x(), x * 2 + point.dx());
+            EXPECT_EQ(point.y(), y * 2 + point.dy());
+
+            EXPECT_EQ(point.black_to_white(), black_to_white);
+          }
+        }
+      }
+    }
+  }
+}
+
+// Tests that IndexPoint doesn't corrupt anything
+TEST(IndexPoint, MasksWork) {
+  std::mt19937 generator(
+      aos::testing::RandomSeed());  // random_uint32(generator)
+  std::uniform_int_distribution<uint32_t> random_blob_index(0, 0xfff);
+  std::uniform_int_distribution<uint32_t> random_theta(0, 0xfffffff);
+  std::uniform_int_distribution<uint32_t> random_point_scalar(0, 0x3ff);
+  std::uniform_int_distribution<uint32_t> random_dxy_scalar(0, 3);
+  std::uniform_int_distribution<uint32_t> random_bool(0, 1);
+
+  IndexPoint point;
+
+  for (int i = 0; i < 25; i++) {
+    const uint32_t blob_index = random_blob_index(generator);
+    for (int j = 0; j < 25; j++) {
+      const uint32_t theta = random_theta(generator);
+      for (int k = 0; k < 25; ++k) {
+        const uint32_t x = random_point_scalar(generator);
+        const uint32_t y = random_point_scalar(generator);
+        for (int k = 0; k < 25; ++k) {
+          const uint32_t dxy = random_dxy_scalar(generator);
+          for (int m = 0; m < 2; ++m) {
+            const bool black_to_white = random_bool(generator) == 1;
+
+            if (point.blob_index() != blob_index) {
+              point.set_blob_index(blob_index);
+            }
+
+            if (point.theta() != theta) {
+              point.set_theta(theta);
+            }
+
+            if (point.base_x() != x || point.base_y() != y) {
+              point.set_base_xy(x, y);
+            }
+
+            switch (dxy) {
+              case 0:
+                if (point.dx() != 1 || point.dy() != 0) {
+                  point.set_dxy(dxy);
+                }
+                break;
+              case 1:
+                if (point.dx() != 1 || point.dy() != 1) {
+                  point.set_dxy(dxy);
+                }
+                break;
+              case 2:
+                if (point.dx() != 0 || point.dy() != 1) {
+                  point.set_dxy(dxy);
+                }
+                break;
+              case 3:
+                if (point.dx() != -1 || point.dy() != 1) {
+                  point.set_dxy(dxy);
+                }
+                break;
+            }
+
+            if (black_to_white != point.black_to_white()) {
+              point.set_black_to_white(black_to_white);
+            }
+
+            EXPECT_EQ(point.blob_index(), blob_index);
+            EXPECT_EQ(point.theta(), theta);
+            EXPECT_EQ(point.base_x(), x);
+            EXPECT_EQ(point.base_y(), y);
+
+            switch (dxy) {
+              case 0:
+                EXPECT_EQ(point.dx(), 1);
+                EXPECT_EQ(point.dy(), 0);
+                break;
+              case 1:
+                EXPECT_EQ(point.dx(), 1);
+                EXPECT_EQ(point.dy(), 1);
+                break;
+              case 2:
+                EXPECT_EQ(point.dx(), 0);
+                EXPECT_EQ(point.dy(), 1);
+                break;
+              case 3:
+                EXPECT_EQ(point.dx(), -1);
+                EXPECT_EQ(point.dy(), 1);
+                break;
+            }
+            EXPECT_EQ(point.x(), x * 2 + point.dx());
+            EXPECT_EQ(point.y(), y * 2 + point.dy());
+
+            EXPECT_EQ(point.black_to_white(), black_to_white);
+          }
+        }
+      }
+    }
+  }
+}
+}  // namespace frc971::apriltag::testing
diff --git a/frc971/orin/transform_output_iterator.h b/frc971/orin/transform_output_iterator.h
new file mode 100644
index 0000000..051baea
--- /dev/null
+++ b/frc971/orin/transform_output_iterator.h
@@ -0,0 +1,92 @@
+#ifndef FRC971_TRANSFORM_OUTPUT_ITERATOR_
+#define FRC971_TRANSFORM_OUTPUT_ITERATOR_
+
+namespace frc971 {
+namespace apriltag {
+
+// template class that allows conversions at the output of a cub algorithm
+template <typename InputType, typename OutputType, typename ConversionOp,
+          typename OffsetT = ptrdiff_t>
+class TransformOutputIterator {
+ private:
+  // proxy object to be able to convert when assigning value
+  struct Reference {
+    OutputType *ptr;
+    ConversionOp convert_op;
+    __host__ __device__ Reference(OutputType *ptr, ConversionOp convert_op)
+        : ptr(ptr), convert_op(convert_op) {}
+    __host__ __device__ Reference operator=(InputType val) {
+      *ptr = convert_op(val);
+      return *this;
+    }
+  };
+
+ public:
+  // typedefs may not be neeeded for iterator to work but is here to maintain
+  // similarity to cub's CacheModifiedOutputIterator
+  typedef TransformOutputIterator self_type;
+  typedef OffsetT difference_type;
+  typedef void value_type;
+  typedef void *pointer;
+  typedef Reference reference;
+
+  TransformOutputIterator(OutputType *ptr, const ConversionOp convert_op)
+      : convert_op(convert_op), ptr(ptr) {}
+
+  // postfix addition
+  __host__ __device__ __forceinline__ self_type operator++(int) {
+    self_type retval = *this;
+    ptr++;
+    return retval;
+  }
+
+  // prefix addition
+  __host__ __device__ __forceinline__ self_type operator++() {
+    ptr++;
+    return *this;
+  }
+
+  // postfix subtraction
+  __host__ __device__ __forceinline__ self_type operator--(int) {
+    self_type retval = *this;
+    ptr--;
+    return retval;
+  }
+
+  // prefix subtraction
+  __host__ __device__ __forceinline__ self_type operator--() {
+    ptr--;
+    return *this;
+  }
+
+  // indirection
+  __host__ __device__ __forceinline__ reference operator*() const {
+    return Reference(ptr, convert_op);
+  }
+
+  // array index
+  __host__ __device__ __forceinline__ reference operator[](int num) const {
+    return Reference(ptr + num, convert_op);
+  }
+
+  // equal to
+  __host__ __device__ __forceinline__ bool operator==(
+      const TransformOutputIterator &rhs) const {
+    return ptr == rhs.ptr;
+  }
+
+  // not equal to
+  __host__ __device__ __forceinline__ bool operator!=(
+      const TransformOutputIterator &rhs) const {
+    return ptr != rhs.ptr;
+  }
+
+ private:
+  const ConversionOp convert_op;
+  OutputType *ptr;
+};
+
+}  // namespace apriltag
+}  // namespace frc971
+
+#endif  // FRC971_TRANSFORM_OUTPUT_ITERATOR_
diff --git a/third_party/apriltag/apriltag_quad_thresh.c b/third_party/apriltag/apriltag_quad_thresh.c
index e4c8b4e..1ce73ad 100644
--- a/third_party/apriltag/apriltag_quad_thresh.c
+++ b/third_party/apriltag/apriltag_quad_thresh.c
@@ -1014,7 +1014,9 @@
       int x = 0;
       if (v != 127) {
         DO_UNIONFIND2(0, -1);
-        DO_UNIONFIND2(1, -1);
+        if (v == 255) {
+            DO_UNIONFIND2(1, -1);
+        }
       }
     }