From f4aeeb2c3fe08495908c2aa9a6166c1be3846a5e Mon Sep 17 00:00:00 2001 From: Ming Ming Date: Mon, 12 Sep 2022 11:34:52 +0800 Subject: [PATCH] Optimize deeplab segment map to include only the center segment --- plugin/android/src/main/cpp/deep_lap_3.cpp | 153 +++++++++++++++++++-- plugin/android/src/main/cpp/util.h | 8 ++ 2 files changed, 150 insertions(+), 11 deletions(-) diff --git a/plugin/android/src/main/cpp/deep_lap_3.cpp b/plugin/android/src/main/cpp/deep_lap_3.cpp index 31906475..dc4fe944 100644 --- a/plugin/android/src/main/cpp/deep_lap_3.cpp +++ b/plugin/android/src/main/cpp/deep_lap_3.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include #include #include @@ -26,7 +28,6 @@ constexpr const char *MODEL = "tf/lite-model_mobilenetv2-dm05-coco_dr_1.tflite"; constexpr size_t WIDTH = 513; constexpr size_t HEIGHT = 513; constexpr unsigned LABEL_COUNT = 21; -constexpr const char *TAG = "deep_lap_3"; enum struct Label { BACKGROUND = 0, @@ -110,10 +111,41 @@ private: * 1. Contain only the most significant label (the one with the most pixel) * 2. The label value set to 255 * 3. The background set to 0 - * - * @param segmentMap */ -void postProcessSegmentMap(std::vector *segmentMap); +class SegmentMapPostProcessor { +public: + explicit SegmentMapPostProcessor(vector *segmentMap) + : segmentMapRef(*segmentMap) {} + + void operator()(); + +private: + /** + * Set the most significant segment to kMostSignificant and all others to 0 + * + * @return true if successful, false if no segments were found + */ + bool keepMostSignificantSegments(); + + /** + * Find a point with the value @a value nearest to the center + * + * @param value + * @return Closest point, or Coord(INT_MAX, INT_MAX) if not found + */ + Coord findNearestPointToCenter(const uint8_t value) const; + + /** + * Set connecting points from @a from to @a to + */ + void flood(const int atX, const int atY, const uint8_t from, + const uint8_t to); + + vector &segmentMapRef; + + static constexpr const char *TAG = "SegmentMapPostProcessor"; + static constexpr int kMostSignificant = 0xFE; +}; } // namespace @@ -210,7 +242,7 @@ vector DeepLab3Portrait::infer(const uint8_t *image, const size_t width, const size_t height, const unsigned radius) { auto segmentMap = deepLab.infer(image, width, height); - postProcessSegmentMap(&segmentMap); + SegmentMapPostProcessor{&segmentMap}(); return enhance(image, width, height, segmentMap, radius); } @@ -250,7 +282,7 @@ vector DeepLab3ColorPop::infer(const uint8_t *image, const size_t width, const size_t height, const float weight) { auto segmentMap = deepLab.infer(image, width, height); - postProcessSegmentMap(&segmentMap); + SegmentMapPostProcessor{&segmentMap}(); return enhance(image, width, height, segmentMap, weight); } @@ -284,29 +316,128 @@ vector DeepLab3ColorPop::enhance(const uint8_t *image, return rgba8ToRgb8(desaturate.data(), width, height); } -void postProcessSegmentMap(vector *segmentMap) { +void SegmentMapPostProcessor::operator()() { + if (!keepMostSignificantSegments()) { + return; + } + const auto pt = findNearestPointToCenter(kMostSignificant); + if (pt.x == INT_MAX && pt.y == INT_MAX) { + // no segment? + return; + } + flood(pt.x, pt.y, kMostSignificant, 0xFF); + for (size_t i = 0; i < segmentMapRef.size(); ++i) { + if (segmentMapRef[i] == kMostSignificant) { + segmentMapRef[i] = 0; + } + } +} + +bool SegmentMapPostProcessor::keepMostSignificantSegments() { // keep only the largest segment - vector &segmentMapRef = *segmentMap; vector count(LABEL_COUNT); for (size_t i = 0; i < segmentMapRef.size(); ++i) { assert(segmentMapRef[i] < LABEL_COUNT); - const auto label = std::min(segmentMapRef[i], LABEL_COUNT); + const auto label = std::min(segmentMapRef[i], LABEL_COUNT - 1); if (label != static_cast(Label::BACKGROUND)) { ++count[label]; } } const auto keep = distance( count.data(), max_element(count.data(), count.data() + count.size())); - LOGI(TAG, "[postProcessSegmentMap] Label to keep: %d", + LOGI(TAG, "[keepMostSignificantSegments] Label to keep: %d", static_cast(keep)); + if (keep == static_cast(Label::BACKGROUND)) { + // no segment found, keep all + std::fill(segmentMapRef.begin(), segmentMapRef.end(), 0xFF); + return false; + } + #pragma omp parallel for for (size_t i = 0; i < segmentMapRef.size(); ++i) { if (segmentMapRef[i] == keep) { - segmentMapRef[i] = 0xFF; + segmentMapRef[i] = kMostSignificant; } else { segmentMapRef[i] = 0; } } + return true; +} + +Coord SegmentMapPostProcessor::findNearestPointToCenter( + const uint8_t value) const { + LOGI(TAG, "[findNearestPointToCenter] Find nearest point of: 0x%X", value); + deque checks; + vector done(WIDTH * HEIGHT); + checks.push_back(Coord(WIDTH / 2, HEIGHT / 2)); + while (!checks.empty()) { + const auto &c = checks.front(); + checks.pop_front(); + + const auto i = c.y * WIDTH + c.x; + if (done[i]) { + continue; + } + done[i] = true; + if (segmentMapRef[i] == value) { + LOGI(TAG, "[findNearestPointToCenter] Found: (%d, %d)", c.x, c.y); + return c; + } else { + // check surroundings + for (int dy = -1; dy <= 1; ++dy) { + if (c.y + dy < 0 || c.y + dy >= HEIGHT) { + continue; + } + for (int dx = -1; dx <= 1; ++dx) { + if (c.x + dx < 0 || c.x + dx >= WIDTH) { + continue; + } + if (dx == 0 && dy == 0) { + continue; + } + if (!done[(c.y + dy) * WIDTH + (c.x + dx)]) { + checks.push_back(Coord(c.x + dx, c.y + dy)); + } + } + } + } + } + // no results + LOGI(TAG, "[findNearestPointToCenter] Not found"); + return Coord(INT_MAX, INT_MAX); +} + +void SegmentMapPostProcessor::flood(const int atX, const int atY, + const uint8_t from, const uint8_t to) { + LOGI(TAG, "[flood] At: (%d, %d), 0x%X -> 0x%X", atX, atY, from, to); + deque checks; + checks.push_back(Coord(atX, atY)); + while (!checks.empty()) { + const auto &c = checks.front(); + checks.pop_front(); + + const auto i = c.y * WIDTH + c.x; + if (segmentMapRef[i] == from) { + segmentMapRef[i] = to; + // check surroundings + for (int dy = -1; dy <= 1; ++dy) { + if (c.y + dy < 0 || c.y + dy >= HEIGHT) { + continue; + } + for (int dx = -1; dx <= 1 && c.x + dx >= 0 && c.x + dx < WIDTH; ++dx) { + if (c.x + dx < 0 || c.x + dx >= WIDTH) { + continue; + } + if (dx == 0 && dy == 0) { + continue; + } + if (segmentMapRef[(c.y + dy) * WIDTH + (c.x + dx)] != to) { + checks.push_back(Coord(c.x + dx, c.y + dy)); + } + } + } + } + } } } // namespace diff --git a/plugin/android/src/main/cpp/util.h b/plugin/android/src/main/cpp/util.h index 7ae8cb8f..7e5c4f4d 100644 --- a/plugin/android/src/main/cpp/util.h +++ b/plugin/android/src/main/cpp/util.h @@ -57,6 +57,14 @@ private: AAsset *asset = nullptr; }; +struct Coord { + Coord() : Coord(0, 0) {} + Coord(const int x, const int y) : x(x), y(y) {} + + const int x; + const int y; +}; + void initOpenMp(); int getNumberOfProcessors();