mirror of
https://gitlab.com/nkming2/nc-photos.git
synced 2025-03-25 00:14:42 +01:00
312 lines
10 KiB
C++
312 lines
10 KiB
C++
#include "exception.h"
|
|
#include "lib/base_resample.h"
|
|
#include "log.h"
|
|
#include "stopwatch.h"
|
|
#include "tflite_wrapper.h"
|
|
#include "util.h"
|
|
#include <RenderScriptToolkit.h>
|
|
#include <algorithm>
|
|
#include <android/asset_manager.h>
|
|
#include <android/asset_manager_jni.h>
|
|
#include <cassert>
|
|
#include <exception>
|
|
#include <jni.h>
|
|
#include <tensorflow/lite/c/c_api.h>
|
|
|
|
#include "./filter/saturation.h"
|
|
|
|
using namespace plugin;
|
|
using namespace renderscript;
|
|
using namespace std;
|
|
using namespace tflite;
|
|
|
|
namespace {
|
|
|
|
constexpr const char *MODEL = "tf/lite-model_mobilenetv2-dm05-coco_dr_1.tflite";
|
|
constexpr size_t WIDTH = 513;
|
|
constexpr size_t HEIGHT = 513;
|
|
constexpr unsigned LABEL_COUNT = 21;
|
|
constexpr const char *TAG = "deep_lap_3";
|
|
|
|
enum struct Label {
|
|
BACKGROUND = 0,
|
|
AEROPLANE,
|
|
BICYCLE,
|
|
BIRD,
|
|
BOAT,
|
|
BOTTLE,
|
|
BUS,
|
|
CAR,
|
|
CAT,
|
|
CHAIR,
|
|
COW,
|
|
DINING_TABLE,
|
|
DOG,
|
|
HORSE,
|
|
MOTORBIKE,
|
|
PERSON,
|
|
POTTED_PLANT,
|
|
SHEEP,
|
|
SOFA,
|
|
TRAIN,
|
|
TV,
|
|
};
|
|
|
|
class DeepLab3 {
|
|
public:
|
|
explicit DeepLab3(AAssetManager *const aam);
|
|
DeepLab3(const DeepLab3 &) = delete;
|
|
DeepLab3(DeepLab3 &&) = default;
|
|
|
|
std::vector<uint8_t> infer(const uint8_t *image, const size_t width,
|
|
const size_t height);
|
|
|
|
private:
|
|
Model model;
|
|
|
|
static constexpr const char *TAG = "DeepLap3";
|
|
};
|
|
|
|
class DeepLab3Portrait {
|
|
public:
|
|
explicit DeepLab3Portrait(DeepLab3 &&deepLab);
|
|
|
|
std::vector<uint8_t> infer(const uint8_t *image, const size_t width,
|
|
const size_t height, const unsigned radius);
|
|
|
|
private:
|
|
std::vector<uint8_t> enhance(const uint8_t *image, const size_t width,
|
|
const size_t height,
|
|
const std::vector<uint8_t> &segmentMap,
|
|
const unsigned radius);
|
|
|
|
DeepLab3 deepLab;
|
|
|
|
static constexpr const char *TAG = "DeepLab3Portrait";
|
|
};
|
|
|
|
class DeepLab3ColorPop {
|
|
public:
|
|
explicit DeepLab3ColorPop(DeepLab3 &&deepLab);
|
|
|
|
std::vector<uint8_t> infer(const uint8_t *image, const size_t width,
|
|
const size_t height, const float weight);
|
|
|
|
private:
|
|
std::vector<uint8_t> enhance(const uint8_t *image, const size_t width,
|
|
const size_t height,
|
|
const std::vector<uint8_t> &segmentMap,
|
|
const float weight);
|
|
|
|
DeepLab3 deepLab;
|
|
|
|
static constexpr const char *TAG = "DeepLab3ColorPop";
|
|
};
|
|
|
|
/**
|
|
* Post-process the segment map.
|
|
*
|
|
* The resulting segment map will:
|
|
* 1. Contain only the most significant label (the one with the most pixel)
|
|
* 2. The label value set to 255
|
|
* 3. The background set to 0
|
|
*
|
|
* @param segmentMap
|
|
*/
|
|
void postProcessSegmentMap(std::vector<uint8_t> *segmentMap);
|
|
|
|
} // namespace
|
|
|
|
extern "C" JNIEXPORT jbyteArray JNICALL
|
|
Java_com_nkming_nc_1photos_plugin_image_1processor_DeepLab3Portrait_inferNative(
|
|
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
|
|
jint width, jint height, jint radius) {
|
|
try {
|
|
initOpenMp();
|
|
auto aam = AAssetManager_fromJava(env, assetManager);
|
|
DeepLab3Portrait model(DeepLab3{aam});
|
|
RaiiContainer<jbyte> cImage(
|
|
[&]() { return env->GetByteArrayElements(image, nullptr); },
|
|
[&](jbyte *obj) {
|
|
env->ReleaseByteArrayElements(image, obj, JNI_ABORT);
|
|
});
|
|
const auto result = model.infer(reinterpret_cast<uint8_t *>(cImage.get()),
|
|
width, height, radius);
|
|
auto resultAry = env->NewByteArray(result.size());
|
|
env->SetByteArrayRegion(resultAry, 0, result.size(),
|
|
reinterpret_cast<const int8_t *>(result.data()));
|
|
return resultAry;
|
|
} catch (const exception &e) {
|
|
throwJavaException(env, e.what());
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
extern "C" JNIEXPORT jbyteArray JNICALL
|
|
Java_com_nkming_nc_1photos_plugin_image_1processor_DeepLab3ColorPop_inferNative(
|
|
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
|
|
jint width, jint height, jfloat weight) {
|
|
try {
|
|
initOpenMp();
|
|
auto aam = AAssetManager_fromJava(env, assetManager);
|
|
DeepLab3ColorPop model(DeepLab3{aam});
|
|
RaiiContainer<jbyte> cImage(
|
|
[&]() { return env->GetByteArrayElements(image, nullptr); },
|
|
[&](jbyte *obj) {
|
|
env->ReleaseByteArrayElements(image, obj, JNI_ABORT);
|
|
});
|
|
const auto result = model.infer(reinterpret_cast<uint8_t *>(cImage.get()),
|
|
width, height, weight);
|
|
auto resultAry = env->NewByteArray(result.size());
|
|
env->SetByteArrayRegion(resultAry, 0, result.size(),
|
|
reinterpret_cast<const int8_t *>(result.data()));
|
|
return resultAry;
|
|
} catch (const exception &e) {
|
|
throwJavaException(env, e.what());
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
DeepLab3::DeepLab3(AAssetManager *const aam) : model(Asset(aam, MODEL)) {}
|
|
|
|
vector<uint8_t> DeepLab3::infer(const uint8_t *image, const size_t width,
|
|
const size_t height) {
|
|
InterpreterOptions options;
|
|
options.setNumThreads(getNumberOfProcessors());
|
|
Interpreter interpreter(model, options);
|
|
interpreter.allocateTensors();
|
|
|
|
LOGI(TAG, "[infer] Convert bitmap to input");
|
|
vector<uint8_t> inputBitmap(WIDTH * HEIGHT * 3);
|
|
base::ResampleImage24(image, width, height, inputBitmap.data(), WIDTH, HEIGHT,
|
|
base::KernelTypeLanczos3);
|
|
const auto input =
|
|
rgb8ToRgbFloat(inputBitmap.data(), inputBitmap.size(), true);
|
|
auto inputTensor = interpreter.getInputTensor(0);
|
|
assert(TfLiteTensorByteSize(inputTensor) == input.size() * sizeof(float));
|
|
TfLiteTensorCopyFromBuffer(inputTensor, input.data(),
|
|
input.size() * sizeof(float));
|
|
|
|
LOGI(TAG, "[infer] Inferring");
|
|
Stopwatch stopwatch;
|
|
interpreter.invoke();
|
|
LOGI(TAG, "[infer] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f);
|
|
|
|
auto outputTensor = interpreter.getOutputTensor(0);
|
|
vector<float> output(WIDTH * HEIGHT * LABEL_COUNT);
|
|
assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float));
|
|
TfLiteTensorCopyToBuffer(outputTensor, output.data(),
|
|
output.size() * sizeof(float));
|
|
const auto i1 = (200 * 513 + 260) * LABEL_COUNT;
|
|
return argmax(output.data(), WIDTH, HEIGHT, LABEL_COUNT);
|
|
}
|
|
|
|
DeepLab3Portrait::DeepLab3Portrait(DeepLab3 &&deepLab)
|
|
: deepLab(move(deepLab)) {}
|
|
|
|
vector<uint8_t> DeepLab3Portrait::infer(const uint8_t *image,
|
|
const size_t width, const size_t height,
|
|
const unsigned radius) {
|
|
auto segmentMap = deepLab.infer(image, width, height);
|
|
postProcessSegmentMap(&segmentMap);
|
|
return enhance(image, width, height, segmentMap, radius);
|
|
}
|
|
|
|
vector<uint8_t> DeepLab3Portrait::enhance(const uint8_t *image,
|
|
const size_t width,
|
|
const size_t height,
|
|
const vector<uint8_t> &segmentMap,
|
|
const unsigned radius) {
|
|
LOGI(TAG, "[enhance] Enhancing image");
|
|
// resize alpha to input size
|
|
vector<uint8_t> alpha(width * height);
|
|
base::ResampleImage<1>(segmentMap.data(), WIDTH, HEIGHT, alpha.data(), width,
|
|
height, base::KernelTypeLanczos3);
|
|
// smoothen the edge
|
|
vector<uint8_t> alphaFiltered(width * height);
|
|
getToolkitInst().blur(alpha.data(), alphaFiltered.data(), width, height, 1,
|
|
16);
|
|
alpha.clear();
|
|
|
|
// blur input
|
|
auto rgba8 = rgb8ToRgba8(image, width, height);
|
|
vector<uint8_t> blur(width * height * 4);
|
|
getToolkitInst().blur(rgba8.data(), blur.data(), width, height, 4, radius);
|
|
|
|
// draw input on top of blurred image, with alpha map
|
|
replaceChannel<4>(rgba8.data(), alphaFiltered.data(), width, height, 3);
|
|
alphaFiltered.clear();
|
|
alphaBlend(rgba8.data(), blur.data(), width, height);
|
|
rgba8.clear();
|
|
return rgba8ToRgb8(blur.data(), width, height);
|
|
}
|
|
|
|
DeepLab3ColorPop::DeepLab3ColorPop(DeepLab3 &&deepLab)
|
|
: deepLab(move(deepLab)) {}
|
|
|
|
vector<uint8_t> DeepLab3ColorPop::infer(const uint8_t *image,
|
|
const size_t width, const size_t height,
|
|
const float weight) {
|
|
auto segmentMap = deepLab.infer(image, width, height);
|
|
postProcessSegmentMap(&segmentMap);
|
|
return enhance(image, width, height, segmentMap, weight);
|
|
}
|
|
|
|
vector<uint8_t> DeepLab3ColorPop::enhance(const uint8_t *image,
|
|
const size_t width,
|
|
const size_t height,
|
|
const vector<uint8_t> &segmentMap,
|
|
const float weight) {
|
|
LOGI(TAG, "[enhance] Enhancing image");
|
|
// resize alpha to input size
|
|
vector<uint8_t> alpha(width * height);
|
|
base::ResampleImage<1>(segmentMap.data(), WIDTH, HEIGHT, alpha.data(), width,
|
|
height, base::KernelTypeLanczos3);
|
|
// smoothen the edge
|
|
vector<uint8_t> alphaFiltered(width * height);
|
|
getToolkitInst().blur(alpha.data(), alphaFiltered.data(), width, height, 1,
|
|
4);
|
|
alpha.clear();
|
|
|
|
// desaturate input
|
|
auto rgba8 = rgb8ToRgba8(image, width, height);
|
|
vector<uint8_t> desaturate(width * height * 4);
|
|
plugin::filter::Saturation saturation;
|
|
desaturate = saturation.apply(rgba8.data(), width, height, -1 * weight);
|
|
|
|
// draw input on top of blurred image, with alpha map
|
|
replaceChannel<4>(rgba8.data(), alphaFiltered.data(), width, height, 3);
|
|
alphaFiltered.clear();
|
|
alphaBlend(rgba8.data(), desaturate.data(), width, height);
|
|
rgba8.clear();
|
|
return rgba8ToRgb8(desaturate.data(), width, height);
|
|
}
|
|
|
|
void postProcessSegmentMap(vector<uint8_t> *segmentMap) {
|
|
// keep only the largest segment
|
|
vector<uint8_t> &segmentMapRef = *segmentMap;
|
|
vector<int> count(LABEL_COUNT);
|
|
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
|
assert(segmentMapRef[i] < LABEL_COUNT);
|
|
const auto label = std::min<unsigned>(segmentMapRef[i], LABEL_COUNT);
|
|
if (label != static_cast<int>(Label::BACKGROUND)) {
|
|
++count[label];
|
|
}
|
|
}
|
|
const auto keep = distance(
|
|
count.data(), max_element(count.data(), count.data() + count.size()));
|
|
LOGI(TAG, "[postProcessSegmentMap] Label to keep: %d",
|
|
static_cast<int>(keep));
|
|
#pragma omp parallel for
|
|
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
|
if (segmentMapRef[i] == keep) {
|
|
segmentMapRef[i] = 0xFF;
|
|
} else {
|
|
segmentMapRef[i] = 0;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace
|