Add LogReader ReplayChannels filtering

LogReader has a new input into its constructor which is replay_channels containing the
name & type pairs of channels to replay on. As a part of construction,
LogReader takes this replay_channels and acquires the channel indicies of
the channels being replayed and uses that to check on sending a message
if the channel is included in replay_channels or not. This functionality is
contained within TimestampMapper which takes a lambda to do the actual
filtering when calling Front().

Change-Id: I614bc70f89afab2e7f6d00a36dc569518d1edc5a
Signed-off-by: James Kuszmaul <james.kuszmaul@bluerivertech.com>
diff --git a/aos/events/logging/BUILD b/aos/events/logging/BUILD
index 7e8daf9..87ebfc5 100644
--- a/aos/events/logging/BUILD
+++ b/aos/events/logging/BUILD
@@ -543,6 +543,7 @@
     name = "realtime_replay_test",
     srcs = ["realtime_replay_test.cc"],
     data = [
+        ":multinode_pingpong_combined_config",
         "//aos/events:pingpong_config",
     ],
     target_compatible_with = ["@platforms//os:linux"],
diff --git a/aos/events/logging/log_reader.cc b/aos/events/logging/log_reader.cc
index f80f28d..d3ce16b 100644
--- a/aos/events/logging/log_reader.cc
+++ b/aos/events/logging/log_reader.cc
@@ -233,13 +233,17 @@
 };
 
 LogReader::LogReader(std::string_view filename,
-                     const Configuration *replay_configuration)
-    : LogReader(SortParts({std::string(filename)}), replay_configuration) {}
+                     const Configuration *replay_configuration,
+                     const ReplayChannels *replay_channels)
+    : LogReader(SortParts({std::string(filename)}), replay_configuration,
+                replay_channels) {}
 
 LogReader::LogReader(std::vector<LogFile> log_files,
-                     const Configuration *replay_configuration)
+                     const Configuration *replay_configuration,
+                     const ReplayChannels *replay_channels)
     : log_files_(std::move(log_files)),
-      replay_configuration_(replay_configuration) {
+      replay_configuration_(replay_configuration),
+      replay_channels_(replay_channels) {
   SetStartTime(FLAGS_start_time);
   SetEndTime(FLAGS_end_time);
 
@@ -260,6 +264,11 @@
     }
   }
 
+  if (replay_channels_ != nullptr) {
+    CHECK(!replay_channels_->empty()) << "replay_channels is empty which means "
+                                         "no messages will get replayed.";
+  }
+
   MakeRemappedConfig();
 
   // Remap all existing remote timestamp channels.  They will be recreated, and
@@ -444,6 +453,7 @@
                 std::nullopt, false,
                 last_queued_message_ == BootTimestamp::max_time()};
           }
+
           TimestampedMessage *message = timestamp_mapper_->Front();
           // Upon reaching the end of the log, exit.
           if (message == nullptr) {
@@ -452,6 +462,7 @@
                                        BootTimestamp>::PushResult{std::nullopt,
                                                                   false, true};
           }
+
           last_queued_message_ = message->monotonic_event_time;
           const util::ThreadedQueue<TimestampedMessage,
                                     BootTimestamp>::PushResult result{
@@ -606,7 +617,8 @@
         filtered_parts.size() == 0u
             ? nullptr
             : std::make_unique<TimestampMapper>(std::move(filtered_parts)),
-        filters_.get(), node, State::ThreadedBuffering::kNo);
+        filters_.get(), node, State::ThreadedBuffering::kNo,
+        MaybeMakeReplayChannelIndicies(node));
     State *state = states_[node_index].get();
     state->SetNodeEventLoopFactory(
         event_loop_factory_->GetNodeEventLoopFactory(node),
@@ -647,6 +659,7 @@
     if (state->SingleThreadedOldestMessageTime() == BootTimestamp::max_time()) {
       continue;
     }
+
     ++live_nodes_;
 
     NodeEventLoopFactory *node_factory =
@@ -795,7 +808,8 @@
         filtered_parts.size() == 0u
             ? nullptr
             : std::make_unique<TimestampMapper>(std::move(filtered_parts)),
-        filters_.get(), node, State::ThreadedBuffering::kYes);
+        filters_.get(), node, State::ThreadedBuffering::kYes,
+        MaybeMakeReplayChannelIndicies(node));
     State *state = states_[node_index].get();
 
     state->SetChannelCount(logged_configuration()->channels()->size());
