Add LogReader ReplayChannels filtering

LogReader has a new input into its constructor which is replay_channels containing the
name & type pairs of channels to replay on. As a part of construction,
LogReader takes this replay_channels and acquires the channel indicies of
the channels being replayed and uses that to check on sending a message
if the channel is included in replay_channels or not. This functionality is
contained within TimestampMapper which takes a lambda to do the actual
filtering when calling Front().

Change-Id: I614bc70f89afab2e7f6d00a36dc569518d1edc5a
Signed-off-by: James Kuszmaul <james.kuszmaul@bluerivertech.com>
diff --git a/aos/events/logging/log_reader.cc b/aos/events/logging/log_reader.cc
index f80f28d..d3ce16b 100644
--- a/aos/events/logging/log_reader.cc
+++ b/aos/events/logging/log_reader.cc
@@ -233,13 +233,17 @@
 };
 
 LogReader::LogReader(std::string_view filename,
-                     const Configuration *replay_configuration)
-    : LogReader(SortParts({std::string(filename)}), replay_configuration) {}
+                     const Configuration *replay_configuration,
+                     const ReplayChannels *replay_channels)
+    : LogReader(SortParts({std::string(filename)}), replay_configuration,
+                replay_channels) {}
 
 LogReader::LogReader(std::vector<LogFile> log_files,
-                     const Configuration *replay_configuration)
+                     const Configuration *replay_configuration,
+                     const ReplayChannels *replay_channels)
     : log_files_(std::move(log_files)),
-      replay_configuration_(replay_configuration) {
+      replay_configuration_(replay_configuration),
+      replay_channels_(replay_channels) {
   SetStartTime(FLAGS_start_time);
   SetEndTime(FLAGS_end_time);
 
@@ -260,6 +264,11 @@
     }
   }
 
+  if (replay_channels_ != nullptr) {
+    CHECK(!replay_channels_->empty()) << "replay_channels is empty which means "
+                                         "no messages will get replayed.";
+  }
+
   MakeRemappedConfig();
 
   // Remap all existing remote timestamp channels.  They will be recreated, and
@@ -444,6 +453,7 @@
                 std::nullopt, false,
                 last_queued_message_ == BootTimestamp::max_time()};
           }
+
           TimestampedMessage *message = timestamp_mapper_->Front();
           // Upon reaching the end of the log, exit.
           if (message == nullptr) {
@@ -452,6 +462,7 @@
                                        BootTimestamp>::PushResult{std::nullopt,
                                                                   false, true};
           }
+
           last_queued_message_ = message->monotonic_event_time;
           const util::ThreadedQueue<TimestampedMessage,
                                     BootTimestamp>::PushResult result{
@@ -606,7 +617,8 @@
         filtered_parts.size() == 0u
             ? nullptr
             : std::make_unique<TimestampMapper>(std::move(filtered_parts)),
-        filters_.get(), node, State::ThreadedBuffering::kNo);
+        filters_.get(), node, State::ThreadedBuffering::kNo,
+        MaybeMakeReplayChannelIndicies(node));
     State *state = states_[node_index].get();
     state->SetNodeEventLoopFactory(
         event_loop_factory_->GetNodeEventLoopFactory(node),
@@ -647,6 +659,7 @@
     if (state->SingleThreadedOldestMessageTime() == BootTimestamp::max_time()) {
       continue;
     }
+
     ++live_nodes_;
 
     NodeEventLoopFactory *node_factory =
@@ -795,7 +808,8 @@
         filtered_parts.size() == 0u
             ? nullptr
             : std::make_unique<TimestampMapper>(std::move(filtered_parts)),
-        filters_.get(), node, State::ThreadedBuffering::kYes);
+        filters_.get(), node, State::ThreadedBuffering::kYes,
+        MaybeMakeReplayChannelIndicies(node));
     State *state = states_[node_index].get();
 
     state->SetChannelCount(logged_configuration()->channels()->size());
