Add enhancement: super resolution with ESRGAN

This commit is contained in:
Ming Ming 2022-05-21 03:40:14 +08:00
parent b75c4dd2f6
commit 038f2a3b39
14 changed files with 438 additions and 7 deletions

View file

@ -12,3 +12,5 @@ const enhanceZeroDceUrl =
"https://gitlab.com/nkming2/nc-photos/-/wikis/help/enhance/zero-dce";
const enhanceDeepLabPortraitBlurUrl =
"https://gitlab.com/nkming2/nc-photos/-/wikis/help/enhance/portrait-blur-(deeplab)";
const enhanceEsrganUrl =
"https://gitlab.com/nkming2/nc-photos/-/wikis/help/enhance/esrgan";

View file

@ -1209,6 +1209,10 @@
"@enhancePortraitBlurParamBlurLabel": {
"description": "This parameter sets the radius of the blur filter"
},
"enhanceSuperResolution4xTitle": "Super-resolution (4x)",
"@enhanceSuperResolution4xTitle": {
"description": "Upscale an image. The algorithm implemented in the app will upscale to 4x the original resolution (eg, 100x100 to 400x400)"
},
"errorUnauthenticated": "Unauthenticated access. Please sign-in again if the problem continues",
"@errorUnauthenticated": {

View file

@ -98,6 +98,7 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle",
"errorAlbumDowngrade"
],
@ -214,6 +215,7 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle",
"errorAlbumDowngrade"
],
@ -385,6 +387,7 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle",
"errorAlbumDowngrade"
],
@ -406,7 +409,12 @@
"collectionEnhancedPhotosLabel",
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel"
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
],
"fi": [
"enhanceSuperResolution4xTitle"
],
"fr": [
@ -427,7 +435,8 @@
"collectionEnhancedPhotosLabel",
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel"
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
],
"pl": [
@ -465,7 +474,8 @@
"collectionEnhancedPhotosLabel",
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel"
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
],
"pt": [
@ -482,7 +492,8 @@
"collectionEnhancedPhotosLabel",
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel"
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
],
"ru": [
@ -499,7 +510,8 @@
"collectionEnhancedPhotosLabel",
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel"
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
],
"zh": [
@ -516,7 +528,8 @@
"collectionEnhancedPhotosLabel",
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel"
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
],
"zh_Hant": [
@ -533,6 +546,7 @@
"collectionEnhancedPhotosLabel",
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel"
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
]
}

View file