@@ -1316,6 +1330,13 @@
                                    std::string_view add_prefix,
                                    std::string_view new_type,
                                    RemapConflict conflict_handling) {
+  if (replay_channels_ != nullptr) {
+    CHECK(std::find(replay_channels_->begin(), replay_channels_->end(),
+                    std::make_pair(name, type)) != replay_channels_->end())
+        << "Attempted to remap channel " << name << " " << type
+        << " which is not included in the replay channels passed to LogReader.";
+  }
+
   for (size_t ii = 0; ii < logged_configuration()->channels()->size(); ++ii) {
     const Channel *const channel = logged_configuration()->channels()->Get(ii);
     if (channel->name()->str() == name &&
@@ -1646,6 +1667,31 @@
   // TODO(austin): Lazily re-build to save CPU?
 }
 
+std::unique_ptr<const ReplayChannelIndicies>
+LogReader::MaybeMakeReplayChannelIndicies(const Node *node) {
+  if (replay_channels_ == nullptr) {
+    return nullptr;
+  } else {
+    std::unique_ptr<ReplayChannelIndicies> replay_channel_indicies =
+        std::make_unique<ReplayChannelIndicies>();
+    for (auto const &channel : *replay_channels_) {
+      const Channel *ch = configuration::GetChannel(
+          logged_configuration(), channel.first, channel.second, "", node);
+      if (ch == nullptr) {
+        LOG(WARNING) << "Channel: " << channel.first << " " << channel.second
+                     << " not found in configuration for node: "
+                     << node->name()->string_view() << " Skipping ...";
+        continue;
+      }
+      const size_t channel_index =
+          configuration::ChannelIndex(logged_configuration(), ch);
+      replay_channel_indicies->emplace_back(channel_index);
+    }
+    std::sort(replay_channel_indicies->begin(), replay_channel_indicies->end());
+    return replay_channel_indicies;
+  }
+}
+
 std::vector<const Channel *> LogReader::RemappedChannels() const {
   std::vector<const Channel *> result;
   result.reserve(remapped_channels_.size());
@@ -1699,11 +1745,24 @@
 LogReader::State::State(
     std::unique_ptr<TimestampMapper> timestamp_mapper,
     message_bridge::MultiNodeNoncausalOffsetEstimator *multinode_filters,
-    const Node *node, LogReader::State::ThreadedBuffering threading)
+    const Node *node, LogReader::State::ThreadedBuffering threading,
+    std::unique_ptr<const ReplayChannelIndicies> replay_channel_indicies)
     : timestamp_mapper_(std::move(timestamp_mapper)),
       node_(node),
       multinode_filters_(multinode_filters),
-      threading_(threading) {}
+      threading_(threading),
+      replay_channel_indicies_(std::move(replay_channel_indicies)) {
+  if (replay_channel_indicies_ != nullptr) {
+    timestamp_mapper_->set_replay_channels_callback(
+        [filter = replay_channel_indicies_.get()](
+            const TimestampedMessage &message) -> bool {
+          auto const begin = filter->cbegin();
+          auto const end = filter->cend();
+          // TODO: benchmark strategies for channel_index matching
+          return std::binary_search(begin, end, message.channel_index);
+        });
+  }
+}
 
 void LogReader::State::AddPeer(State *peer) {
   if (timestamp_mapper_ && peer->timestamp_mapper_) {
@@ -2131,6 +2190,7 @@
 }
 
 TimestampedMessage LogReader::State::PopOldest() {
+  // multithreaded
   if (message_queuer_.has_value()) {
     std::optional<TimestampedMessage> message = message_queuer_->Pop();
     CHECK(message.has_value()) << ": Unexpectedly ran out of messages.";
@@ -2139,7 +2199,7 @@
         std::chrono::duration_cast<std::chrono::nanoseconds>(
             std::chrono::duration<double>(FLAGS_threaded_look_ahead_seconds)));
     return message.value();
-  } else {
+  } else {  // single threaded
     CHECK(timestamp_mapper_ != nullptr);
     TimestampedMessage *result_ptr = timestamp_mapper_->Front();
     CHECK(result_ptr != nullptr);
diff --git a/aos/events/logging/log_reader.h b/aos/events/logging/log_reader.h
index 07cca48..0d50fb9 100644
--- a/aos/events/logging/log_reader.h
+++ b/aos/events/logging/log_reader.h
@@ -31,6 +31,12 @@
 
 class EventNotifier;
 
+// Vector of pair of name and type of the channel
+using ReplayChannels =
+    std::vector<std::pair<std::string_view, std::string_view>>;
+// Vector of channel indices
+using ReplayChannelIndicies = std::vector<size_t>;
+
 // We end up with one of the following 3 log file types.
 //
 // Single node logged as the source node.
@@ -67,11 +73,16 @@
   // pass it in here. It must provide all the channels that the original logged
   // config did.
   //
+  // If certain messages should not be replayed, the replay_channels param can
+  // be used as an inclusive list of channels for messages to be replayed.
+  //
   // The single file constructor calls SortParts internally.
   LogReader(std::string_view filename,
-            const Configuration *replay_configuration = nullptr);
+            const Configuration *replay_configuration = nullptr,
+            const ReplayChannels *replay_channels = nullptr);
   LogReader(std::vector<LogFile> log_files,
-            const Configuration *replay_configuration = nullptr);
+            const Configuration *replay_configuration = nullptr,
+            const ReplayChannels *replay_channels = nullptr);
   ~LogReader();
 
   // Registers all the callbacks to send the log file data out on an event loop
@@ -332,7 +343,8 @@
     enum class ThreadedBuffering { kYes, kNo };
     State(std::unique_ptr<TimestampMapper> timestamp_mapper,
           message_bridge::MultiNodeNoncausalOffsetEstimator *multinode_filters,
-          const Node *node, ThreadedBuffering threading);
+          const Node *node, ThreadedBuffering threading,
+          std::unique_ptr<const ReplayChannelIndicies> replay_channel_indicies);
 
     // Connects up the timestamp mappers.
     void AddPeer(State *peer);
@@ -728,8 +740,18 @@
     std::optional<BootTimestamp> last_queued_message_;
     std::optional<util::ThreadedQueue<TimestampedMessage, BootTimestamp>>
         message_queuer_;
+
+    // If a ReplayChannels was passed to LogReader, this will hold the
+    // indices of the channels to replay for the Node represented by
+    // the instance of LogReader::State.
+    std::unique_ptr<const ReplayChannelIndicies> replay_channel_indicies_;
   };
 