@@ -1316,6 +1330,13 @@
                                    std::string_view add_prefix,
                                    std::string_view new_type,
                                    RemapConflict conflict_handling) {
+  if (replay_channels_ != nullptr) {
+    CHECK(std::find(replay_channels_->begin(), replay_channels_->end(),
+                    std::make_pair(name, type)) != replay_channels_->end())
+        << "Attempted to remap channel " << name << " " << type
+        << " which is not included in the replay channels passed to LogReader.";
+  }
+
   for (size_t ii = 0; ii < logged_configuration()->channels()->size(); ++ii) {
     const Channel *const channel = logged_configuration()->channels()->Get(ii);
     if (channel->name()->str() == name &&
@@ -1646,6 +1667,31 @@
   // TODO(austin): Lazily re-build to save CPU?
 }
 
+std::unique_ptr<const ReplayChannelIndicies>
+LogReader::MaybeMakeReplayChannelIndicies(const Node *node) {
+  if (replay_channels_ == nullptr) {
+    return nullptr;
+  } else {
+    std::unique_ptr<ReplayChannelIndicies> replay_channel_indicies =
+        std::make_unique<ReplayChannelIndicies>();
+    for (auto const &channel : *replay_channels_) {
+      const Channel *ch = configuration::GetChannel(
+          logged_configuration(), channel.first, channel.second, "", node);
+      if (ch == nullptr) {
+        LOG(WARNING) << "Channel: " << channel.first << " " << channel.second
+                     << " not found in configuration for node: "
+                     << node->name()->string_view() << " Skipping ...";
+        continue;
+      }
+      const size_t channel_index =
+          configuration::ChannelIndex(logged_configuration(), ch);
+      replay_channel_indicies->emplace_back(channel_index);
+    }
+    std::sort(replay_channel_indicies->begin(), replay_channel_indicies->end());
+    return replay_channel_indicies;
+  }
+}
+
 std::vector<const Channel *> LogReader::RemappedChannels() const {
   std::vector<const Channel *> result;
   result.reserve(remapped_channels_.size());
@@ -1699,11 +1745,24 @@
 LogReader::State::State(
     std::unique_ptr<TimestampMapper> timestamp_mapper,
     message_bridge::MultiNodeNoncausalOffsetEstimator *multinode_filters,
-    const Node *node, LogReader::State::ThreadedBuffering threading)
+    const Node *node, LogReader::State::ThreadedBuffering threading,
+    std::unique_ptr<const ReplayChannelIndicies> replay_channel_indicies)
     : timestamp_mapper_(std::move(timestamp_mapper)),
       node_(node),
       multinode_filters_(multinode_filters),
-      threading_(threading) {}
+      threading_(threading),
+      replay_channel_indicies_(std::move(replay_channel_indicies)) {
+  if (replay_channel_indicies_ != nullptr) {
+    timestamp_mapper_->set_replay_channels_callback(
+        [filter = replay_channel_indicies_.get()](
+            const TimestampedMessage &message) -> bool {
+          auto const begin = filter->cbegin();
+          auto const end = filter->cend();
+          // TODO: benchmark strategies for channel_index matching
+          return std::binary_search(begin, end, message.channel_index);
+        });
+  }
+}
 
 void LogReader::State::AddPeer(State *peer) {
   if (timestamp_mapper_ && peer->timestamp_mapper_) {
@@ -2131,6 +2190,7 @@
 }
 
 TimestampedMessage LogReader::State::PopOldest() {
+  // multithreaded
   if (message_queuer_.has_value()) {
     std::optional<TimestampedMessage> message = message_queuer_->Pop();
     CHECK(message.has_value()) << ": Unexpectedly ran out of messages.";
@@ -2139,7 +2199,7 @@
         std::chrono::duration_cast<std::chrono::nanoseconds>(
             std::chrono::duration<double>(FLAGS_threaded_look_ahead_seconds)));
     return message.value();
-  } else {
+  } else {  // single threaded
     CHECK(timestamp_mapper_ != nullptr);
     TimestampedMessage *result_ptr = timestamp_mapper_->Front();
     CHECK(result_ptr != nullptr);