#include "exception.h" #include "lib/base_resample.h" #include "log.h" #include "stopwatch.h" #include "tflite_wrapper.h" #include "util.h" #include #include #include #include #include #include #include #include #include "./filter/saturation.h" using namespace plugin; using namespace renderscript; using namespace std; using namespace tflite; namespace { constexpr const char *MODEL = "tf/lite-model_mobilenetv2-dm05-coco_dr_1.tflite"; constexpr size_t WIDTH = 513; constexpr size_t HEIGHT = 513; constexpr unsigned LABEL_COUNT = 21; constexpr const char *TAG = "deep_lap_3"; enum struct Label { BACKGROUND = 0, AEROPLANE, BICYCLE, BIRD, BOAT, BOTTLE, BUS, CAR, CAT, CHAIR, COW, DINING_TABLE, DOG, HORSE, MOTORBIKE, PERSON, POTTED_PLANT, SHEEP, SOFA, TRAIN, TV, }; class DeepLab3 { public: explicit DeepLab3(AAssetManager *const aam); DeepLab3(const DeepLab3 &) = delete; DeepLab3(DeepLab3 &&) = default; std::vector infer(const uint8_t *image, const size_t width, const size_t height); private: Model model; static constexpr const char *TAG = "DeepLap3"; }; class DeepLab3Portrait { public: explicit DeepLab3Portrait(DeepLab3 &&deepLab); std::vector infer(const uint8_t *image, const size_t width, const size_t height, const unsigned radius); private: std::vector enhance(const uint8_t *image, const size_t width, const size_t height, const std::vector &segmentMap, const unsigned radius); DeepLab3 deepLab; static constexpr const char *TAG = "DeepLab3Portrait"; }; class DeepLab3ColorPop { public: explicit DeepLab3ColorPop(DeepLab3 &&deepLab); std::vector infer(const uint8_t *image, const size_t width, const size_t height, const float weight); private: std::vector enhance(const uint8_t *image, const size_t width, const size_t height, const std::vector &segmentMap, const float weight); DeepLab3 deepLab; static constexpr const char *TAG = "DeepLab3ColorPop"; }; /** * Post-process the segment map. * * The resulting segment map will: * 1. Contain only the most significant label (the one with the most pixel) * 2. The label value set to 255 * 3. The background set to 0 * * @param segmentMap */ void postProcessSegmentMap(std::vector *segmentMap); } // namespace extern "C" JNIEXPORT jbyteArray JNICALL Java_com_nkming_nc_1photos_plugin_image_1processor_DeepLab3Portrait_inferNative( JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image, jint width, jint height, jint radius) { try { initOpenMp(); auto aam = AAssetManager_fromJava(env, assetManager); DeepLab3Portrait model(DeepLab3{aam}); RaiiContainer cImage( [&]() { return env->GetByteArrayElements(image, nullptr); }, [&](jbyte *obj) { env->ReleaseByteArrayElements(image, obj, JNI_ABORT); }); const auto result = model.infer(reinterpret_cast(cImage.get()), width, height, radius); auto resultAry = env->NewByteArray(result.size()); env->SetByteArrayRegion(resultAry, 0, result.size(), reinterpret_cast(result.data())); return resultAry; } catch (const exception &e) { throwJavaException(env, e.what()); return nullptr; } } extern "C" JNIEXPORT jbyteArray JNICALL Java_com_nkming_nc_1photos_plugin_image_1processor_DeepLab3ColorPop_inferNative( JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image, jint width, jint height, jfloat weight) { try { initOpenMp(); auto aam = AAssetManager_fromJava(env, assetManager); DeepLab3ColorPop model(DeepLab3{aam}); RaiiContainer cImage( [&]() { return env->GetByteArrayElements(image, nullptr); }, [&](jbyte *obj) { env->ReleaseByteArrayElements(image, obj, JNI_ABORT); }); const auto result = model.infer(reinterpret_cast(cImage.get()), width, height, weight); auto resultAry = env->NewByteArray(result.size()); env->SetByteArrayRegion(resultAry, 0, result.size(), reinterpret_cast(result.data())); return resultAry; } catch (const exception &e) { throwJavaException(env, e.what()); return nullptr; } } namespace { DeepLab3::DeepLab3(AAssetManager *const aam) : model(Asset(aam, MODEL)) {} vector DeepLab3::infer(const uint8_t *image, const size_t width, const size_t height) { InterpreterOptions options; options.setNumThreads(getNumberOfProcessors()); Interpreter interpreter(model, options); interpreter.allocateTensors(); LOGI(TAG, "[infer] Convert bitmap to input"); vector inputBitmap(WIDTH * HEIGHT * 3); base::ResampleImage24(image, width, height, inputBitmap.data(), WIDTH, HEIGHT, base::KernelTypeLanczos3); const auto input = rgb8ToRgbFloat(inputBitmap.data(), inputBitmap.size(), true); auto inputTensor = interpreter.getInputTensor(0); assert(TfLiteTensorByteSize(inputTensor) == input.size() * sizeof(float)); TfLiteTensorCopyFromBuffer(inputTensor, input.data(), input.size() * sizeof(float)); LOGI(TAG, "[infer] Inferring"); Stopwatch stopwatch; interpreter.invoke(); LOGI(TAG, "[infer] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f); auto outputTensor = interpreter.getOutputTensor(0); vector output(WIDTH * HEIGHT * LABEL_COUNT); assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float)); TfLiteTensorCopyToBuffer(outputTensor, output.data(), output.size() * sizeof(float)); const auto i1 = (200 * 513 + 260) * LABEL_COUNT; return argmax(output.data(), WIDTH, HEIGHT, LABEL_COUNT); } DeepLab3Portrait::DeepLab3Portrait(DeepLab3 &&deepLab) : deepLab(move(deepLab)) {} vector DeepLab3Portrait::infer(const uint8_t *image, const size_t width, const size_t height, const unsigned radius) { auto segmentMap = deepLab.infer(image, width, height); postProcessSegmentMap(&segmentMap); return enhance(image, width, height, segmentMap, radius); } vector DeepLab3Portrait::enhance(const uint8_t *image, const size_t width, const size_t height, const vector &segmentMap, const unsigned radius) { LOGI(TAG, "[enhance] Enhancing image"); // resize alpha to input size vector alpha(width * height); base::ResampleImage<1>(segmentMap.data(), WIDTH, HEIGHT, alpha.data(), width, height, base::KernelTypeLanczos3); // smoothen the edge vector alphaFiltered(width * height); getToolkitInst().blur(alpha.data(), alphaFiltered.data(), width, height, 1, 16); alpha.clear(); // blur input auto rgba8 = rgb8ToRgba8(image, width, height); vector blur(width * height * 4); getToolkitInst().blur(rgba8.data(), blur.data(), width, height, 4, radius); // draw input on top of blurred image, with alpha map replaceChannel<4>(rgba8.data(), alphaFiltered.data(), width, height, 3); alphaFiltered.clear(); alphaBlend(rgba8.data(), blur.data(), width, height); rgba8.clear(); return rgba8ToRgb8(blur.data(), width, height); } DeepLab3ColorPop::DeepLab3ColorPop(DeepLab3 &&deepLab) : deepLab(move(deepLab)) {} vector DeepLab3ColorPop::infer(const uint8_t *image, const size_t width, const size_t height, const float weight) { auto segmentMap = deepLab.infer(image, width, height); postProcessSegmentMap(&segmentMap); return enhance(image, width, height, segmentMap, weight); } vector DeepLab3ColorPop::enhance(const uint8_t *image, const size_t width, const size_t height, const vector &segmentMap, const float weight) { LOGI(TAG, "[enhance] Enhancing image"); // resize alpha to input size vector alpha(width * height); base::ResampleImage<1>(segmentMap.data(), WIDTH, HEIGHT, alpha.data(), width, height, base::KernelTypeLanczos3); // smoothen the edge vector alphaFiltered(width * height); getToolkitInst().blur(alpha.data(), alphaFiltered.data(), width, height, 1, 4); alpha.clear(); // desaturate input auto rgba8 = rgb8ToRgba8(image, width, height); vector desaturate(width * height * 4); plugin::filter::Saturation saturation; desaturate = saturation.apply(rgba8.data(), width, height, -1 * weight); // draw input on top of blurred image, with alpha map replaceChannel<4>(rgba8.data(), alphaFiltered.data(), width, height, 3); alphaFiltered.clear(); alphaBlend(rgba8.data(), desaturate.data(), width, height); rgba8.clear(); return rgba8ToRgb8(desaturate.data(), width, height); } void postProcessSegmentMap(vector *segmentMap) { // keep only the largest segment vector &segmentMapRef = *segmentMap; vector count(LABEL_COUNT); for (size_t i = 0; i < segmentMapRef.size(); ++i) { assert(segmentMapRef[i] < LABEL_COUNT); const auto label = std::min(segmentMapRef[i], LABEL_COUNT); if (label != static_cast(Label::BACKGROUND)) { ++count[label]; } } const auto keep = distance( count.data(), max_element(count.data(), count.data() + count.size())); LOGI(TAG, "[postProcessSegmentMap] Label to keep: %d", static_cast(keep)); #pragma omp parallel for for (size_t i = 0; i < segmentMapRef.size(); ++i) { if (segmentMapRef[i] == keep) { segmentMapRef[i] = 0xFF; } else { segmentMapRef[i] = 0; } } } } // namespace