diff --git a/app/lib/help_utils.dart b/app/lib/help_utils.dart index ffe483c9..f8d06576 100644 --- a/app/lib/help_utils.dart +++ b/app/lib/help_utils.dart @@ -12,3 +12,5 @@ const enhanceZeroDceUrl = "https://gitlab.com/nkming2/nc-photos/-/wikis/help/enhance/zero-dce"; const enhanceDeepLabPortraitBlurUrl = "https://gitlab.com/nkming2/nc-photos/-/wikis/help/enhance/portrait-blur-(deeplab)"; +const enhanceEsrganUrl = + "https://gitlab.com/nkming2/nc-photos/-/wikis/help/enhance/esrgan"; diff --git a/app/lib/l10n/app_en.arb b/app/lib/l10n/app_en.arb index e7a67256..34e608f7 100644 --- a/app/lib/l10n/app_en.arb +++ b/app/lib/l10n/app_en.arb @@ -1209,6 +1209,10 @@ "@enhancePortraitBlurParamBlurLabel": { "description": "This parameter sets the radius of the blur filter" }, + "enhanceSuperResolution4xTitle": "Super-resolution (4x)", + "@enhanceSuperResolution4xTitle": { + "description": "Upscale an image. The algorithm implemented in the app will upscale to 4x the original resolution (eg, 100x100 to 400x400)" + }, "errorUnauthenticated": "Unauthenticated access. Please sign-in again if the problem continues", "@errorUnauthenticated": { diff --git a/app/lib/l10n/untranslated-messages.txt b/app/lib/l10n/untranslated-messages.txt index 8a13f82c..4559e80f 100644 --- a/app/lib/l10n/untranslated-messages.txt +++ b/app/lib/l10n/untranslated-messages.txt @@ -98,6 +98,7 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle", "errorAlbumDowngrade" ], @@ -214,6 +215,7 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle", "errorAlbumDowngrade" ], @@ -385,6 +387,7 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle", "errorAlbumDowngrade" ], @@ -406,7 +409,12 @@ "collectionEnhancedPhotosLabel", "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", - "enhancePortraitBlurParamBlurLabel" + "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle" + ], + + "fi": [ + "enhanceSuperResolution4xTitle" ], "fr": [ @@ -427,7 +435,8 @@ "collectionEnhancedPhotosLabel", "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", - "enhancePortraitBlurParamBlurLabel" + "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle" ], "pl": [ @@ -465,7 +474,8 @@ "collectionEnhancedPhotosLabel", "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", - "enhancePortraitBlurParamBlurLabel" + "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle" ], "pt": [ @@ -482,7 +492,8 @@ "collectionEnhancedPhotosLabel", "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", - "enhancePortraitBlurParamBlurLabel" + "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle" ], "ru": [ @@ -499,7 +510,8 @@ "collectionEnhancedPhotosLabel", "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", - "enhancePortraitBlurParamBlurLabel" + "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle" ], "zh": [ @@ -516,7 +528,8 @@ "collectionEnhancedPhotosLabel", "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", - "enhancePortraitBlurParamBlurLabel" + "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle" ], "zh_Hant": [ @@ -533,6 +546,7 @@ "collectionEnhancedPhotosLabel", "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", - "enhancePortraitBlurParamBlurLabel" + "enhancePortraitBlurParamBlurLabel", + "enhanceSuperResolution4xTitle" ] } diff --git a/app/lib/widget/handler/enhance_handler.dart b/app/lib/widget/handler/enhance_handler.dart index 1836e411..05e493ff 100644 --- a/app/lib/widget/handler/enhance_handler.dart +++ b/app/lib/widget/handler/enhance_handler.dart @@ -74,6 +74,18 @@ class EnhanceHandler { }, ); break; + + case _Algorithm.esrgan: + await ImageProcessor.esrgan( + "${account.url}/${file.path}", + file.filename, + Pref().getEnhanceMaxWidthOr(), + Pref().getEnhanceMaxHeightOr(), + headers: { + "Authorization": Api.getAuthorizationHeaderValue(account), + }, + ); + break; } } @@ -175,6 +187,13 @@ class EnhanceHandler { link: enhanceDeepLabPortraitBlurUrl, algorithm: _Algorithm.deepLab3Portrait, ), + if (platform_k.isAndroid) + _Option( + title: L10n.global().enhanceSuperResolution4xTitle, + subtitle: "ESRGAN", + link: enhanceEsrganUrl, + algorithm: _Algorithm.esrgan, + ), ]; Future?> _getArgs( @@ -185,6 +204,9 @@ class EnhanceHandler { case _Algorithm.deepLab3Portrait: return _getDeepLab3PortraitArgs(context); + + case _Algorithm.esrgan: + return {}; } } @@ -301,6 +323,7 @@ class EnhanceHandler { enum _Algorithm { zeroDce, deepLab3Portrait, + esrgan, } class _Option { diff --git a/plugin/android/src/main/assets/esrgan-tf2_1-dr.tflite b/plugin/android/src/main/assets/esrgan-tf2_1-dr.tflite new file mode 100644 index 00000000..e1ed5656 Binary files /dev/null and b/plugin/android/src/main/assets/esrgan-tf2_1-dr.tflite differ diff --git a/plugin/android/src/main/cpp/CMakeLists.txt b/plugin/android/src/main/cpp/CMakeLists.txt index 3103a58b..32673387 100644 --- a/plugin/android/src/main/cpp/CMakeLists.txt +++ b/plugin/android/src/main/cpp/CMakeLists.txt @@ -34,7 +34,9 @@ add_library( # Sets the name of the library. # Provides a relative path to your source file(s). deep_lap_3.cpp + esrgan.cpp exception.cpp + image_splitter.cpp stopwatch.cpp tflite_wrapper.cpp util.cpp diff --git a/plugin/android/src/main/cpp/esrgan.cpp b/plugin/android/src/main/cpp/esrgan.cpp new file mode 100644 index 00000000..41e20a82 --- /dev/null +++ b/plugin/android/src/main/cpp/esrgan.cpp @@ -0,0 +1,153 @@ +#include "exception.h" +#include "image_splitter.h" +#include "log.h" +#include "stopwatch.h" +#include "tflite_wrapper.h" +#include "util.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace plugin; +using namespace std; +using namespace tflite; + +namespace { + +constexpr const char *MODEL = "esrgan-tf2_1-dr.tflite"; +constexpr const size_t TILE_SIZE = 118; +constexpr const size_t TILE_PADDING = 10; + +class Esrgan { +public: + explicit Esrgan(AAssetManager *const aam); + + std::vector infer(const uint8_t *image, const size_t width, + const size_t height); + +private: + std::vector inferSingle(const uint8_t *image, const size_t width, + const size_t height); + std::vector + joinTiles(const std::vector> &tiles); + + Model model; + + static constexpr const char *TAG = "Esrgan"; +}; + +} // namespace + +extern "C" JNIEXPORT jbyteArray JNICALL +Java_com_nkming_nc_1photos_plugin_image_1processor_Esrgan_inferNative( + JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image, + jint width, jint height) { + try { + initOpenMp(); + auto aam = AAssetManager_fromJava(env, assetManager); + Esrgan model(Esrgan{aam}); + RaiiContainer cImage( + [&]() { return env->GetByteArrayElements(image, nullptr); }, + [&](jbyte *obj) { + env->ReleaseByteArrayElements(image, obj, JNI_ABORT); + }); + const auto result = + model.infer(reinterpret_cast(cImage.get()), width, height); + auto resultAry = env->NewByteArray(result.size()); + env->SetByteArrayRegion(resultAry, 0, result.size(), + reinterpret_cast(result.data())); + return resultAry; + } catch (const exception &e) { + throwJavaException(env, e.what()); + return nullptr; + } +} + +namespace { + +Esrgan::Esrgan(AAssetManager *const aam) : model(Asset(aam, MODEL)) {} + +vector Esrgan::infer(const uint8_t *image, const size_t width, + const size_t height) { + // doing ESRGAN in one pass requires loads of memory, so we split the image + // into smaller tiles + const ImageSplitter splitter(TILE_SIZE, TILE_SIZE, TILE_PADDING); + auto tiles = splitter(image, width, height, 3); + const size_t tileCount = tiles.size() * (tiles.empty() ? 0 : tiles[0].size()); + auto i = 0; + for (auto &row : tiles) { + for (auto &t : row) { + LOGI(TAG, "[infer] Tile#%d/%zu", i++, tileCount); + auto result = inferSingle(t.data().data(), t.width(), t.height()); + t = ImageTile(move(result), t.width() * 4, t.height() * 4, 3); + } + } + + // when joining tiles, we use half of paddings from next tile to cover the + // prev tile + vector output(width * 4 * height * 4 * 3); +#pragma omp parallel for + for (size_t ty = 0; ty < tiles.size(); ++ty) { + const auto thisTilePaddingH = ty == 0 ? 0 : TILE_PADDING * 4; + const auto thisTileBegY = thisTilePaddingH / 2; + for (size_t tx = 0; tx < tiles[ty].size(); ++tx) { + const auto thisTilePaddingW = tx == 0 ? 0 : TILE_PADDING * 4; + const auto thisTileBegX = thisTilePaddingW / 2; + const ImageTile &tile = tiles[ty][tx]; + const auto thisTileEndY = + tile.height() - (ty == tiles.size() - 1 ? 0 : TILE_PADDING * 4 / 2); + const auto thisTileEndX = + tile.width() - + (tx == tiles[ty].size() - 1 ? 0 : TILE_PADDING * 4 / 2); + for (size_t dy = thisTileBegY; dy < thisTileEndY; ++dy) { + const auto srcOffset = (dy * tile.width() + thisTileBegX) * 3; + const auto dstOffset = + ((ty * TILE_SIZE * 4 - thisTilePaddingH + dy) * width * 4 + + tx * TILE_SIZE * 4 - thisTilePaddingW + thisTileBegX) * + 3; + memcpy(output.data() + dstOffset, tile.data().data() + srcOffset, + (thisTileEndX - thisTileBegX) * 3); + } + } + } + return output; +} + +vector Esrgan::inferSingle(const uint8_t *image, const size_t width, + const size_t height) { + InterpreterOptions options; + options.setNumThreads(getNumberOfProcessors()); + Interpreter interpreter(model, options); + const int dims[] = {1, static_cast(height), static_cast(width), 3}; + interpreter.resizeInputTensor(0, dims, 4); + interpreter.allocateTensors(); + + LOGI(TAG, "[inferSingle] Convert bitmap to input"); + const auto input = rgb8ToRgbFloat(image, width * height * 3, false); + auto inputTensor = interpreter.getInputTensor(0); + assert(TfLiteTensorByteSize(inputTensor) == input.size() * sizeof(float)); + TfLiteTensorCopyFromBuffer(inputTensor, input.data(), + input.size() * sizeof(float)); + + LOGI(TAG, "[inferSingle] Inferring"); + Stopwatch stopwatch; + interpreter.invoke(); + LOGI(TAG, "[inferSingle] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f); + + auto outputTensor = interpreter.getOutputTensor(0); + vector output(width * 4 * height * 4 * 3); + assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float)); + TfLiteTensorCopyToBuffer(outputTensor, output.data(), + output.size() * sizeof(float)); + return rgbFloatToRgb8(output.data(), output.size(), false); +} + +} // namespace diff --git a/plugin/android/src/main/cpp/esrgan.h b/plugin/android/src/main/cpp/esrgan.h new file mode 100644 index 00000000..759f43a4 --- /dev/null +++ b/plugin/android/src/main/cpp/esrgan.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +JNIEXPORT jbyteArray JNICALL +Java_com_nkming_nc_1photos_plugin_image_1processor_Esrgan_inferNative( + JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image, + jint width, jint height); + +#ifdef __cplusplus +} +#endif diff --git a/plugin/android/src/main/cpp/image_splitter.cpp b/plugin/android/src/main/cpp/image_splitter.cpp new file mode 100644 index 00000000..2754fc62 --- /dev/null +++ b/plugin/android/src/main/cpp/image_splitter.cpp @@ -0,0 +1,75 @@ +#include "image_splitter.h" +#include "log.h" +#include "util.h" +#include +#include +#include +#include +#include + +using namespace std; + +namespace plugin { + +ImageTile::ImageTile() : _width(0), _height(0), _channel(3) {} + +ImageTile::ImageTile(vector &&data, const size_t width, + const size_t height, const unsigned channel) + : _data(move(data)), _width(width), _height(height), _channel(channel) {} + +ImageTile &ImageTile::operator=(ImageTile &&rhs) { + if (this != &rhs) { + _data = move(rhs._data); + _width = rhs._width; + rhs._width = 0; + _height = rhs._height; + rhs._height = 0; + _channel = rhs._channel; + } + return *this; +} + +ImageSplitter::ImageSplitter(const size_t tileWidth, const size_t tileHeight, + const size_t padding) + : _tileWidth(tileWidth), _tileHeight(tileHeight), _padding(padding) {} + +deque> +ImageSplitter::operator()(const uint8_t *image, const size_t width, + const size_t height, const unsigned channel) const { + const size_t tileHoriz = ceil(static_cast(width) / _tileWidth); + const size_t tileVert = ceil(static_cast(height) / _tileHeight); + deque> product(tileVert); + for (auto &r : product) { + for (int i = 0; i < tileHoriz; ++i) { + r.emplace_back(); + } + } + LOGD(TAG, "[split] Spliting %zux%zu into %zux%zu tiles, each %zux%zu in size", + width, height, tileHoriz, tileVert, _tileWidth, _tileHeight); + + for (size_t ty = 0; ty < tileVert; ++ty) { + const size_t thisTilePaddingY = ty == 0 ? 0 : _padding; + const size_t thisTileH = + std::min(_tileHeight, height - _tileHeight * ty) + thisTilePaddingY; + for (size_t tx = 0; tx < tileHoriz; ++tx) { + const size_t thisTilePaddingX = tx == 0 ? 0 : _padding; + const size_t thisTileW = + std::min(_tileWidth, width - _tileWidth * tx) + thisTilePaddingX; + LOGI(TAG, "[split] Tile[%zu][%zu]: %zux%zu", ty, tx, thisTileW, + thisTileH); + vector tile(thisTileW * thisTileH * channel); + for (size_t dy = 0; dy < thisTileH; ++dy) { + const auto srcOffset = + ((ty * _tileHeight + dy - thisTilePaddingY) * width + + (tx * _tileWidth - thisTilePaddingX)) * + channel; + const auto dstOffset = dy * thisTileW * channel; + memcpy(tile.data() + dstOffset, image + srcOffset, thisTileW * channel); + } + product[ty][tx] = ImageTile(move(tile), thisTileW, thisTileH, channel); + } + } + return product; +} + +} // namespace plugin diff --git a/plugin/android/src/main/cpp/image_splitter.h b/plugin/android/src/main/cpp/image_splitter.h new file mode 100644 index 00000000..44e90b07 --- /dev/null +++ b/plugin/android/src/main/cpp/image_splitter.h @@ -0,0 +1,54 @@ +#include +#include +#include + +namespace plugin { + +class ImageTile { +public: + ImageTile(); + ImageTile(std::vector &&data, const size_t width, + const size_t height, const unsigned channel); + + ImageTile &operator=(const ImageTile &) = delete; + ImageTile &operator=(ImageTile &&rhs); + + const std::vector &data() const { return _data; } + size_t width() const { return _width; } + size_t height() const { return _height; } + unsigned channel() const { return _channel; } + +private: + std::vector _data; + size_t _width; + size_t _height; + unsigned _channel; +}; + +/** + * Split an image to smaller tiles with optional padding + * + * If padding > 0, each tile will have extra pixels belonging to the previous + * tiles in both axis + */ +class ImageSplitter { +public: + ImageSplitter(const size_t tileWidth, const size_t tileHeight) + : ImageSplitter(tileWidth, tileHeight, 0) {} + ImageSplitter(const size_t tileWidth, const size_t tileHeight, + const size_t padding); + + std::deque> operator()(const uint8_t *image, + const size_t width, + const size_t height, + const unsigned channel) const; + +private: + const size_t _tileWidth; + const size_t _tileHeight; + const size_t _padding; + + static constexpr const char *TAG = "ImageSplitter"; +}; + +} // namespace plugin diff --git a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorChannelHandler.kt b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorChannelHandler.kt index 9eb45338..b9b7e61d 100644 --- a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorChannelHandler.kt +++ b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorChannelHandler.kt @@ -49,6 +49,21 @@ class ImageProcessorChannelHandler(context: Context) : } } + "esrgan" -> { + try { + esrgan( + call.argument("fileUrl")!!, + call.argument("headers"), + call.argument("filename")!!, + call.argument("maxWidth")!!, + call.argument("maxHeight")!!, + result + ) + } catch (e: Throwable) { + result.error("systemException", e.toString(), null) + } + } + else -> result.notImplemented() } } @@ -82,6 +97,14 @@ class ImageProcessorChannelHandler(context: Context) : } ) + private fun esrgan( + fileUrl: String, headers: Map?, filename: String, + maxWidth: Int, maxHeight: Int, result: MethodChannel.Result + ) = method( + fileUrl, headers, filename, maxWidth, maxHeight, + ImageProcessorService.METHOD_ESRGAN, result + ) + private fun method( fileUrl: String, headers: Map?, filename: String, maxWidth: Int, maxHeight: Int, method: String, diff --git a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorService.kt b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorService.kt index a1c88aff..f25cd30c 100644 --- a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorService.kt +++ b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorService.kt @@ -18,6 +18,7 @@ import androidx.core.app.NotificationCompat import androidx.core.app.NotificationManagerCompat import androidx.exifinterface.media.ExifInterface import com.nkming.nc_photos.plugin.image_processor.DeepLab3Portrait +import com.nkming.nc_photos.plugin.image_processor.Esrgan import com.nkming.nc_photos.plugin.image_processor.ZeroDce import java.io.File import java.net.HttpURLConnection @@ -28,6 +29,7 @@ class ImageProcessorService : Service() { const val EXTRA_METHOD = "method" const val METHOD_ZERO_DCE = "zero-dce" const val METHOD_DEEP_LAP_PORTRAIT = "DeepLab3Portrait" + const val METHOD_ESRGAN = "Esrgan" const val EXTRA_FILE_URL = "fileUrl" const val EXTRA_HEADERS = "headers" const val EXTRA_FILENAME = "filename" @@ -99,6 +101,7 @@ class ImageProcessorService : Service() { METHOD_DEEP_LAP_PORTRAIT -> onDeepLapPortrait( startId, intent.extras!! ) + METHOD_ESRGAN -> onEsrgan(startId, intent.extras!!) else -> { logE(TAG, "Unknown method: $method") // we can't call stopSelf here as it'll stop the service even if @@ -126,6 +129,10 @@ class ImageProcessorService : Service() { ) } + private fun onEsrgan(startId: Int, extras: Bundle) { + return onMethod(startId, extras, METHOD_ESRGAN) + } + /** * Handle methods without arguments * @@ -449,6 +456,10 @@ private open class ImageProcessorCommandTask(context: Context) : cmd.args["radius"] as? Int ?: 16 ).infer(fileUri) + ImageProcessorService.METHOD_ESRGAN -> Esrgan( + context, cmd.maxWidth, cmd.maxHeight + ).infer(fileUri) + else -> throw IllegalArgumentException( "Unknown method: ${cmd.method}" ) diff --git a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/Esrgan.kt b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/Esrgan.kt new file mode 100644 index 00000000..87b6b242 --- /dev/null +++ b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/Esrgan.kt @@ -0,0 +1,39 @@ +package com.nkming.nc_photos.plugin.image_processor + +import android.content.Context +import android.content.res.AssetManager +import android.graphics.Bitmap +import android.net.Uri +import com.nkming.nc_photos.plugin.BitmapResizeMethod +import com.nkming.nc_photos.plugin.BitmapUtil + +class Esrgan(context: Context, maxWidth: Int, maxHeight: Int) { + fun infer(imageUri: Uri): Bitmap { + val width: Int + val height: Int + val rgb8Image = BitmapUtil.loadImage( + context, imageUri, maxWidth / 4, maxHeight / 4, + BitmapResizeMethod.FIT, isAllowSwapSide = true, + shouldUpscale = false + ).let { + width = it.width + height = it.height + val rgb8 = TfLiteHelper.bitmapToRgb8Array(it) + it.recycle() + rgb8 + } + val am = context.assets + + return inferNative(am, rgb8Image, width, height).let { + TfLiteHelper.rgb8ArrayToBitmap(it, width * 4, height * 4) + } + } + + private external fun inferNative( + am: AssetManager, image: ByteArray, width: Int, height: Int + ): ByteArray + + private val context = context + private val maxWidth = maxWidth + private val maxHeight = maxHeight +} diff --git a/plugin/lib/src/image_processor.dart b/plugin/lib/src/image_processor.dart index cef8afdd..cd720145 100644 --- a/plugin/lib/src/image_processor.dart +++ b/plugin/lib/src/image_processor.dart @@ -38,6 +38,21 @@ class ImageProcessor { "radius": radius, }); + static Future esrgan( + String fileUrl, + String filename, + int maxWidth, + int maxHeight, { + Map? headers, + }) => + _methodChannel.invokeMethod("esrgan", { + "fileUrl": fileUrl, + "headers": headers, + "filename": filename, + "maxWidth": maxWidth, + "maxHeight": maxHeight, + }); + static const _methodChannel = MethodChannel("${k.libId}/image_processor_method"); }