+  // If a ReplayChannels was passed to LogReader then creates a
+  // ReplayChannelIndicies for the given node. Otherwise, returns a nullptr.
+  std::unique_ptr<const ReplayChannelIndicies> MaybeMakeReplayChannelIndicies(
+      const Node *node);
+
   // Node index -> State.
   std::vector<std::unique_ptr<State>> states_;
 
@@ -766,6 +788,10 @@
   const Configuration *remapped_configuration_ = nullptr;
   const Configuration *replay_configuration_ = nullptr;
 
+  // If a ReplayChannels was passed to LogReader, this will hold the
+  // name and type of channels to replay which is used when creating States.
+  const ReplayChannels *replay_channels_ = nullptr;
+
   // If true, the replay timer will ignore any missing data.  This is used
   // during startup when we are bootstrapping everything and trying to get to
   // the start of all the log files.
diff --git a/aos/events/logging/logfile_utils.cc b/aos/events/logging/logfile_utils.cc
index 94337d3..ab56356 100644
--- a/aos/events/logging/logfile_utils.cc
+++ b/aos/events/logging/logfile_utils.cc
@@ -1804,12 +1804,30 @@
 }
 
 bool TimestampMapper::QueueMatched() {
+  MatchResult result = MatchResult::kEndOfFile;
+  do {
+    result = MaybeQueueMatched();
+  } while (result == MatchResult::kSkipped);
+  return result == MatchResult::kQueued;
+}
+
+bool TimestampMapper::CheckReplayChannelsAndMaybePop(
+    const TimestampedMessage & /*message*/) {
+  if (replay_channels_callback_ &&
+      !replay_channels_callback_(matched_messages_.back())) {
+    matched_messages_.pop_back();
+    return true;
+  }
+  return false;
+}
+
+TimestampMapper::MatchResult TimestampMapper::MaybeQueueMatched() {
   if (nodes_data_.empty()) {
     // Simple path.  We are single node, so there are no timestamps to match!
     CHECK_EQ(messages_.size(), 0u);
     Message *m = boot_merger_.Front();
     if (!m) {
-      return false;
+      return MatchResult::kEndOfFile;
     }
     // Enqueue this message into matched_messages_ so we have a place to
     // associate remote timestamps, and return it.
@@ -1821,7 +1839,10 @@
     // We are thin wrapper around node_merger.  Call it directly.
     boot_merger_.PopFront();
     timestamp_callback_(&matched_messages_.back());
-    return true;
+    if (CheckReplayChannelsAndMaybePop(matched_messages_.back())) {
+      return MatchResult::kSkipped;
+    }
+    return MatchResult::kQueued;
   }
 
   // We need to only add messages to the list so they get processed for
@@ -1830,7 +1851,7 @@
   if (messages_.empty()) {
     if (!Queue()) {
       // Found nothing to add, we are out of data!
-      return false;
+      return MatchResult::kEndOfFile;
     }
 
     // Now that it has been added (and cannibalized), forget about it
@@ -1847,7 +1868,10 @@
     last_message_time_ = matched_messages_.back().monotonic_event_time;
     messages_.pop_front();
     timestamp_callback_(&matched_messages_.back());
-    return true;
+    if (CheckReplayChannelsAndMaybePop(matched_messages_.back())) {
+      return MatchResult::kSkipped;
+    }
+    return MatchResult::kQueued;
   } else {
     // Got a timestamp, find the matching remote data, match it, and return
     // it.
@@ -1874,7 +1898,10 @@
     // Since messages_ holds the data, drop it.
     messages_.pop_front();
     timestamp_callback_(&matched_messages_.back());
-    return true;
+    if (CheckReplayChannelsAndMaybePop(matched_messages_.back())) {
+      return MatchResult::kSkipped;
+    }
+    return MatchResult::kQueued;
   }
 }
 
diff --git a/aos/events/logging/logfile_utils.h b/aos/events/logging/logfile_utils.h
index 7105b0c..346a36e 100644
--- a/aos/events/logging/logfile_utils.h
+++ b/aos/events/logging/logfile_utils.h
@@ -787,12 +787,25 @@
     }
   }
 
+  // Sets the callback that can be used to skip messages.
+  void set_replay_channels_callback(
+      std::function<bool(const TimestampedMessage &)> fn) {
+    replay_channels_callback_ = fn;
+  }
+
   // Sets a callback to be called whenever a full message is queued.
   void set_timestamp_callback(std::function<void(TimestampedMessage *)> fn) {
     timestamp_callback_ = fn;
   }
 
  private:
