Merge "Solve the newtons method problem simultaneously when possible"
diff --git a/aos/network/multinode_timestamp_filter.cc b/aos/network/multinode_timestamp_filter.cc
index 07dffe8..87b5752 100644
--- a/aos/network/multinode_timestamp_filter.cc
+++ b/aos/network/multinode_timestamp_filter.cc
@@ -33,7 +33,7 @@
 
 TimestampProblem::TimestampProblem(size_t count) {
   CHECK_GT(count, 1u);
-  filters_.resize(count);
+  clock_offset_filter_for_node_.resize(count);
   base_clock_.resize(count);
   live_.resize(count, true);
   node_mapping_.resize(count, 0);
@@ -48,8 +48,8 @@
 
 bool TimestampProblem::ValidateSolution(std::vector<BootTimestamp> solution) {
   bool success = true;
-  for (size_t i = 0u; i < filters_.size(); ++i) {
-    for (const struct FilterPair &filter : filters_[i]) {
+  for (size_t i = 0u; i < clock_offset_filter_for_node_.size(); ++i) {
+    for (const struct FilterPair &filter : clock_offset_filter_for_node_[i]) {
       success = success && filter.filter->ValidateSolution(
                                solution[i], solution[filter.b_index]);
     }
@@ -60,8 +60,8 @@
 Eigen::VectorXd TimestampProblem::Gradient(
     const Eigen::Ref<Eigen::VectorXd> time_offsets) const {
   Eigen::VectorXd grad = Eigen::VectorXd::Zero(live_nodes_);
-  for (size_t i = 0; i < filters_.size(); ++i) {
-    for (const struct FilterPair &filter : filters_[i]) {
+  for (size_t i = 0; i < clock_offset_filter_for_node_.size(); ++i) {
+    for (const struct FilterPair &filter : clock_offset_filter_for_node_[i]) {
       // Reminder, our cost function has the following form.
       //   ((tb - (1 + ma) ta - ba)^2
       // We are ignoring the slope when taking the derivative and applying the
@@ -87,8 +87,8 @@
     const Eigen::Ref<Eigen::VectorXd> /*time_offsets*/) const {
   Eigen::MatrixXd hessian = Eigen::MatrixXd::Zero(live_nodes_, live_nodes_);
 
-  for (size_t i = 0; i < filters_.size(); ++i) {
-    for (const struct FilterPair &filter : filters_[i]) {
+  for (size_t i = 0; i < clock_offset_filter_for_node_.size(); ++i) {
+    for (const struct FilterPair &filter : clock_offset_filter_for_node_[i]) {
       // Reminder, our cost function has the following form.
       //   ((tb - (1 + ma) ta - ba)^2
       // We are ignoring the slope when taking the derivative and applying the
@@ -107,8 +107,9 @@
   return hessian;
 }
 
-Eigen::VectorXd TimestampProblem::Newton(
-    const Eigen::Ref<Eigen::VectorXd> time_offsets) const {
+std::tuple<Eigen::VectorXd, size_t> TimestampProblem::Newton(
+    const Eigen::Ref<Eigen::VectorXd> time_offsets,
+    const std::vector<logger::BootTimestamp> &points) const {
   CHECK_GT(live_nodes_, 0u) << ": No live nodes to solve for.";
   // TODO(austin): Each of the DCost functions does a binary search of the
   // timestamps list.  By the time we have computed the gradient and Hessian,
@@ -161,10 +162,6 @@
   // This ends up working surprisingly well.  A toy problem with 2 line segments
   // and 2 nodes converges in 2 iterations.
   //
-  // TODO(austin): Maybe drive the distributed so we drive the min clock?  This
-  // will solve the for loop at the same time, making things faster.
-  //
-  //
   // To ensure reliable convergence, we want to make 1 adjustment to the above
   // problem statement.
   //
@@ -192,23 +189,56 @@
   Eigen::VectorXd b = Eigen::VectorXd::Zero(live_nodes_ + 1);
   b.block(0, 0, live_nodes_, 1) = -grad;
 
-  // Since we are driving the clock on the solution node to the base_clock, that
-  // is equivalent to driving the solution node's offset to 0.
-  b(live_nodes_) = -time_offsets(NodeToFullSolutionIndex(solution_node_));
+  // Now, we want to set b(live_nodes_) to be -time_offset for the earliest
+  // clock.
+  //
+  // To save ourselves a fair amount of compute, we can take the min here.  That
+  // will drive us back the furthest back in time for all provided nodes without
+  // having to solve N times and look for the earliest solution.
+  size_t solution_node = std::numeric_limits<size_t>::max();
+  for (size_t i = 0; i < points.size(); ++i) {
+    if (points[i] == logger::BootTimestamp::max_time()) {
+      continue;
+    }
 
-  return a.colPivHouseholderQr().solve(b);
+    CHECK_EQ(points[i].boot, base_clock(i).boot);
+    const double candidate_b =
+        chrono::duration<double, std::nano>(points[i].time - base_clock(i).time)
+            .count() -
+        time_offsets(NodeToFullSolutionIndex(i));
+    if (candidate_b < b(live_nodes_) ||
+        solution_node == std::numeric_limits<size_t>::max()) {
+      VLOG(2) << "Node " << i << ", solution time " << points[i]
+              << ", base_clock " << base_clock(i) << ", error " << candidate_b
+              << " time offset " << time_offsets(NodeToFullSolutionIndex(i));
+      b(live_nodes_) = candidate_b;
+      solution_node = i;
+    }
+  }
+
+  CHECK_NE(solution_node, std::numeric_limits<size_t>::max())
+      << ": No solution nodes, please investigate";
+
+  return std::tuple<Eigen::VectorXd, size_t>(a.colPivHouseholderQr().solve(b),
+                                             solution_node);
 }
 
-std::vector<BootTimestamp> TimestampProblem::SolveNewton() {
+std::tuple<std::vector<BootTimestamp>, size_t> TimestampProblem::SolveNewton(
+    const std::vector<logger::BootTimestamp> &points) {
   constexpr int kMaxIterations = 200;
   MaybeUpdateNodeMapping();
-  VLOG(2) << "Solving for node " << solution_node_ << " at "
-          << base_clock(solution_node_);
+  for (size_t i = 0; i < points.size(); ++i) {
+    if (points[i] != logger::BootTimestamp::max_time()) {
+      VLOG(2) << "Solving for node " << i << " at " << points[i];
+    }
+  }
   Eigen::VectorXd data = Eigen::VectorXd::Zero(live_nodes_);
 
   int solution_number = 0;
+  size_t solution_node;
   while (true) {
-    Eigen::VectorXd step = Newton(data);
+    Eigen::VectorXd step;
+    std::tie(step, solution_node) = Newton(data, points);
 
     if (VLOG_IS_ON(2)) {
       // Print out the gradient ignoring the component removed by the equality
@@ -232,7 +262,11 @@
     // gradient since the Hessian is constant), and our solution node's time is
     // also close.
     if (step.block(0, 0, live_nodes_, 1).lpNorm<Eigen::Infinity>() < 1e-4 &&
-        std::abs(data(NodeToFullSolutionIndex(solution_node_))) < 1e-4) {
+        std::abs(
+            chrono::duration<double, std::nano>(points[solution_node].time -
+                                                base_clock(solution_node).time)
+                .count() -
+            data(NodeToFullSolutionIndex(solution_node))) < 1e-4) {
       break;
     }
 
@@ -251,8 +285,7 @@
     // since it makes it hard to debug as the data keeps jumping around.
     for (size_t j = 0; j < size(); ++j) {
       const size_t solution_index = NodeToFullSolutionIndex(j);
-      if (j != solution_node_ && live(j) &&
-          std::abs(data(solution_index)) > 1000) {
+      if (live(j) && std::abs(data(solution_index)) > 1000) {
         int64_t dsolution =
             static_cast<int64_t>(std::round(data(solution_index)));
         base_clock_[j].time += chrono::nanoseconds(dsolution);
@@ -267,9 +300,12 @@
     }
   }
 
-  VLOG(2) << "Solving for node " << solution_node_ << " of "
-          << base_clock(solution_node_) << " in " << solution_number
-          << " cycles";
+  for (size_t i = 0; i < points.size(); ++i) {
+    if (points[i] != logger::BootTimestamp::max_time()) {
+      VLOG(2) << "Solving for node " << i << " of " << base_clock(i) << " in "
+              << solution_number << " cycles";
+    }
+  }
   std::vector<BootTimestamp> result(size());
   for (size_t i = 0; i < size(); ++i) {
     if (live(i)) {
@@ -288,17 +324,33 @@
     LOG(FATAL) << "Failed to converge.";
   }
 
-  return result;
+  return std::make_pair(std::move(result), solution_node);
+}
+
+void TimestampProblem::MaybeUpdateNodeMapping() {
+  if (node_mapping_valid_) {
+    return;
+  }
+  size_t live_node_index = 0;
+  for (size_t i = 0; i < node_mapping_.size(); ++i) {
+    if (live(i)) {
+      node_mapping_[i] = live_node_index;
+      ++live_node_index;
+    } else {
+      node_mapping_[i] = std::numeric_limits<size_t>::max();
+    }
+  }
+  live_nodes_ = live_node_index;
+  node_mapping_valid_ = true;
 }
 
 void TimestampProblem::Debug() {
   MaybeUpdateNodeMapping();
-  LOG(INFO) << "Solving for node " << solution_node_ << " at "
-            << base_clock_[solution_node_];
 
-  std::vector<std::vector<std::string>> gradients(filters_.size());
-  for (size_t i = 0u; i < filters_.size(); ++i) {
-    for (const struct FilterPair &filter : filters_[i]) {
+  std::vector<std::vector<std::string>> gradients(
+      clock_offset_filter_for_node_.size());
+  for (size_t i = 0u; i < clock_offset_filter_for_node_.size(); ++i) {
+    for (const struct FilterPair &filter : clock_offset_filter_for_node_[i]) {
       if (live(i) && live(filter.b_index)) {
         // TODO(austin): This should be right, but I haven't gone and spent a
         // bunch of time making sure it all matches perfectly.  We aren't
@@ -317,13 +369,13 @@
     }
   }
 
-  for (size_t i = 0u; i < filters_.size(); ++i) {
+  for (size_t i = 0u; i < clock_offset_filter_for_node_.size(); ++i) {
     LOG(INFO) << (live(i) ? "live" : "dead") << " Grad[" << i << "] = "
               << (gradients[i].empty() ? std::string("0.0")
                                        : absl::StrJoin(gradients[i], " + "));
   }
 
-  for (size_t i = 0u; i < filters_.size(); ++i) {
+  for (size_t i = 0u; i < clock_offset_filter_for_node_.size(); ++i) {
     LOG(INFO) << (live(i) ? "live" : "dead") << " base_clock[" << i
               << "] = " << base_clock_[i];
   }
@@ -456,8 +508,6 @@
   size_t index = times_.size() - 2u;
   while (index > 0u) {
     // TODO(austin): Binary search.
-    //LOG(INFO) << std::get<1>(times_[index])[node_index] << " <= " << time
-              //<< "?";
     if (std::get<1>(times_[index])[node_index] <= time) {
       break;
     }
@@ -533,8 +583,6 @@
   // where most of the times we care about will be.
   size_t index = times_.size() - 2u;
   while (index > 0u) {
-    //LOG(INFO) << "Considering " << std::get<0>(times_[index + 1]) << " index "
-              //<< index << " vs " << time;
     // If we are searching across a reboot, we want both the before and after
     // time.  We will be asked to solve for the after, so make sure when a time
     // matches exactly, we pick the time before, not the time after.
@@ -565,8 +613,6 @@
     }
   }
 
-  //LOG(INFO) << "d0 " << d0 << " time " << time << " d1 " << d1 << " t0 " << t0
-            //<< " t1 " << t1;
   if (time > d1) {
     const BootTimestamp result = t1 + (time - d1);
     VLOG(3) << "FromDistributedClock(" << node_index << ", " << time << ", "
@@ -1138,7 +1184,8 @@
           }
           all_live_nodes.Set(node_a_index, true);
           all_live_nodes.Set(filter.b_index, true);
-          problem.add_filter(node_a_index, filter.filter, filter.b_index);
+          problem.add_clock_offset_filter(node_a_index, filter.filter,
+                                          filter.b_index);
 
           if (timestamp_mappers_[node_a_index] != nullptr) {
             // Now, we have cases at startup where we have a couple of points
@@ -1248,106 +1295,214 @@
   return problem;
 }
 
-std::tuple<NoncausalTimestampFilter *, std::vector<BootTimestamp>, int>
-MultiNodeNoncausalOffsetEstimator::NextSolution(
-    TimestampProblem *problem, const std::vector<BootTimestamp> &base_times) {
-  // Ok, now solve for the minimum time on each channel.
-  std::vector<BootTimestamp> result_times;
+std::tuple<std::vector<MultiNodeNoncausalOffsetEstimator::CandidateTimes>, bool>
+MultiNodeNoncausalOffsetEstimator::MakeCandidateTimes() const {
+  bool boots_all_match = true;
+  std::vector<CandidateTimes> candidate_times;
+  candidate_times.resize(last_monotonics_.size());
 
-  struct CandidateTimes {
+  size_t node_a_index = 0;
+  size_t last_boot = std::numeric_limits<size_t>::max();
+  for (const auto &filters : filters_per_node_) {
+    VLOG(2) << "Investigating filter for node " << node_a_index;
     BootTimestamp next_node_time = BootTimestamp::max_time();
     BootDuration next_node_duration;
     NoncausalTimestampFilter *next_node_filter = nullptr;
-  };
+    // Find the oldest time for each node in each filter, and solve for that
+    // time.  That gives us the next timestamp for this node.
+    size_t filter_index = 0;
+    for (const auto &filter : filters) {
+      std::optional<std::tuple<BootTimestamp, BootDuration>> candidate =
+          filter.filter->Observe();
 
-  std::vector<CandidateTimes> candidate_times;
-  candidate_times.resize(base_times.size());
-  {
-    size_t node_a_index = 0;
-    for (const auto &filters : filters_per_node_) {
-      VLOG(2) << "Investigating filter for node " << node_a_index;
-      BootTimestamp next_node_time = BootTimestamp::max_time();
-      BootDuration next_node_duration;
-      NoncausalTimestampFilter *next_node_filter = nullptr;
-      // Find the oldest time for each node in each filter, and solve for that
-      // time.  That gives us the next timestamp for this node.
-      size_t filter_index = 0;
-      for (const auto &filter : filters) {
-        std::optional<std::tuple<BootTimestamp, BootDuration>> candidate =
-            filter.filter->Observe();
-
-        if (candidate) {
-          VLOG(2) << "Candidate for node " << node_a_index << " filter "
-                  << filter_index << " is " << std::get<0>(*candidate);
-          if (std::get<0>(*candidate) < next_node_time) {
-            next_node_time = std::get<0>(*candidate);
-            next_node_duration = std::get<1>(*candidate);
-            next_node_filter = filter.filter;
-          }
-        }
-        ++filter_index;
-      }
-
-      // Found no active filters.  Either this node is off, or disconnected, or
-      // we are before the log file starts or after the log file ends.
-      if (next_node_time == BootTimestamp::max_time()) {
-        candidate_times[node_a_index] =
-            CandidateTimes{.next_node_time = next_node_time,
-                           .next_node_duration = next_node_duration,
-                           .next_node_filter = next_node_filter};
-        ++node_a_index;
-        continue;
-      }
-
-      // We want to make sure we solve explicitly for the start time for each
-      // log.  This is useless (though not all that expensive) if it is in the
-      // middle of a set of data since we are just adding an extra point in the
-      // middle of a line, but very useful if the start time is before any
-      // points and we need to force a node to reboot.
-      //
-      // We can only do this meaningfully if there are data points on this node
-      // before or after this node to solve for.
-      const size_t next_boot = last_monotonics_[node_a_index].boot + 1;
-      if (next_boot < boots_->boots[node_a_index].size() &&
-          timestamp_mappers_[node_a_index] != nullptr) {
-        const BootTimestamp next_start_time = BootTimestamp{
-            .boot = next_boot,
-            .time = timestamp_mappers_[node_a_index]->monotonic_start_time(
-                next_boot)};
-        if (next_start_time < next_node_time) {
-          VLOG(1) << "Candidate for node " << node_a_index
-                  << " is the next startup time, " << next_start_time;
-          next_node_time = next_start_time;
-          next_node_filter = nullptr;
-        }
-
-        // We need to make sure we have solutions as well for any local messages
-        // published before remote messages.  Find the oldest message for each
-        // boot and make sure there's a time there.  Boots can't overlap, so if
-        // we have evidence that there has been a reboot, we need to get that
-        // into the interpolation function.
-        const BootTimestamp next_oldest_time = BootTimestamp{
-            .boot = next_boot,
-            .time = timestamp_mappers_[node_a_index]->monotonic_oldest_time(
-                next_boot)};
-        if (next_oldest_time < next_node_time) {
-          VLOG(1) << "Candidate for node " << node_a_index
-                  << " is the next oldest time, " << next_oldest_time
-                  << " not applying yet";
-          next_node_time = next_oldest_time;
-          next_node_filter = nullptr;
+      if (candidate) {
+        VLOG(2) << "Candidate for node " << node_a_index << " filter "
+                << filter_index << " is " << std::get<0>(*candidate);
+        if (std::get<0>(*candidate) < next_node_time) {
+          next_node_time = std::get<0>(*candidate);
+          next_node_duration = std::get<1>(*candidate);
+          next_node_filter = filter.filter;
         }
       }
+      ++filter_index;
+    }
+
+    // Found no active filters.  Either this node is off, or disconnected, or
+    // we are before the log file starts or after the log file ends.
+    if (next_node_time == BootTimestamp::max_time()) {
       candidate_times[node_a_index] =
           CandidateTimes{.next_node_time = next_node_time,
                          .next_node_duration = next_node_duration,
                          .next_node_filter = next_node_filter};
       ++node_a_index;
+      continue;
     }
+
+    // We want to make sure we solve explicitly for the start time for each
+    // log.  This is useless (though not all that expensive) if it is in the
+    // middle of a set of data since we are just adding an extra point in the
+    // middle of a line, but very useful if the start time is before any
+    // points and we need to force a node to reboot.
+    //
+    // We can only do this meaningfully if there are data points on this node
+    // before or after this node to solve for.
+    const size_t next_boot = last_monotonics_[node_a_index].boot + 1;
+    if (next_boot < boots_->boots[node_a_index].size() &&
+        timestamp_mappers_[node_a_index] != nullptr) {
+      const BootTimestamp next_start_time = BootTimestamp{
+          .boot = next_boot,
+          .time = timestamp_mappers_[node_a_index]->monotonic_start_time(
+              next_boot)};
+      if (next_start_time < next_node_time) {
+        VLOG(1) << "Candidate for node " << node_a_index
+                << " is the next startup time, " << next_start_time;
+        next_node_time = next_start_time;
+        next_node_filter = nullptr;
+      }
+
+      // We need to make sure we have solutions as well for any local messages
+      // published before remote messages.  Find the oldest message for each
+      // boot and make sure there's a time there.  Boots can't overlap, so if
+      // we have evidence that there has been a reboot, we need to get that
+      // into the interpolation function.
+      const BootTimestamp next_oldest_time = BootTimestamp{
+          .boot = next_boot,
+          .time = timestamp_mappers_[node_a_index]->monotonic_oldest_time(
+              next_boot)};
+      if (next_oldest_time < next_node_time) {
+        VLOG(1) << "Candidate for node " << node_a_index
+                << " is the next oldest time, " << next_oldest_time
+                << " not applying yet";
+        next_node_time = next_oldest_time;
+        next_node_filter = nullptr;
+      }
+    }
+    if (last_boot != std::numeric_limits<size_t>::max()) {
+      boots_all_match &= (next_node_time.boot == last_boot);
+    }
+    last_boot = next_node_time.boot;
+    candidate_times[node_a_index] =
+        CandidateTimes{.next_node_time = next_node_time,
+                       .next_node_duration = next_node_duration,
+                       .next_node_filter = next_node_filter};
+    ++node_a_index;
   }
 
+  return std::make_tuple(candidate_times, boots_all_match);
+}
+
+std::tuple<NoncausalTimestampFilter *, std::vector<BootTimestamp>, int>
+MultiNodeNoncausalOffsetEstimator::SimultaneousSolution(
+    TimestampProblem *problem,
+    const std::vector<CandidateTimes> candidate_times,
+    const std::vector<logger::BootTimestamp> &base_times) {
+  std::vector<BootTimestamp> result_times;
   NoncausalTimestampFilter *next_filter = nullptr;
   size_t solution_index = 0;
+
+  // Now, build up the solution points that we care about.
+  size_t valid_nodes = 0;
+  // We know that time advances at about 1 seconds/second.  So, a good
+  // approximation for the next solution is going to be to compute the amount
+  // of time that will elapse for each node to go to the points to solve, and
+  // advance the minimum amount.  This should hopefully save an iteration or
+  // two on the solver for minimal compute.
+  chrono::nanoseconds dt{0};
+  std::vector<BootTimestamp> points(problem->size(), BootTimestamp::max_time());
+
+  for (size_t node_a_index = 0; node_a_index < candidate_times.size();
+       ++node_a_index) {
+    BootTimestamp next_node_time = candidate_times[node_a_index].next_node_time;
+    if (next_node_time == BootTimestamp::max_time()) {
+      continue;
+    }
+    CHECK_EQ(next_node_time.boot, base_times[node_a_index].boot);
+
+    const chrono::nanoseconds this_dt =
+        next_node_time.time - base_times[node_a_index].time;
+    if (valid_nodes == 0 || this_dt < dt) {
+      dt = this_dt;
+    }
+
+    ++valid_nodes;
+    points[node_a_index] = next_node_time;
+  }
+
+  // Only solve if there are nodes to solve for.  Otherwise the defaults will
+  // report 'no solution' which is exactly what we want.
+  if (valid_nodes > 0) {
+    // Apply our dt offset.
+    for (size_t node_index = 0; node_index < base_times.size(); ++node_index) {
+      problem->set_base_clock(node_index, {base_times[node_index].boot,
+                                           base_times[node_index].time + dt});
+    }
+    std::tuple<std::vector<BootTimestamp>, size_t> solution =
+        problem->SolveNewton(points);
+
+    if (!problem->ValidateSolution(std::get<0>(solution))) {
+      LOG(WARNING) << "Invalid solution, constraints not met.";
+      for (size_t i = 0; i < std::get<0>(solution).size(); ++i) {
+        LOG(INFO) << "  " << std::get<0>(solution)[i];
+      }
+      problem->Debug();
+      if (!skip_order_validation_) {
+        LOG(FATAL) << "Bailing, use --skip_order_validation to continue.  "
+                      "Use at your own risk.";
+      }
+    }
+
+    result_times = std::move(std::get<0>(solution));
+    next_filter = candidate_times[std::get<1>(solution)].next_node_filter;
+    solution_index = std::get<1>(solution);
+  }
+
+  return std::make_tuple(next_filter, std::move(result_times), solution_index);
+}
+
+void MultiNodeNoncausalOffsetEstimator::CheckInvalidDistance(
+    const std::vector<BootTimestamp> &result_times,
+    const std::vector<BootTimestamp> &solution) {
+  // If times are close enough, drop the invalid time.
+  const chrono::nanoseconds invalid_distance =
+      InvalidDistance(result_times, solution);
+  if (invalid_distance <= chrono::nanoseconds(FLAGS_max_invalid_distance_ns)) {
+    VLOG(1) << "Times can't be compared by " << invalid_distance.count()
+            << "ns";
+    for (size_t i = 0; i < result_times.size(); ++i) {
+      VLOG(1) << "  " << result_times[i] << " vs " << solution[i] << " -> "
+              << (result_times[i].time - solution[i].time).count() << "ns";
+    }
+    VLOG(1) << "Ignoring because it is close enough.";
+    return;
+  }
+  // Somehow the new solution is better *and* worse than the old
+  // solution...  This is an internal failure because that means time
+  // goes backwards on a node.
+  CHECK_EQ(result_times.size(), solution.size());
+  LOG(INFO) << "Times can't be compared by " << invalid_distance.count()
+            << "ns";
+  for (size_t i = 0; i < result_times.size(); ++i) {
+    LOG(INFO) << "  " << result_times[i] << " vs " << solution[i] << " -> "
+              << (result_times[i].time - solution[i].time).count() << "ns";
+  }
+
+  if (skip_order_validation_) {
+    LOG(ERROR) << "Skipping because --skip_order_validation";
+  } else {
+    LOG(FATAL) << "Please investigate.  Use --max_invalid_distance_ns="
+               << invalid_distance.count() << " to ignore this.";
+  }
+}
+
+std::tuple<NoncausalTimestampFilter *, std::vector<BootTimestamp>, int>
+MultiNodeNoncausalOffsetEstimator::SequentialSolution(
+    TimestampProblem *problem,
+    const std::vector<CandidateTimes> candidate_times,
+    const std::vector<logger::BootTimestamp> &base_times) {
+  std::vector<BootTimestamp> result_times;
+  NoncausalTimestampFilter *next_filter = nullptr;
+  size_t solution_index = 0;
+
   for (size_t node_a_index = 0; node_a_index < candidate_times.size();
        ++node_a_index) {
     VLOG(2) << "Investigating filter for node " << node_a_index;
@@ -1371,8 +1526,6 @@
     // timestamps, we might need to change our assumptions around
     // BootTimestamp and BootDuration.
 
-    // If we haven't rebooted, we can seed the optimization problem with a
-    // pretty good initial guess.
     if (next_node_time.boot == base_times[node_a_index].boot) {
       // Optimize, and save the time into times if earlier than time.
       for (size_t node_index = 0; node_index < base_times.size();
@@ -1392,23 +1545,27 @@
            ++node_index) {
         problem->set_base_clock(node_index, base_times[node_index]);
       }
+      // And we know our solution node will have the wrong boot, so replace
+      // it entirely.
+      problem->set_base_clock(node_a_index, next_node_time);
     }
 
-    problem->set_solution_node(node_a_index);
-    problem->set_base_clock(problem->solution_node(), next_node_time);
+    std::vector<BootTimestamp> points(problem->size(),
+                                      BootTimestamp::max_time());
     if (VLOG_IS_ON(2)) {
       problem->Debug();
     }
-    // TODO(austin): Solve all problems at once :)
-    std::vector<BootTimestamp> solution = problem->SolveNewton();
+    points[node_a_index] = next_node_time;
+    std::tuple<std::vector<BootTimestamp>, size_t> solution =
+        problem->SolveNewton(points);
 
     // Bypass checking if order validation is turned off.  This lets us dump a
     // CSV file so we can view the problem and figure out what to do.  The
     // results won't make sense.
-    if (!problem->ValidateSolution(solution)) {
+    if (!problem->ValidateSolution(std::get<0>(solution))) {
       LOG(WARNING) << "Invalid solution, constraints not met.";
-      for (size_t i = 0; i < solution.size(); ++i) {
-        LOG(INFO) << "  " << solution[i];
+      for (size_t i = 0; i < std::get<0>(solution).size(); ++i) {
+        LOG(INFO) << "  " << std::get<0>(solution)[i];
       }
       problem->Debug();
       if (!skip_order_validation_) {
@@ -1418,21 +1575,22 @@
     }
 
     if (VLOG_IS_ON(1)) {
-      VLOG(1) << "Candidate solution for node " << node_a_index << " is";
-      for (size_t i = 0; i < solution.size(); ++i) {
-        VLOG(1) << "  " << solution[i];
+      VLOG(1) << "Candidate std::get<0>(solution) for node " << node_a_index
+              << " is";
+      for (size_t i = 0; i < std::get<0>(solution).size(); ++i) {
+        VLOG(1) << "  " << std::get<0>(solution)[i];
       }
     }
 
     if (result_times.empty()) {
       // This is the first solution candidate, so don't bother comparing.
-      result_times = std::move(solution);
+      result_times = std::move(std::get<0>(solution));
       next_filter = next_node_filter;
       solution_index = node_a_index;
       continue;
     }
 
-    switch (CompareTimes(result_times, solution)) {
+    switch (CompareTimes(result_times, std::get<0>(solution))) {
       // The old solution is before or at the new solution.  This means that
       // the old solution is a better result, so ignore this one.
       case TimeComparison::kBefore:
@@ -1440,66 +1598,53 @@
         break;
       case TimeComparison::kAfter:
         // The new solution is better!  Save it.
-        result_times = std::move(solution);
+        result_times = std::move(std::get<0>(solution));
         next_filter = next_node_filter;
         solution_index = node_a_index;
         break;
       case TimeComparison::kInvalid: {
-        // If times are close enough, drop the invalid time.
-        const chrono::nanoseconds invalid_distance =
-            InvalidDistance(result_times, solution);
-        if (invalid_distance <=
-            chrono::nanoseconds(FLAGS_max_invalid_distance_ns)) {
-          VLOG(1) << "Times can't be compared by " << invalid_distance.count()
-                  << "ns";
-          for (size_t i = 0; i < result_times.size(); ++i) {
-            VLOG(1) << "  " << result_times[i] << " vs " << solution[i]
-                    << " -> "
-                    << (result_times[i].time - solution[i].time).count()
-                    << "ns";
-          }
-          VLOG(1) << "Ignoring because it is close enough.";
-          if (next_node_filter) {
-            std::optional<
-                std::tuple<logger::BootTimestamp, logger::BootDuration>>
-                result = next_node_filter->Consume();
-            CHECK(result);
-            next_node_filter->Pop(std::get<0>(*result) -
-                                  time_estimation_buffer_seconds_);
-          }
-          break;
-        }
-        // Somehow the new solution is better *and* worse than the old
-        // solution...  This is an internal failure because that means time
-        // goes backwards on a node.
-        CHECK_EQ(result_times.size(), solution.size());
-        LOG(INFO) << "Times can't be compared by " << invalid_distance.count()
-                  << "ns";
-        for (size_t i = 0; i < result_times.size(); ++i) {
-          LOG(INFO) << "  " << result_times[i] << " vs " << solution[i]
-                    << " -> "
-                    << (result_times[i].time - solution[i].time).count()
-                    << "ns";
-        }
-
-        if (skip_order_validation_) {
-          if (next_node_filter) {
-            std::optional<
-                std::tuple<logger::BootTimestamp, logger::BootDuration>>
-                result = next_node_filter->Consume();
-            CHECK(result);
-            next_node_filter->Pop(std::get<0>(*result) -
-                                  time_estimation_buffer_seconds_);
-          }
-          LOG(ERROR) << "Skipping because --skip_order_validation";
-          break;
-        } else {
-          LOG(FATAL) << "Please investigate.  Use --max_invalid_distance_ns="
-                     << invalid_distance.count() << " to ignore this.";
+        CheckInvalidDistance(result_times, std::get<0>(solution));
+        if (next_node_filter) {
+          std::optional<std::tuple<logger::BootTimestamp, logger::BootDuration>>
+              result = next_node_filter->Consume();
+          CHECK(result);
+          next_node_filter->Pop(std::get<0>(*result) -
+                                time_estimation_buffer_seconds_);
         }
       } break;
     }
   }
+
+  return std::make_tuple(next_filter, std::move(result_times), solution_index);
+}
+
+std::tuple<NoncausalTimestampFilter *, std::vector<BootTimestamp>, int>
+MultiNodeNoncausalOffsetEstimator::NextSolution(
+    TimestampProblem *problem, const std::vector<BootTimestamp> &base_times) {
+  // Ok, now solve for the minimum time on each channel.
+  std::vector<BootTimestamp> result_times;
+
+  bool boots_all_match = true;
+  std::vector<CandidateTimes> candidate_times;
+  std::tie(candidate_times, boots_all_match) = MakeCandidateTimes();
+
+  NoncausalTimestampFilter *next_filter = nullptr;
+  size_t solution_index = 0;
+
+  // We can significantly speed things up if we know that all the boots match by
+  // solving for everything at once.  If the boots don't match, the combined min
+  // that happens inside the solver doesn't make a lot of sense since we are
+  // actually using the boot from the candidate times to figure out which
+  // interpolation function to use under the hood.
+  if (boots_all_match) {
+    std::tie(next_filter, result_times, solution_index) =
+        SimultaneousSolution(problem, std::move(candidate_times), base_times);
+  } else {
+    // If all the boots don't match, fall back to the old method of comparing
+    // all the solutions individually.
+    std::tie(next_filter, result_times, solution_index) =
+        SequentialSolution(problem, std::move(candidate_times), base_times);
+  }
   if (VLOG_IS_ON(1)) {
     VLOG(1) << "Best solution is for node " << solution_index;
     for (size_t i = 0; i < result_times.size(); ++i) {
diff --git a/aos/network/multinode_timestamp_filter.h b/aos/network/multinode_timestamp_filter.h
index d9c6f59..7259fc8 100644
--- a/aos/network/multinode_timestamp_filter.h
+++ b/aos/network/multinode_timestamp_filter.h
@@ -39,26 +39,24 @@
 
   size_t size() const { return base_clock_.size(); }
 
-  // Sets node to fix time for and not solve for.
-  void set_solution_node(size_t solution_node) {
-    solution_node_ = solution_node;
-  }
-  size_t solution_node() const { return solution_node_; }
-
   // Sets and gets the base time for a node.
   void set_base_clock(size_t i, logger::BootTimestamp t) { base_clock_[i] = t; }
   logger::BootTimestamp base_clock(size_t i) const { return base_clock_[i]; }
 
   // Adds a timestamp filter from a -> b.
   //   filter[a_index]->Offset(ta) + ta => t(b_index);
-  void add_filter(size_t a_index, const NoncausalTimestampFilter *filter,
-                  size_t b_index) {
-    filters_[a_index].emplace_back(filter, b_index);
+  void add_clock_offset_filter(size_t a_index,
+                               const NoncausalTimestampFilter *filter,
+                               size_t b_index) {
+    clock_offset_filter_for_node_[a_index].emplace_back(filter, b_index);
   }
 
   // Solves the optimization problem phrased using the symmetric Netwon's method
-  // solver and returns the optimal time on each node.
-  std::vector<logger::BootTimestamp> SolveNewton();
+  // solver and returns the optimal time on each node, along with the node which
+  // constrained the problem.  points is the list of potential constraint
+  // points, and the solver uses the earliest point.
+  std::tuple<std::vector<logger::BootTimestamp>, size_t> SolveNewton(
+      const std::vector<logger::BootTimestamp> &points);
 
   // Validates the solution, returning true if it meets all the constraints, and
   // false otherwise.
@@ -85,33 +83,33 @@
   }
 
  private:
+  size_t SolutionNode(const std::vector<logger::BootTimestamp> &points) const {
+    size_t solution_node = std::numeric_limits<size_t>::max();
+    for (size_t i = 0; i < points.size(); ++i) {
+      if (points[i] != logger::BootTimestamp::max_time()) {
+        CHECK_EQ(solution_node, std::numeric_limits<size_t>::max());
+        solution_node = i;
+      }
+    }
+    CHECK_NE(solution_node, std::numeric_limits<size_t>::max());
+    return solution_node;
+  }
+
   // Returns the Hessian of the cost function at time_offsets.
   Eigen::MatrixXd Hessian(const Eigen::Ref<Eigen::VectorXd> time_offsets) const;
   // Returns the gradient of the cost function at time_offsets.
   Eigen::VectorXd Gradient(
       const Eigen::Ref<Eigen::VectorXd> time_offsets) const;
 
-  // Returns the newton step of the timestamp problem.  The last term is the
-  // scalar on the equality constraint.  This needs to be removed from the
-  // solution to get the actual newton step.
-  Eigen::VectorXd Newton(const Eigen::Ref<Eigen::VectorXd> time_offsets) const;
+  // Returns the newton step of the timestamp problem, and the node which was
+  // used for the equality constraint.  The last term is the scalar on the
+  // equality constraint.  This needs to be removed from the solution to get the
+  // actual newton step.
+  std::tuple<Eigen::VectorXd, size_t> Newton(
+      const Eigen::Ref<Eigen::VectorXd> time_offsets,
+      const std::vector<logger::BootTimestamp> &points) const;
 
-  void MaybeUpdateNodeMapping() {
-    if (node_mapping_valid_) {
-      return;
-    }
-    size_t live_node_index = 0;
-    for (size_t i = 0; i < node_mapping_.size(); ++i) {
-      if (live(i)) {
-        node_mapping_[i] = live_node_index;
-        ++live_node_index;
-      } else {
-        node_mapping_[i] = std::numeric_limits<size_t>::max();
-      }
-    }
-    live_nodes_ = live_node_index;
-    node_mapping_valid_ = true;
-  }
+  void MaybeUpdateNodeMapping();
 
   // Converts from a node index to an index in the solution without skipping the
   // solution node.
@@ -120,9 +118,6 @@
     return node_mapping_[node_index];
   }
 
-  // The node to hold fixed when solving.
-  size_t solution_node_ = 0;
-
   // The optimization problem is solved as base_clock + time_offsets to minimize
   // numerical precision problems.  This contains all the base times.  The base
   // time corresponding to solution_node is fixed and not solved.
@@ -146,7 +141,7 @@
   };
 
   // List of filters indexed by node.
-  std::vector<std::vector<FilterPair>> filters_;
+  std::vector<std::vector<FilterPair>> clock_offset_filter_for_node_;
 };
 
 // Helpers to convert times between the monotonic and distributed clocks for
@@ -334,8 +329,17 @@
   const aos::Configuration *configuration() const { return configuration_; }
 
  private:
+  struct CandidateTimes {
+    logger::BootTimestamp next_node_time = logger::BootTimestamp::max_time();
+    logger::BootDuration next_node_duration;
+    NoncausalTimestampFilter *next_node_filter = nullptr;
+  };
+
   TimestampProblem MakeProblem();
 
+  // Returns the list of candidate times to solve for.
+  std::tuple<std::vector<CandidateTimes>, bool> MakeCandidateTimes() const;
+
   // Returns the next solution, the filter which has the control point for it
   // (or nullptr if a start time triggered this to be returned), and the node
   // which triggered it.
@@ -344,6 +348,27 @@
   NextSolution(TimestampProblem *problem,
                const std::vector<logger::BootTimestamp> &base_times);
 
+  // Returns the solution (if there is one) for the list of candidate times by
+  // solving all the problems simultaneously.  They must be from the same boot.
+  std::tuple<NoncausalTimestampFilter *, std::vector<logger::BootTimestamp>,
+             int>
+  SimultaneousSolution(TimestampProblem *problem,
+                       const std::vector<CandidateTimes> candidate_times,
+                       const std::vector<logger::BootTimestamp> &base_times);
+
+  // Returns the solution (if there is one) for the list of candidate times by
+  // solving the problems one after another.  They can be from any boot.
+  std::tuple<NoncausalTimestampFilter *, std::vector<logger::BootTimestamp>,
+             int>
+  SequentialSolution(TimestampProblem *problem,
+                     const std::vector<CandidateTimes> candidate_times,
+                     const std::vector<logger::BootTimestamp> &base_times);
+
+  // Explodes if the invalid distance is too far.
+  void CheckInvalidDistance(
+      const std::vector<logger::BootTimestamp> &result_times,
+      const std::vector<logger::BootTimestamp> &solution);
+
   // Writes all samples to disk.
   void FlushAllSamples(bool finish);
 
@@ -352,7 +377,7 @@
 
   std::shared_ptr<const logger::Boots> boots_;
 
-  // If true, skip any validation which would trigger if we see evidance that
+  // If true, skip any validation which would trigger if we see evidence that
   // time estimation between nodes was incorrect.
   const bool skip_order_validation_;
 
diff --git a/aos/network/multinode_timestamp_filter_test.cc b/aos/network/multinode_timestamp_filter_test.cc
index c1fb456..c2a3b71 100644
--- a/aos/network/multinode_timestamp_filter_test.cc
+++ b/aos/network/multinode_timestamp_filter_test.cc
@@ -361,30 +361,37 @@
   TimestampProblem problem(2);
   problem.set_base_clock(0, ta);
   problem.set_base_clock(1, e);
-  problem.set_solution_node(0);
-  problem.add_filter(0, &a, 1);
-  problem.add_filter(1, &b, 0);
+  problem.add_clock_offset_filter(0, &a, 1);
+  problem.add_clock_offset_filter(1, &b, 0);
 
   problem.Debug();
 
   problem.set_base_clock(0, e + chrono::seconds(1));
   problem.set_base_clock(1, e);
 
-  problem.set_solution_node(0);
-  std::vector<BootTimestamp> result1 = problem.SolveNewton();
+  std::vector<BootTimestamp> points1(problem.size(), BootTimestamp::max_time());
+  points1[0] = e + chrono::seconds(1);
 
-  problem.set_base_clock(1, result1[1]);
-  problem.set_solution_node(1);
-  std::vector<BootTimestamp> result2 = problem.SolveNewton();
+  std::tuple<std::vector<BootTimestamp>, size_t> result1 =
+      problem.SolveNewton(points1);
+  EXPECT_EQ(std::get<1>(result1), 0u);
 
-  EXPECT_EQ(result1[0], e + chrono::seconds(1));
-  EXPECT_EQ(result1[0], result2[0]);
-  EXPECT_EQ(result1[1], result2[1]);
+  std::vector<BootTimestamp> points2(problem.size(), BootTimestamp::max_time());
+  points2[1] = std::get<0>(result1)[1];
+  std::tuple<std::vector<BootTimestamp>, size_t> result2 =
+      problem.SolveNewton(points2);
+  EXPECT_EQ(std::get<1>(result2), 1u);
+
+  EXPECT_EQ(std::get<0>(result1)[0], e + chrono::seconds(1));
+  EXPECT_EQ(std::get<0>(result1)[0], std::get<0>(result2)[0]);
+  EXPECT_EQ(std::get<0>(result1)[1], std::get<0>(result2)[1]);
 
   // Confirm that the error is almost equal for both directions.  The solution
   // is an integer solution, so there will be a little bit of error left over.
-  EXPECT_NEAR(a.OffsetError(result1[0], 0.0, result1[1], 0.0) -
-                  b.OffsetError(result1[1], 0.0, result1[0], 0.0),
+  EXPECT_NEAR(a.OffsetError(std::get<0>(result1)[0], 0.0,
+                            std::get<0>(result1)[1], 0.0) -
+                  b.OffsetError(std::get<0>(result1)[1], 0.0,
+                                std::get<0>(result1)[0], 0.0),
               0.0, 0.5);
 }