mirror of
https://gitlab.com/nkming2/nc-photos.git
synced 2025-03-23 07:29:23 +01:00
Optimize deeplab segment map to include only the center segment
This commit is contained in:
parent
104025d424
commit
f4aeeb2c3f
2 changed files with 150 additions and 11 deletions
|
@ -9,6 +9,8 @@
|
||||||
#include <android/asset_manager.h>
|
#include <android/asset_manager.h>
|
||||||
#include <android/asset_manager_jni.h>
|
#include <android/asset_manager_jni.h>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <climits>
|
||||||
|
#include <deque>
|
||||||
#include <exception>
|
#include <exception>
|
||||||
#include <jni.h>
|
#include <jni.h>
|
||||||
#include <tensorflow/lite/c/c_api.h>
|
#include <tensorflow/lite/c/c_api.h>
|
||||||
|
@ -26,7 +28,6 @@ constexpr const char *MODEL = "tf/lite-model_mobilenetv2-dm05-coco_dr_1.tflite";
|
||||||
constexpr size_t WIDTH = 513;
|
constexpr size_t WIDTH = 513;
|
||||||
constexpr size_t HEIGHT = 513;
|
constexpr size_t HEIGHT = 513;
|
||||||
constexpr unsigned LABEL_COUNT = 21;
|
constexpr unsigned LABEL_COUNT = 21;
|
||||||
constexpr const char *TAG = "deep_lap_3";
|
|
||||||
|
|
||||||
enum struct Label {
|
enum struct Label {
|
||||||
BACKGROUND = 0,
|
BACKGROUND = 0,
|
||||||
|
@ -110,10 +111,41 @@ private:
|
||||||
* 1. Contain only the most significant label (the one with the most pixel)
|
* 1. Contain only the most significant label (the one with the most pixel)
|
||||||
* 2. The label value set to 255
|
* 2. The label value set to 255
|
||||||
* 3. The background set to 0
|
* 3. The background set to 0
|
||||||
*
|
|
||||||
* @param segmentMap
|
|
||||||
*/
|
*/
|
||||||
void postProcessSegmentMap(std::vector<uint8_t> *segmentMap);
|
class SegmentMapPostProcessor {
|
||||||
|
public:
|
||||||
|
explicit SegmentMapPostProcessor(vector<uint8_t> *segmentMap)
|
||||||
|
: segmentMapRef(*segmentMap) {}
|
||||||
|
|
||||||
|
void operator()();
|
||||||
|
|
||||||
|
private:
|
||||||
|
/**
|
||||||
|
* Set the most significant segment to kMostSignificant and all others to 0
|
||||||
|
*
|
||||||
|
* @return true if successful, false if no segments were found
|
||||||
|
*/
|
||||||
|
bool keepMostSignificantSegments();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find a point with the value @a value nearest to the center
|
||||||
|
*
|
||||||
|
* @param value
|
||||||
|
* @return Closest point, or Coord(INT_MAX, INT_MAX) if not found
|
||||||
|
*/
|
||||||
|
Coord findNearestPointToCenter(const uint8_t value) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set connecting points from @a from to @a to
|
||||||
|
*/
|
||||||
|
void flood(const int atX, const int atY, const uint8_t from,
|
||||||
|
const uint8_t to);
|
||||||
|
|
||||||
|
vector<uint8_t> &segmentMapRef;
|
||||||
|
|
||||||
|
static constexpr const char *TAG = "SegmentMapPostProcessor";
|
||||||
|
static constexpr int kMostSignificant = 0xFE;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -210,7 +242,7 @@ vector<uint8_t> DeepLab3Portrait::infer(const uint8_t *image,
|
||||||
const size_t width, const size_t height,
|
const size_t width, const size_t height,
|
||||||
const unsigned radius) {
|
const unsigned radius) {
|
||||||
auto segmentMap = deepLab.infer(image, width, height);
|
auto segmentMap = deepLab.infer(image, width, height);
|
||||||
postProcessSegmentMap(&segmentMap);
|
SegmentMapPostProcessor{&segmentMap}();
|
||||||
return enhance(image, width, height, segmentMap, radius);
|
return enhance(image, width, height, segmentMap, radius);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,7 +282,7 @@ vector<uint8_t> DeepLab3ColorPop::infer(const uint8_t *image,
|
||||||
const size_t width, const size_t height,
|
const size_t width, const size_t height,
|
||||||
const float weight) {
|
const float weight) {
|
||||||
auto segmentMap = deepLab.infer(image, width, height);
|
auto segmentMap = deepLab.infer(image, width, height);
|
||||||
postProcessSegmentMap(&segmentMap);
|
SegmentMapPostProcessor{&segmentMap}();
|
||||||
return enhance(image, width, height, segmentMap, weight);
|
return enhance(image, width, height, segmentMap, weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,29 +316,128 @@ vector<uint8_t> DeepLab3ColorPop::enhance(const uint8_t *image,
|
||||||
return rgba8ToRgb8(desaturate.data(), width, height);
|
return rgba8ToRgb8(desaturate.data(), width, height);
|
||||||
}
|
}
|
||||||
|
|
||||||
void postProcessSegmentMap(vector<uint8_t> *segmentMap) {
|
void SegmentMapPostProcessor::operator()() {
|
||||||
|
if (!keepMostSignificantSegments()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto pt = findNearestPointToCenter(kMostSignificant);
|
||||||
|
if (pt.x == INT_MAX && pt.y == INT_MAX) {
|
||||||
|
// no segment?
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
flood(pt.x, pt.y, kMostSignificant, 0xFF);
|
||||||
|
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
||||||
|
if (segmentMapRef[i] == kMostSignificant) {
|
||||||
|
segmentMapRef[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SegmentMapPostProcessor::keepMostSignificantSegments() {
|
||||||
// keep only the largest segment
|
// keep only the largest segment
|
||||||
vector<uint8_t> &segmentMapRef = *segmentMap;
|
|
||||||
vector<int> count(LABEL_COUNT);
|
vector<int> count(LABEL_COUNT);
|
||||||
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
||||||
assert(segmentMapRef[i] < LABEL_COUNT);
|
assert(segmentMapRef[i] < LABEL_COUNT);
|
||||||
const auto label = std::min<unsigned>(segmentMapRef[i], LABEL_COUNT);
|
const auto label = std::min<unsigned>(segmentMapRef[i], LABEL_COUNT - 1);
|
||||||
if (label != static_cast<int>(Label::BACKGROUND)) {
|
if (label != static_cast<int>(Label::BACKGROUND)) {
|
||||||
++count[label];
|
++count[label];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const auto keep = distance(
|
const auto keep = distance(
|
||||||
count.data(), max_element(count.data(), count.data() + count.size()));
|
count.data(), max_element(count.data(), count.data() + count.size()));
|
||||||
LOGI(TAG, "[postProcessSegmentMap] Label to keep: %d",
|
LOGI(TAG, "[keepMostSignificantSegments] Label to keep: %d",
|
||||||
static_cast<int>(keep));
|
static_cast<int>(keep));
|
||||||
|
if (keep == static_cast<int>(Label::BACKGROUND)) {
|
||||||
|
// no segment found, keep all
|
||||||
|
std::fill(segmentMapRef.begin(), segmentMapRef.end(), 0xFF);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
||||||
if (segmentMapRef[i] == keep) {
|
if (segmentMapRef[i] == keep) {
|
||||||
segmentMapRef[i] = 0xFF;
|
segmentMapRef[i] = kMostSignificant;
|
||||||
} else {
|
} else {
|
||||||
segmentMapRef[i] = 0;
|
segmentMapRef[i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Coord SegmentMapPostProcessor::findNearestPointToCenter(
|
||||||
|
const uint8_t value) const {
|
||||||
|
LOGI(TAG, "[findNearestPointToCenter] Find nearest point of: 0x%X", value);
|
||||||
|
deque<Coord> checks;
|
||||||
|
vector<uint8_t> done(WIDTH * HEIGHT);
|
||||||
|
checks.push_back(Coord(WIDTH / 2, HEIGHT / 2));
|
||||||
|
while (!checks.empty()) {
|
||||||
|
const auto &c = checks.front();
|
||||||
|
checks.pop_front();
|
||||||
|
|
||||||
|
const auto i = c.y * WIDTH + c.x;
|
||||||
|
if (done[i]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
done[i] = true;
|
||||||
|
if (segmentMapRef[i] == value) {
|
||||||
|
LOGI(TAG, "[findNearestPointToCenter] Found: (%d, %d)", c.x, c.y);
|
||||||
|
return c;
|
||||||
|
} else {
|
||||||
|
// check surroundings
|
||||||
|
for (int dy = -1; dy <= 1; ++dy) {
|
||||||
|
if (c.y + dy < 0 || c.y + dy >= HEIGHT) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int dx = -1; dx <= 1; ++dx) {
|
||||||
|
if (c.x + dx < 0 || c.x + dx >= WIDTH) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (dx == 0 && dy == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!done[(c.y + dy) * WIDTH + (c.x + dx)]) {
|
||||||
|
checks.push_back(Coord(c.x + dx, c.y + dy));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// no results
|
||||||
|
LOGI(TAG, "[findNearestPointToCenter] Not found");
|
||||||
|
return Coord(INT_MAX, INT_MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SegmentMapPostProcessor::flood(const int atX, const int atY,
|
||||||
|
const uint8_t from, const uint8_t to) {
|
||||||
|
LOGI(TAG, "[flood] At: (%d, %d), 0x%X -> 0x%X", atX, atY, from, to);
|
||||||
|
deque<Coord> checks;
|
||||||
|
checks.push_back(Coord(atX, atY));
|
||||||
|
while (!checks.empty()) {
|
||||||
|
const auto &c = checks.front();
|
||||||
|
checks.pop_front();
|
||||||
|
|
||||||
|
const auto i = c.y * WIDTH + c.x;
|
||||||
|
if (segmentMapRef[i] == from) {
|
||||||
|
segmentMapRef[i] = to;
|
||||||
|
// check surroundings
|
||||||
|
for (int dy = -1; dy <= 1; ++dy) {
|
||||||
|
if (c.y + dy < 0 || c.y + dy >= HEIGHT) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int dx = -1; dx <= 1 && c.x + dx >= 0 && c.x + dx < WIDTH; ++dx) {
|
||||||
|
if (c.x + dx < 0 || c.x + dx >= WIDTH) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (dx == 0 && dy == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (segmentMapRef[(c.y + dy) * WIDTH + (c.x + dx)] != to) {
|
||||||
|
checks.push_back(Coord(c.x + dx, c.y + dy));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -57,6 +57,14 @@ private:
|
||||||
AAsset *asset = nullptr;
|
AAsset *asset = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct Coord {
|
||||||
|
Coord() : Coord(0, 0) {}
|
||||||
|
Coord(const int x, const int y) : x(x), y(y) {}
|
||||||
|
|
||||||
|
const int x;
|
||||||
|
const int y;
|
||||||
|
};
|
||||||
|
|
||||||
void initOpenMp();
|
void initOpenMp();
|
||||||
int getNumberOfProcessors();
|
int getNumberOfProcessors();
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue