Work around tensorflow flatbuffer issues
Tensorflow uses an old version of flatbuffers that
doesn't work with our flatbuffer version. This patch
isolates tensorflow's flatbuffers from the rest of
the codebase.
Signed-off-by: Ravago Jones <ravagojones@gmail.com>
Signed-off-by: Filip Kujawa <filip.j.kujawa@gmail.com>
Change-Id: I65087decc156cce301d5ce0f360571400c68ef21
diff --git a/y2023/vision/yolov5.cc b/y2023/vision/yolov5.cc
index df03d12..7f5aa2a 100644
--- a/y2023/vision/yolov5.cc
+++ b/y2023/vision/yolov5.cc
@@ -1,5 +1,10 @@
#include "yolov5.h"
+#include <tensorflow/lite/interpreter.h>
+#include <tensorflow/lite/kernels/register.h>
+#include <tensorflow/lite/model.h>
+#include <tflite/public/edgetpu_c.h>
+
#include <opencv2/core.hpp>
#include "gflags/gflags.h"
@@ -19,7 +24,63 @@
namespace y2023 {
namespace vision {
-void YOLOV5::LoadModel(const std::string path) {
+class YOLOV5Impl : public YOLOV5 {
+ public:
+ // Takes a model path as string and and loads a pre-trained
+ // YOLOv5 model from the specified path.
+ void LoadModel(const std::string path);
+
+ // Takes an image and returns a Detection.
+ std::vector<Detection> ProcessImage(cv::Mat image);
+
+ private:
+ // Convert an OpenCV Mat object to a tensor input
+ // that can be fed to the TensorFlow Lite model.
+ void ConvertCVMatToTensor(const cv::Mat &src, uint8_t *in);
+
+ // Resizes, converts color space, and converts
+ // image data type before inference.
+ void Preprocess(cv::Mat image);
+
+ // Converts a TensorFlow Lite tensor to a 2D vector.
+ std::vector<std::vector<float>> TensorToVector2D(TfLiteTensor *src_tensor,
+ const int rows,
+ const int columns);
+
+ // Performs non-maximum suppression to remove overlapping bounding boxes.
+ void NonMaximumSupression(const std::vector<std::vector<float>> &orig_preds,
+ const int rows, const int columns,
+ std::vector<Detection> *detections,
+ std::vector<int> *indices);
+ // Models
+ std::unique_ptr<tflite::FlatBufferModel> model_;
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+ tflite::StderrReporter error_reporter_;
+
+ // Parameters of interpreter's input
+ int input_;
+ int in_height_;
+ int in_width_;
+ int in_channels_;
+ int in_type_;
+
+ // Parameters of original image
+ int img_height_;
+ int img_width_;
+
+ // Input of the interpreter
+ uint8_t *input_8_;
+
+ // Subtract this offset from class labels to get the actual label.
+ static constexpr int kClassIdOffset = 5;
+};
+
+std::unique_ptr<YOLOV5> MakeYOLOV5() {
+ YOLOV5Impl *yolo = new YOLOV5Impl();
+ return std::unique_ptr<YOLOV5>(yolo);
+}
+
+void YOLOV5Impl::LoadModel(const std::string path) {
model_ = tflite::FlatBufferModel::BuildFromFile(path.c_str());
CHECK(model_);
size_t num_devices;
@@ -49,13 +110,13 @@
interpreter_->SetNumThreads(FLAGS_nthreads);
}
-void YOLOV5::Preprocess(cv::Mat image) {
+void YOLOV5Impl::Preprocess(cv::Mat image) {
cv::resize(image, image, cv::Size(in_height_, in_width_), cv::INTER_CUBIC);
cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
image.convertTo(image, CV_8U);
}
-void YOLOV5::ConvertCVMatToTensor(const cv::Mat &src, uint8_t *in) {
+void YOLOV5Impl::ConvertCVMatToTensor(const cv::Mat &src, uint8_t *in) {
CHECK(src.type() == CV_8UC3);
int n = 0, nc = src.channels(), ne = src.elemSize();
for (int y = 0; y < src.rows; ++y)
@@ -64,7 +125,7 @@
in[n++] = src.data[y * src.step + x * ne + c];
}
-std::vector<std::vector<float>> YOLOV5::TensorToVector2D(
+std::vector<std::vector<float>> YOLOV5Impl::TensorToVector2D(
TfLiteTensor *src_tensor, const int rows, const int columns) {
auto scale = src_tensor->params.scale;
auto zero_point = src_tensor->params.zero_point;
@@ -83,7 +144,7 @@
return result_vec;
}
-void YOLOV5::NonMaximumSupression(
+void YOLOV5Impl::NonMaximumSupression(
const std::vector<std::vector<float>> &orig_preds, const int rows,
const int columns, std::vector<Detection> *detections,
std::vector<int> *indices)
@@ -125,7 +186,7 @@
FLAGS_nms_threshold, *indices);
}
-std::vector<Detection> YOLOV5::ProcessImage(cv::Mat frame) {
+std::vector<Detection> YOLOV5Impl::ProcessImage(cv::Mat frame) {
img_height_ = frame.rows;
img_width_ = frame.cols;