+  // Result of MaybeQueueMatched
+  enum class MatchResult : uint8_t {
+    kEndOfFile,  // End of the log file being read
+    kQueued,     // Message was queued
+    kSkipped     // Message was skipped over
+  };
+
   // The state for a remote node.  This holds the data that needs to be matched
   // with the remote node's timestamps.
   struct NodeData {
@@ -835,6 +848,10 @@
   // true if one was queued, and false otherwise.
   bool QueueMatched();
 
+  // Queues a message if the replay_channels_callback is passed and the end of
+  // the log file has not been reached.
+  MatchResult MaybeQueueMatched();
+
   // Queues up data until we have at least one message >= to time t.
   // Useful for triggering a remote node to read enough data to have the
   // timestamp you care about available.
@@ -843,6 +860,11 @@
   // Queues m into matched_messages_.
   void QueueMessage(Message *m);
 
+  // If a replay_channels_callback was set and the callback returns false, a
+  // matched message is popped and true is returned. Otherwise false is
+  // returned.
+  bool CheckReplayChannelsAndMaybePop(const TimestampedMessage &message);
+
   // Returns the name of the node this class is sorting for.
   std::string_view node_name() const {
     return configuration_->has_nodes() ? configuration_->nodes()
@@ -886,6 +908,7 @@
   BootTimestamp queued_until_ = BootTimestamp::min_time();
 
   std::function<void(TimestampedMessage *)> timestamp_callback_;
+  std::function<bool(TimestampedMessage &)> replay_channels_callback_;
 };
 
 // Returns the node name with a trailing space, or an empty string if we are on
diff --git a/aos/events/logging/logfile_utils_test.cc b/aos/events/logging/logfile_utils_test.cc
index 3d99757..f453798 100644
--- a/aos/events/logging/logfile_utils_test.cc
+++ b/aos/events/logging/logfile_utils_test.cc
@@ -24,8 +24,8 @@
 namespace logger {
 namespace testing {
 namespace chrono = std::chrono;
-using aos::testing::ArtifactPath;
 using aos::message_bridge::RemoteMessage;
+using aos::testing::ArtifactPath;
 
 // Adapter class to make it easy to test DetachedBufferWriter without adding
 // test only boilerplate to DetachedBufferWriter.
@@ -1036,6 +1036,143 @@
   }
 }
 
+// Tests that we filter messages using the channel filter callback
+TEST_F(TimestampMapperTest, ReplayChannelsCallbackTest) {
+  const aos::monotonic_clock::time_point e = monotonic_clock::epoch();
+  {
+    TestDetachedBufferWriter writer0(logfile0_);
+    writer0.QueueSpan(config0_.span());
+    TestDetachedBufferWriter writer1(logfile1_);
+    writer1.QueueSpan(config2_.span());
+
+    writer0.WriteSizedFlatbuffer(
+        MakeLogMessage(e + chrono::milliseconds(1000), 0, 0x005));
+    writer1.WriteSizedFlatbuffer(MakeTimestampMessage(
+        e + chrono::milliseconds(1000), 0, chrono::seconds(100)));
+
+    writer0.WriteSizedFlatbuffer(
+        MakeLogMessage(e + chrono::milliseconds(2000), 0, 0x006));
+    writer1.WriteSizedFlatbuffer(MakeTimestampMessage(
+        e + chrono::milliseconds(2000), 0, chrono::seconds(100)));
+
+    writer0.WriteSizedFlatbuffer(
+        MakeLogMessage(e + chrono::milliseconds(3000), 0, 0x007));
+    writer1.WriteSizedFlatbuffer(MakeTimestampMessage(
+        e + chrono::milliseconds(3000), 0, chrono::seconds(100)));
+  }
+
+  const std::vector<LogFile> parts = SortParts({logfile0_, logfile1_});
+
+  ASSERT_EQ(parts[0].logger_node, "pi1");
+  ASSERT_EQ(parts[1].logger_node, "pi2");
+
+  // mapper0 will not provide any messages while mapper1 will provide all
+  // messages due to the channel filter callbacks used
+  size_t mapper0_count = 0;
+  TimestampMapper mapper0(FilterPartsForNode(parts, "pi1"));
+  mapper0.set_timestamp_callback(
+      [&](TimestampedMessage *) { ++mapper0_count; });
+  mapper0.set_replay_channels_callback(
+      [&](const TimestampedMessage &) -> bool { return mapper0_count != 2; });
+  size_t mapper1_count = 0;
+  TimestampMapper mapper1(FilterPartsForNode(parts, "pi2"));
+  mapper1.set_timestamp_callback(
+      [&](TimestampedMessage *) { ++mapper1_count; });
+  mapper1.set_replay_channels_callback(
+      [&](const TimestampedMessage &) -> bool { return mapper1_count != 2; });
+
+  mapper0.AddPeer(&mapper1);
+  mapper1.AddPeer(&mapper0);
+
+  {
+    std::deque<TimestampedMessage> output0;
+
+    EXPECT_EQ(mapper0_count, 0u);
+    EXPECT_EQ(mapper1_count, 0u);
+
+    ASSERT_TRUE(mapper0.Front() != nullptr);
+    EXPECT_EQ(mapper0_count, 1u);
+    EXPECT_EQ(mapper1_count, 0u);
+    output0.emplace_back(std::move(*mapper0.Front()));
+    mapper0.PopFront();
+
+    EXPECT_TRUE(mapper0.started());
+    EXPECT_EQ(mapper0_count, 1u);
+    EXPECT_EQ(mapper1_count, 0u);
+
+    // mapper0_count is now at 3 since the second message is not queued, but
+    // timestamp_callback needs to be called everytime even if Front() does not
+    // provide a message due to the replay_channels_callback.
+    ASSERT_TRUE(mapper0.Front() != nullptr);
+    EXPECT_EQ(mapper0_count, 3u);
+    EXPECT_EQ(mapper1_count, 0u);
+    output0.emplace_back(std::move(*mapper0.Front()));
+    mapper0.PopFront();
+
+    EXPECT_TRUE(mapper0.started());
+    EXPECT_EQ(mapper0_count, 3u);
+    EXPECT_EQ(mapper1_count, 0u);
+
+    ASSERT_TRUE(mapper0.Front() == nullptr);
+    EXPECT_TRUE(mapper0.started());
+
+    EXPECT_EQ(mapper0_count, 3u);
+    EXPECT_EQ(mapper1_count, 0u);
+
+    EXPECT_EQ(output0[0].monotonic_event_time.boot, 0u);
+    EXPECT_EQ(output0[0].monotonic_event_time.time,
+              e + chrono::milliseconds(1000));
+    EXPECT_TRUE(output0[0].data != nullptr);
+
+    EXPECT_EQ(output0[1].monotonic_event_time.boot, 0u);
+    EXPECT_EQ(output0[1].monotonic_event_time.time,
+              e + chrono::milliseconds(3000));
+    EXPECT_TRUE(output0[1].data != nullptr);
+  }
+
+  {
+    SCOPED_TRACE("Trying node1 now");
+    std::deque<TimestampedMessage> output1;
+
+    EXPECT_EQ(mapper0_count, 3u);
+    EXPECT_EQ(mapper1_count, 0u);
+
+    ASSERT_TRUE(mapper1.Front() != nullptr);
+    EXPECT_EQ(mapper0_count, 3u);
+    EXPECT_EQ(mapper1_count, 1u);
+    output1.emplace_back(std::move(*mapper1.Front()));
+    mapper1.PopFront();
+    EXPECT_TRUE(mapper1.started());
+    EXPECT_EQ(mapper0_count, 3u);
+    EXPECT_EQ(mapper1_count, 1u);
+
+    // mapper1_count is now at 3 since the second message is not queued, but
+    // timestamp_callback needs to be called everytime even if Front() does not
+    // provide a message due to the replay_channels_callback.
+    ASSERT_TRUE(mapper1.Front() != nullptr);
+    output1.emplace_back(std::move(*mapper1.Front()));
+    mapper1.PopFront();
+    EXPECT_TRUE(mapper1.started());
+
+    EXPECT_EQ(mapper0_count, 3u);
+    EXPECT_EQ(mapper1_count, 3u);
+
+    ASSERT_TRUE(mapper1.Front() == nullptr);
+
+    EXPECT_EQ(mapper0_count, 3u);
+    EXPECT_EQ(mapper1_count, 3u);
+
+    EXPECT_EQ(output1[0].monotonic_event_time.boot, 0u);
+    EXPECT_EQ(output1[0].monotonic_event_time.time,
+              e + chrono::seconds(100) + chrono::milliseconds(1000));
+    EXPECT_TRUE(output1[0].data != nullptr);
+
+    EXPECT_EQ(output1[1].monotonic_event_time.boot, 0u);
+    EXPECT_EQ(output1[1].monotonic_event_time.time,
+              e + chrono::seconds(100) + chrono::milliseconds(3000));
+    EXPECT_TRUE(output1[1].data != nullptr);
+  }
+}
 // Tests that a MessageHeader with monotonic_timestamp_time set gets properly
 // returned.
 TEST_F(TimestampMapperTest, MessageWithTimestampTime) {
diff --git a/aos/events/logging/logger_test.cc b/aos/events/logging/logger_test.cc
index c99a7c8..18c8aa1 100644
--- a/aos/events/logging/logger_test.cc
+++ b/aos/events/logging/logger_test.cc
@@ -3708,7 +3708,6 @@
 
     event_loop_factory_.RunFor(chrono::milliseconds(95));
 
-
     StartLogger(&pi1_logger);
     aos::monotonic_clock::time_point last_rotation_time =
         pi1_logger.event_loop->monotonic_now();
diff --git a/aos/events/logging/realtime_replay_test.cc b/aos/events/logging/realtime_replay_test.cc
index 0cdf9fb..888b043 100644
--- a/aos/events/logging/realtime_replay_test.cc
+++ b/aos/events/logging/realtime_replay_test.cc
@@ -1,12 +1,15 @@
 #include "aos/events/logging/log_reader.h"
 #include "aos/events/logging/log_writer.h"
 #include "aos/events/ping_lib.h"
+#include "aos/events/pong_lib.h"
 #include "aos/events/shm_event_loop.h"
 #include "aos/json_to_flatbuffer.h"
 #include "aos/testing/path.h"
 #include "aos/testing/tmpdir.h"
 #include "gtest/gtest.h"
 
+DECLARE_string(override_hostname);
+
 namespace aos::logger::testing {
 
 class RealtimeLoggerTest : public ::testing::Test {
@@ -18,12 +21,54 @@
         config_(aos::configuration::ReadConfig(config_file_)),
         event_loop_factory_(&config_.message()),
         ping_event_loop_(event_loop_factory_.MakeEventLoop("ping")),
-        ping_(ping_event_loop_.get()) {
+        pong_event_loop_(event_loop_factory_.MakeEventLoop("pong")),
+        ping_(ping_event_loop_.get()),
+        pong_(pong_event_loop_.get()),
+        tmpdir_(aos::testing::TestTmpDir()),
+        base_name_(tmpdir_ + "/logfile/") {
     FLAGS_shm_base = shm_dir_;
 
-    // Nuke the shm dir, to ensure we aren't being affected by any preexisting
-    // tests.
+    // Nuke the shm and log dirs, to ensure we aren't being affected by any
+    // preexisting tests.
     aos::util::UnlinkRecursive(shm_dir_);
+    aos::util::UnlinkRecursive(base_name_);
+  }
+
+  gflags::FlagSaver flag_saver_;
+  std::string shm_dir_;
+
+  const std::string config_file_;
+  const aos::FlatbufferDetachedBuffer<aos::Configuration> config_;
+
+  // Factory and Ping class to generate a test logfile.
+  SimulatedEventLoopFactory event_loop_factory_;
+  std::unique_ptr<EventLoop> ping_event_loop_;
+  std::unique_ptr<EventLoop> pong_event_loop_;
+  Ping ping_;
+  Pong pong_;
+  const std::string tmpdir_;
+  const std::string base_name_;
+};
+
+class RealtimeMultiNodeLoggerTest : public ::testing::Test {
+ protected:
+  RealtimeMultiNodeLoggerTest()
+      : shm_dir_(aos::testing::TestTmpDir() + "/aos"),
+        config_file_(aos::testing::ArtifactPath(
+            "aos/events/logging/multinode_pingpong_combined_config.json")),
+        config_(aos::configuration::ReadConfig(config_file_)),
+        event_loop_factory_(&config_.message()),
+        ping_event_loop_(event_loop_factory_.MakeEventLoop(
+            "pi1", configuration::GetNode(&config_.message(), "pi1"))),
+        ping_(ping_event_loop_.get()),
+        tmpdir_(aos::testing::TestTmpDir()),
+        base_name_(tmpdir_ + "/logfile/") {
+    FLAGS_shm_base = shm_dir_;
+
+    // Nuke the shm and log dirs, to ensure we aren't being affected by any
+    // preexisting tests.
+    aos::util::UnlinkRecursive(shm_dir_);
+    aos::util::UnlinkRecursive(base_name_);
   }
 
   gflags::FlagSaver flag_saver_;
@@ -36,12 +81,11 @@
   SimulatedEventLoopFactory event_loop_factory_;
   std::unique_ptr<EventLoop> ping_event_loop_;
   Ping ping_;
+  const std::string tmpdir_;
+  const std::string base_name_;
 };
 
 TEST_F(RealtimeLoggerTest, RealtimeReplay) {
-  const std::string tmpdir = aos::testing::TestTmpDir();
-  const std::string base_name = tmpdir + "/logfile/";
-  aos::util::UnlinkRecursive(base_name);
   {
     std::unique_ptr<EventLoop> logger_event_loop =
         event_loop_factory_.MakeEventLoop("logger");
@@ -51,11 +95,11 @@
     Logger logger(logger_event_loop.get());
     logger.set_separate_config(false);
     logger.set_polling_period(std::chrono::milliseconds(100));
-    logger.StartLoggingOnRun(base_name);
+    logger.StartLoggingOnRun(base_name_);
     event_loop_factory_.RunFor(std::chrono::milliseconds(2000));
   }
 
-  LogReader reader(logger::SortParts(logger::FindLogs(base_name)));
+  LogReader reader(logger::SortParts(logger::FindLogs(base_name_)));
   ShmEventLoop shm_event_loop(reader.configuration());
   reader.Register(&shm_event_loop);
   reader.OnEnd(shm_event_loop.node(),
@@ -73,4 +117,248 @@
   ASSERT_TRUE(ping_fetcher.Fetch());
   ASSERT_EQ(ping_fetcher->value(), 210);
 }
+
+// Tests that ReplayChannels causes no messages to be replayed other than what
+// is included on a single node config
+TEST_F(RealtimeLoggerTest, SingleNodeReplayChannels) {
+  {
+    std::unique_ptr<EventLoop> logger_event_loop =
+        event_loop_factory_.MakeEventLoop("logger");
+
+    event_loop_factory_.RunFor(std::chrono::milliseconds(95));
+
+    Logger logger(logger_event_loop.get());
+    logger.set_separate_config(false);
+    logger.set_polling_period(std::chrono::milliseconds(100));
+    logger.StartLoggingOnRun(base_name_);
+    event_loop_factory_.RunFor(std::chrono::milliseconds(2000));
+  }
+
+  ReplayChannels replay_channels{{"/test", "aos.examples.Ping"}};
+  LogReader reader(logger::SortParts(logger::FindLogs(base_name_)),
+                   &config_.message(), &replay_channels);
+  ShmEventLoop shm_event_loop(reader.configuration());
+  reader.Register(&shm_event_loop);
+  reader.OnEnd(shm_event_loop.node(),
+               [&shm_event_loop]() { shm_event_loop.Exit(); });
+
+  Fetcher<examples::Ping> ping_fetcher =
+      shm_event_loop.MakeFetcher<examples::Ping>("/test");
+  Fetcher<examples::Pong> pong_fetcher =
+      shm_event_loop.MakeFetcher<examples::Pong>("/test");
+
+  shm_event_loop.AddTimer([]() { LOG(INFO) << "Hello, World!"; })
+      ->Setup(shm_event_loop.monotonic_now(), std::chrono::seconds(1));
+
+  auto *const end_timer = shm_event_loop.AddTimer([&shm_event_loop]() {
+    LOG(INFO) << "All done, quitting now";
+    shm_event_loop.Exit();
+  });
+
+  // TODO(EricS) reader.OnEnd() is not working as expected when
+  // using a channel filter.
+  // keep looking for 3 seconds if some message comes, just in case
+  size_t run_seconds = 3;
+  shm_event_loop.OnRun([&shm_event_loop, end_timer, run_seconds]() {
+    LOG(INFO) << "Quitting in: " << run_seconds;
+    end_timer->Setup(shm_event_loop.monotonic_now() +
+                     std::chrono::seconds(run_seconds));
+  });
+  shm_event_loop.Run();
+  reader.Deregister();
+
+  ASSERT_TRUE(ping_fetcher.Fetch());
+  ASSERT_EQ(ping_fetcher->value(), 210);
+  ASSERT_FALSE(pong_fetcher.Fetch());
+}
+
+// Tests that ReplayChannels causes no messages to be replayed other than what
+// is included on a multi node config
+TEST_F(RealtimeMultiNodeLoggerTest, ReplayChannelsPingTest) {
+  FLAGS_override_hostname = "raspberrypi";
+  {
+    std::unique_ptr<EventLoop> logger_event_loop =
+        event_loop_factory_.MakeEventLoop(
+            "logger", configuration::GetNode(&config_.message(), "pi1"));
+
+    event_loop_factory_.RunFor(std::chrono::milliseconds(95));
+
+    Logger logger(logger_event_loop.get());
+    logger.set_separate_config(false);
+    logger.set_polling_period(std::chrono::milliseconds(100));
+
+    std::unique_ptr<MultiNodeLogNamer> namer =
+        std::make_unique<MultiNodeLogNamer>(
+            base_name_, &config_.message(), logger_event_loop.get(),
+            configuration::GetNode(&config_.message(), "pi1"));
+
+    logger.StartLogging(std::move(namer));
+    event_loop_factory_.RunFor(std::chrono::milliseconds(2000));
+  }
+
+  ReplayChannels replay_channels{{"/test", "aos.examples.Ping"}};
+  LogReader reader(logger::SortParts(logger::FindLogs(base_name_)),
+                   &config_.message(), &replay_channels);
+  ShmEventLoop shm_event_loop(reader.configuration());
+  reader.Register(&shm_event_loop);
+  reader.OnEnd(shm_event_loop.node(),
+               [&shm_event_loop]() { shm_event_loop.Exit(); });
+
+  Fetcher<examples::Ping> ping_fetcher =
+      shm_event_loop.MakeFetcher<examples::Ping>("/test");
+
+  shm_event_loop.AddTimer([]() { LOG(INFO) << "Hello, World!"; })
+      ->Setup(shm_event_loop.monotonic_now(), std::chrono::seconds(1));
+
+  shm_event_loop.Run();
+  reader.Deregister();
+
+  ASSERT_TRUE(ping_fetcher.Fetch());
+  ASSERT_EQ(ping_fetcher->value(), 210);
+}
+
+// Tests that when remapping a channel included in ReplayChannels messages are
+// sent on the remapped channel
+TEST_F(RealtimeMultiNodeLoggerTest, RemappedReplayChannelsTest) {
+  FLAGS_override_hostname = "raspberrypi";
+  {
+    std::unique_ptr<EventLoop> logger_event_loop =
+        event_loop_factory_.MakeEventLoop(
+            "logger", configuration::GetNode(&config_.message(), "pi1"));
+
+    event_loop_factory_.RunFor(std::chrono::milliseconds(95));
+
+    Logger logger(logger_event_loop.get());
+    logger.set_separate_config(false);
+    logger.set_polling_period(std::chrono::milliseconds(100));
+
+    std::unique_ptr<MultiNodeLogNamer> namer =
+        std::make_unique<MultiNodeLogNamer>(
+            base_name_, &config_.message(), logger_event_loop.get(),
+            configuration::GetNode(&config_.message(), "pi1"));
+
+    logger.StartLogging(std::move(namer));
+    event_loop_factory_.RunFor(std::chrono::milliseconds(2000));
+  }
+
+  ReplayChannels replay_channels{{"/test", "aos.examples.Ping"}};
+  LogReader reader(logger::SortParts(logger::FindLogs(base_name_)),
+                   &config_.message(), &replay_channels);
+  reader.RemapLoggedChannel<aos::examples::Ping>("/test", "/original");
+  ShmEventLoop shm_event_loop(reader.configuration());
+  reader.Register(&shm_event_loop);
+  reader.OnEnd(shm_event_loop.node(),
+               [&shm_event_loop]() { shm_event_loop.Exit(); });
+
+  Fetcher<examples::Ping> original_ping_fetcher =
+      shm_event_loop.MakeFetcher<examples::Ping>("/original/test");
+
+  Fetcher<examples::Ping> ping_fetcher =
+      shm_event_loop.MakeFetcher<examples::Ping>("/test");
+
+  shm_event_loop.AddTimer([]() { LOG(INFO) << "Hello, World!"; })
+      ->Setup(shm_event_loop.monotonic_now(), std::chrono::seconds(1));
+
+  shm_event_loop.Run();
+  reader.Deregister();
+
+  ASSERT_TRUE(original_ping_fetcher.Fetch());
+  ASSERT_EQ(original_ping_fetcher->value(), 210);
+  ASSERT_FALSE(ping_fetcher.Fetch());
+}
+
+// Tests that messages are not replayed when they do not exist in the
+// ReplayChannels provided to LogReader. The channels used here do not
+// exist in the log being replayed, and there's no messages on those
+// channels as well.
+TEST_F(RealtimeMultiNodeLoggerTest, DoesNotExistInReplayChannelsTest) {
+  FLAGS_override_hostname = "raspberrypi";
+  {
+    std::unique_ptr<EventLoop> logger_event_loop =
+        event_loop_factory_.MakeEventLoop(
+            "logger", configuration::GetNode(&config_.message(), "pi1"));
+
+    event_loop_factory_.RunFor(std::chrono::milliseconds(95));
+
+    Logger logger(logger_event_loop.get());
+    logger.set_separate_config(false);
+    logger.set_polling_period(std::chrono::milliseconds(100));
+    std::unique_ptr<MultiNodeLogNamer> namer =
+        std::make_unique<MultiNodeLogNamer>(
+            base_name_, &config_.message(), logger_event_loop.get(),
+            configuration::GetNode(&config_.message(), "pi1"));
+
+    logger.StartLogging(std::move(namer));
+    event_loop_factory_.RunFor(std::chrono::milliseconds(2000));
+  }
+
+  ReplayChannels replay_channels{{"/test", "aos.examples.Pong"},
+                                 {"/test", "fake"},
+                                 {"fake", "aos.examples.Ping"}};
+  LogReader reader(logger::SortParts(logger::FindLogs(base_name_)),
+                   &config_.message(), &replay_channels);
+  ShmEventLoop shm_event_loop(reader.configuration());
+  reader.Register(&shm_event_loop);
+  reader.OnEnd(shm_event_loop.node(),
+               [&shm_event_loop]() { shm_event_loop.Exit(); });
+
+  Fetcher<examples::Ping> ping_fetcher =
+      shm_event_loop.MakeFetcher<examples::Ping>("/test");
+
+  auto *const end_timer = shm_event_loop.AddTimer([&shm_event_loop]() {
+    LOG(INFO) << "All done, quitting now";
+    shm_event_loop.Exit();
+  });
+
+  // TODO(#21) reader.OnEnd() is not working as expected when
+  // using replay_channels
+  // keep looking for 3 seconds if some message comes, just in case
+  size_t run_seconds = 3;
+  shm_event_loop.OnRun([&shm_event_loop, end_timer, run_seconds]() {
+    LOG(INFO) << "Quitting in: " << run_seconds;
+    end_timer->Setup(shm_event_loop.monotonic_now() +
+                     std::chrono::seconds(run_seconds));
+  });
+
+  shm_event_loop.Run();
+  reader.Deregister();
+  ASSERT_FALSE(ping_fetcher.Fetch());
+}
+
+using RealtimeMultiNodeLoggerDeathTest = RealtimeMultiNodeLoggerTest;
+
+// Tests that remapping a channel not included in the replay channels passed to
+// LogReader throws an error since this would indicate the user is trying to use
+// the channel being remapped.
+TEST_F(RealtimeMultiNodeLoggerDeathTest,
+       RemapLoggedChannelNotIncludedInReplayChannels) {
+  FLAGS_override_hostname = "raspberrypi";
+  {
+    std::unique_ptr<EventLoop> logger_event_loop =
+        event_loop_factory_.MakeEventLoop(
+            "logger", configuration::GetNode(&config_.message(), "pi1"));
+
+    event_loop_factory_.RunFor(std::chrono::milliseconds(95));
+
+    Logger logger(logger_event_loop.get());
+    logger.set_separate_config(false);
+    logger.set_polling_period(std::chrono::milliseconds(100));
+
+    std::unique_ptr<MultiNodeLogNamer> namer =
+        std::make_unique<MultiNodeLogNamer>(
+            base_name_, &config_.message(), logger_event_loop.get(),
+            configuration::GetNode(&config_.message(), "pi1"));
+
+    logger.StartLogging(std::move(namer));
+    event_loop_factory_.RunFor(std::chrono::milliseconds(2000));
+  }
+
+  ReplayChannels replay_channels{{"/test", "aos.examples.Ping"}};
+  LogReader reader(logger::SortParts(logger::FindLogs(base_name_)),
+                   &config_.message(), &replay_channels);
+  EXPECT_DEATH(
+      reader.RemapLoggedChannel<aos::examples::Ping>("/fake", "/original"),
+      "which is not included in the replay channels passed to LogReader");
+}
+
 }  // namespace aos::logger::testing