@ -74,6 +74,18 @@ class EnhanceHandler {
},
);
break;
case _Algorithm.esrgan:
await ImageProcessor.esrgan(
"${account.url}/${file.path}",
file.filename,
Pref().getEnhanceMaxWidthOr(),
Pref().getEnhanceMaxHeightOr(),
headers: {
"Authorization": Api.getAuthorizationHeaderValue(account),
},
);
break;
}
}
@ -175,6 +187,13 @@ class EnhanceHandler {
link: enhanceDeepLabPortraitBlurUrl,
algorithm: _Algorithm.deepLab3Portrait,
),
if (platform_k.isAndroid)
_Option(
title: L10n.global().enhanceSuperResolution4xTitle,
subtitle: "ESRGAN",
link: enhanceEsrganUrl,
algorithm: _Algorithm.esrgan,
),
];
Future<Map<String, dynamic>?> _getArgs(
@ -185,6 +204,9 @@ class EnhanceHandler {
case _Algorithm.deepLab3Portrait:
return _getDeepLab3PortraitArgs(context);
case _Algorithm.esrgan:
return {};
}
}
@ -301,6 +323,7 @@ class EnhanceHandler {
enum _Algorithm {
zeroDce,
deepLab3Portrait,
esrgan,
}
class _Option {

Binary file not shown.

View file

@ -34,7 +34,9 @@ add_library( # Sets the name of the library.
# Provides a relative path to your source file(s).
deep_lap_3.cpp
esrgan.cpp
exception.cpp
image_splitter.cpp
stopwatch.cpp
tflite_wrapper.cpp
util.cpp

View file

@ -0,0 +1,153 @@
#include "exception.h"
#include "image_splitter.h"
#include "log.h"
#include "stopwatch.h"
#include "tflite_wrapper.h"
#include "util.h"
#include <algorithm>
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <cassert>
#include <cstdint>
#include <cstring>
#include <exception>
#include <jni.h>
#include <omp.h>
#include <tensorflow/lite/c/c_api.h>
#include <vector>
using namespace plugin;
using namespace std;
using namespace tflite;
namespace {
constexpr const char *MODEL = "esrgan-tf2_1-dr.tflite";
constexpr const size_t TILE_SIZE = 118;
constexpr const size_t TILE_PADDING = 10;
class Esrgan {
public:
explicit Esrgan(AAssetManager *const aam);
std::vector<uint8_t> infer(const uint8_t *image, const size_t width,
const size_t height);
private:
std::vector<uint8_t> inferSingle(const uint8_t *image, const size_t width,
const size_t height);
std::vector<uint8_t>
joinTiles(const std::vector<std::vector<ImageTile>> &tiles);
Model model;
static constexpr const char *TAG = "Esrgan";
};
} // namespace
extern "C" JNIEXPORT jbyteArray JNICALL
Java_com_nkming_nc_1photos_plugin_image_1processor_Esrgan_inferNative(
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
jint width, jint height) {
try {
initOpenMp();
auto aam = AAssetManager_fromJava(env, assetManager);
Esrgan model(Esrgan{aam});
RaiiContainer<jbyte> cImage(
[&]() { return env->GetByteArrayElements(image, nullptr); },
[&](jbyte *obj) {
env->ReleaseByteArrayElements(image, obj, JNI_ABORT);
});
const auto result =
model.infer(reinterpret_cast<uint8_t *>(cImage.get()), width, height);
auto resultAry = env->NewByteArray(result.size());
env->SetByteArrayRegion(resultAry, 0, result.size(),
reinterpret_cast<const int8_t *>(result.data()));
return resultAry;
} catch (const exception &e) {
throwJavaException(env, e.what());
return nullptr;
}
}
namespace {
Esrgan::Esrgan(AAssetManager *const aam) : model(Asset(aam, MODEL)) {}
vector<uint8_t> Esrgan::infer(const uint8_t *image, const size_t width,
const size_t height) {
// doing ESRGAN in one pass requires loads of memory, so we split the image
// into smaller tiles
const ImageSplitter splitter(TILE_SIZE, TILE_SIZE, TILE_PADDING);
auto tiles = splitter(image, width, height, 3);
const size_t tileCount = tiles.size() * (tiles.empty() ? 0 : tiles[0].size());
auto i = 0;
for (auto &row : tiles) {
for (auto &t : row) {
LOGI(TAG, "[infer] Tile#%d/%zu", i++, tileCount);
auto result = inferSingle(t.data().data(), t.width(), t.height());
t = ImageTile(move(result), t.width() * 4, t.height() * 4, 3);
}
}
// when joining tiles, we use half of paddings from next tile to cover the
// prev tile
vector<uint8_t> output(width * 4 * height * 4 * 3);
#pragma omp parallel for
for (size_t ty = 0; ty < tiles.size(); ++ty) {
const auto thisTilePaddingH = ty == 0 ? 0 : TILE_PADDING * 4;
const auto thisTileBegY = thisTilePaddingH / 2;
for (size_t tx = 0; tx < tiles[ty].size(); ++tx) {
const auto thisTilePaddingW = tx == 0 ? 0 : TILE_PADDING * 4;
const auto thisTileBegX = thisTilePaddingW / 2;
const ImageTile &tile = tiles[ty][tx];
const auto thisTileEndY =
tile.height() - (ty == tiles.size() - 1 ? 0 : TILE_PADDING * 4 / 2);
const auto thisTileEndX =
tile.width() -
(tx == tiles[ty].size() - 1 ? 0 : TILE_PADDING * 4 / 2);
for (size_t dy = thisTileBegY; dy < thisTileEndY; ++dy) {
const auto srcOffset = (dy * tile.width() + thisTileBegX) * 3;
const auto dstOffset =
((ty * TILE_SIZE * 4 - thisTilePaddingH + dy) * width * 4 +
tx * TILE_SIZE * 4 - thisTilePaddingW + thisTileBegX) *
3;
memcpy(output.data() + dstOffset, tile.data().data() + srcOffset,
(thisTileEndX - thisTileBegX) * 3);
}
}
}
return output;
}
vector<uint8_t> Esrgan::inferSingle(const uint8_t *image, const size_t width,
const size_t height) {
InterpreterOptions options;
options.setNumThreads(getNumberOfProcessors());
Interpreter interpreter(model, options);
const int dims[] = {1, static_cast<int>(height), static_cast<int>(width), 3};
interpreter.resizeInputTensor(0, dims, 4);
interpreter.allocateTensors();
LOGI(TAG, "[inferSingle] Convert bitmap to input");
const auto input = rgb8ToRgbFloat(image, width * height * 3, false);
auto inputTensor = interpreter.getInputTensor(0);
assert(TfLiteTensorByteSize(inputTensor) == input.size() * sizeof(float));
TfLiteTensorCopyFromBuffer(inputTensor, input.data(),
input.size() * sizeof(float));
LOGI(TAG, "[inferSingle] Inferring");
Stopwatch stopwatch;
interpreter.invoke();
LOGI(TAG, "[inferSingle] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f);
auto outputTensor = interpreter.getOutputTensor(0);
vector<float> output(width * 4 * height * 4 * 3);
assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float));
TfLiteTensorCopyToBuffer(outputTensor, output.data(),
output.size() * sizeof(float));
return rgbFloatToRgb8(output.data(), output.size(), false);
}
} // namespace

View file

@ -0,0 +1,16 @@
#pragma once
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT jbyteArray JNICALL
Java_com_nkming_nc_1photos_plugin_image_1processor_Esrgan_inferNative(
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
jint width, jint height);
#ifdef __cplusplus
}
#endif

View file

@ -0,0 +1,75 @@
#include "image_splitter.h"
#include "log.h"
#include "util.h"
#include <cmath>
#include <cstdint>
#include <cstring>
#include <deque>
#include <vector>
using namespace std;
namespace plugin {
ImageTile::ImageTile() : _width(0), _height(0), _channel(3) {}
ImageTile::ImageTile(vector<uint8_t> &&data, const size_t width,
const size_t height, const unsigned channel)
: _data(move(data)), _width(width), _height(height), _channel(channel) {}
ImageTile &ImageTile::operator=(ImageTile &&rhs) {
if (this != &rhs) {
_data = move(rhs._data);
_width = rhs._width;
rhs._width = 0;
_height = rhs._height;
rhs._height = 0;
_channel = rhs._channel;
}
return *this;
}
ImageSplitter::ImageSplitter(const size_t tileWidth, const size_t tileHeight,
const size_t padding)
: _tileWidth(tileWidth), _tileHeight(tileHeight), _padding(padding) {}
deque<deque<ImageTile>>
ImageSplitter::operator()(const uint8_t *image, const size_t width,
const size_t height, const unsigned channel) const {
const size_t tileHoriz = ceil(static_cast<float>(width) / _tileWidth);
const size_t tileVert = ceil(static_cast<float>(height) / _tileHeight);
deque<deque<ImageTile>> product(tileVert);
for (auto &r : product) {
for (int i = 0; i < tileHoriz; ++i) {
r.emplace_back();
}
}
LOGD(TAG, "[split] Spliting %zux%zu into %zux%zu tiles, each %zux%zu in size",
width, height, tileHoriz, tileVert, _tileWidth, _tileHeight);
for (size_t ty = 0; ty < tileVert; ++ty) {
const size_t thisTilePaddingY = ty == 0 ? 0 : _padding;
const size_t thisTileH =
std::min(_tileHeight, height - _tileHeight * ty) + thisTilePaddingY;
for (size_t tx = 0; tx < tileHoriz; ++tx) {
const size_t thisTilePaddingX = tx == 0 ? 0 : _padding;
const size_t thisTileW =
std::min(_tileWidth, width - _tileWidth * tx) + thisTilePaddingX;
LOGI(TAG, "[split] Tile[%zu][%zu]: %zux%zu", ty, tx, thisTileW,
thisTileH);
vector<uint8_t> tile(thisTileW * thisTileH * channel);
for (size_t dy = 0; dy < thisTileH; ++dy) {
const auto srcOffset =
((ty * _tileHeight + dy - thisTilePaddingY) * width +
(tx * _tileWidth - thisTilePaddingX)) *
channel;
const auto dstOffset = dy * thisTileW * channel;
memcpy(tile.data() + dstOffset, image + srcOffset, thisTileW * channel);
}
product[ty][tx] = ImageTile(move(tile), thisTileW, thisTileH, channel);
}
}
return product;
}
} // namespace plugin

View file

@ -0,0 +1,54 @@
#include <cstdint>
#include <deque>
#include <vector>
namespace plugin {
class ImageTile {
public:
ImageTile();
ImageTile(std::vector<uint8_t> &&data, const size_t width,
const size_t height, const unsigned channel);
ImageTile &operator=(const ImageTile &) = delete;
ImageTile &operator=(ImageTile &&rhs);
const std::vector<uint8_t> &data() const { return _data; }
size_t width() const { return _width; }
size_t height() const { return _height; }
unsigned channel() const { return _channel; }
private:
std::vector<uint8_t> _data;
size_t _width;
size_t _height;
unsigned _channel;
};
/**
* Split an image to smaller tiles with optional padding
*
* If padding > 0, each tile will have extra pixels belonging to the previous
* tiles in both axis
*/
class ImageSplitter {
public:
ImageSplitter(const size_t tileWidth, const size_t tileHeight)
: ImageSplitter(tileWidth, tileHeight, 0) {}
ImageSplitter(const size_t tileWidth, const size_t tileHeight,
const size_t padding);
std::deque<std::deque<ImageTile>> operator()(const uint8_t *image,
const size_t width,
const size_t height,
const unsigned channel) const;
private:
const size_t _tileWidth;
const size_t _tileHeight;
const size_t _padding;
static constexpr const char *TAG = "ImageSplitter";
};
} // namespace plugin

View file

@ -49,6 +49,21 @@ class ImageProcessorChannelHandler(context: Context) :
}
}
"esrgan" -> {
try {
esrgan(
call.argument("fileUrl")!!,
call.argument("headers"),
call.argument("filename")!!,
call.argument("maxWidth")!!,
call.argument("maxHeight")!!,
result
)
} catch (e: Throwable) {
result.error("systemException", e.toString(), null)
}
}
else -> result.notImplemented()
}
}
@ -82,6 +97,14 @@ class ImageProcessorChannelHandler(context: Context) :
}
)
private fun esrgan(
fileUrl: String, headers: Map<String, String>?, filename: String,
maxWidth: Int, maxHeight: Int, result: MethodChannel.Result
) = method(
fileUrl, headers, filename, maxWidth, maxHeight,
ImageProcessorService.METHOD_ESRGAN, result
)
private fun method(
fileUrl: String, headers: Map<String, String>?, filename: String,
maxWidth: Int, maxHeight: Int, method: String,

View file

@ -18,6 +18,7 @@ import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat
import androidx.exifinterface.media.ExifInterface
import com.nkming.nc_photos.plugin.image_processor.DeepLab3Portrait
import com.nkming.nc_photos.plugin.image_processor.Esrgan
import com.nkming.nc_photos.plugin.image_processor.ZeroDce
import java.io.File
import java.net.HttpURLConnection
@ -28,6 +29,7 @@ class ImageProcessorService : Service() {
const val EXTRA_METHOD = "method"
const val METHOD_ZERO_DCE = "zero-dce"
const val METHOD_DEEP_LAP_PORTRAIT = "DeepLab3Portrait"
const val METHOD_ESRGAN = "Esrgan"
const val EXTRA_FILE_URL = "fileUrl"
const val EXTRA_HEADERS = "headers"
const val EXTRA_FILENAME = "filename"
@ -99,6 +101,7 @@ class ImageProcessorService : Service() {
METHOD_DEEP_LAP_PORTRAIT -> onDeepLapPortrait(
startId, intent.extras!!
)
METHOD_ESRGAN -> onEsrgan(startId, intent.extras!!)
else -> {
logE(TAG, "Unknown method: $method")
// we can't call stopSelf here as it'll stop the service even if
@ -126,6 +129,10 @@ class ImageProcessorService : Service() {
)
}
private fun onEsrgan(startId: Int, extras: Bundle) {
return onMethod(startId, extras, METHOD_ESRGAN)
}
/**
* Handle methods without arguments
*
@ -449,6 +456,10 @@ private open class ImageProcessorCommandTask(context: Context) :
cmd.args["radius"] as? Int ?: 16
).infer(fileUri)
ImageProcessorService.METHOD_ESRGAN -> Esrgan(
context, cmd.maxWidth, cmd.maxHeight
).infer(fileUri)
else -> throw IllegalArgumentException(
"Unknown method: ${cmd.method}"
)

View file

@ -0,0 +1,39 @@
package com.nkming.nc_photos.plugin.image_processor
import android.content.Context
import android.content.res.AssetManager
import android.graphics.Bitmap
import android.net.Uri
import com.nkming.nc_photos.plugin.BitmapResizeMethod
import com.nkming.nc_photos.plugin.BitmapUtil
class Esrgan(context: Context, maxWidth: Int, maxHeight: Int) {
fun infer(imageUri: Uri): Bitmap {
val width: Int
val height: Int
val rgb8Image = BitmapUtil.loadImage(
context, imageUri, maxWidth / 4, maxHeight / 4,
BitmapResizeMethod.FIT, isAllowSwapSide = true,
shouldUpscale = false
).let {
width = it.width
height = it.height
val rgb8 = TfLiteHelper.bitmapToRgb8Array(it)
it.recycle()
rgb8
}
val am = context.assets
return inferNative(am, rgb8Image, width, height).let {
TfLiteHelper.rgb8ArrayToBitmap(it, width * 4, height * 4)
}
}
private external fun inferNative(
am: AssetManager, image: ByteArray, width: Int, height: Int
): ByteArray
private val context = context
private val maxWidth = maxWidth
private val maxHeight = maxHeight
}

View file

@ -38,6 +38,21 @@ class ImageProcessor {
"radius": radius,
});
static Future<void> esrgan(
String fileUrl,
String filename,
int maxWidth,
int maxHeight, {
Map<String, String>? headers,
}) =>
_methodChannel.invokeMethod("esrgan", <String, dynamic>{
"fileUrl": fileUrl,
"headers": headers,
"filename": filename,
"maxWidth": maxWidth,
"maxHeight": maxHeight,
});
static const _methodChannel =
MethodChannel("${k.libId}/image_processor_method");
}