diff --git a/app/lib/help_utils.dart b/app/lib/help_utils.dart index 74034e98..0218dfcf 100644 --- a/app/lib/help_utils.dart +++ b/app/lib/help_utils.dart @@ -7,3 +7,4 @@ const enhanceUrl = "https://bit.ly/3lF5OiT"; const enhanceZeroDceUrl = "https://bit.ly/3wKJcm9"; const enhanceDeepLabPortraitBlurUrl = "https://bit.ly/3wIuXy6"; const enhanceEsrganUrl = "https://bit.ly/3wO0NJP"; +const enhanceStyleTransferUrl = "https://bit.ly/3agpTcF"; diff --git a/app/lib/l10n/app_en.arb b/app/lib/l10n/app_en.arb index 34e608f7..2b804626 100644 --- a/app/lib/l10n/app_en.arb +++ b/app/lib/l10n/app_en.arb @@ -1213,6 +1213,10 @@ "@enhanceSuperResolution4xTitle": { "description": "Upscale an image. The algorithm implemented in the app will upscale to 4x the original resolution (eg, 100x100 to 400x400)" }, + "enhanceStyleTransferTitle": "Style transfer", + "@enhanceStyleTransferTitle": { + "description": "Transfer the image style from a reference image to a photo" + }, "errorUnauthenticated": "Unauthenticated access. Please sign-in again if the problem continues", "@errorUnauthenticated": { diff --git a/app/lib/l10n/untranslated-messages.txt b/app/lib/l10n/untranslated-messages.txt index 4559e80f..c65b6556 100644 --- a/app/lib/l10n/untranslated-messages.txt +++ b/app/lib/l10n/untranslated-messages.txt @@ -99,6 +99,7 @@ "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle", "errorAlbumDowngrade" ], @@ -216,6 +217,7 @@ "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle", "errorAlbumDowngrade" ], @@ -388,6 +390,7 @@ "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle", "errorAlbumDowngrade" ], @@ -410,11 +413,13 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", - "enhanceSuperResolution4xTitle" + "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle" ], "fi": [ - "enhanceSuperResolution4xTitle" + "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle" ], "fr": [ @@ -436,7 +441,8 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", - "enhanceSuperResolution4xTitle" + "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle" ], "pl": [ @@ -475,7 +481,8 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", - "enhanceSuperResolution4xTitle" + "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle" ], "pt": [ @@ -493,7 +500,8 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", - "enhanceSuperResolution4xTitle" + "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle" ], "ru": [ @@ -511,7 +519,8 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", - "enhanceSuperResolution4xTitle" + "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle" ], "zh": [ @@ -529,7 +538,8 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", - "enhanceSuperResolution4xTitle" + "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle" ], "zh_Hant": [ @@ -547,6 +557,7 @@ "deletePermanentlyLocalConfirmationDialogContent", "enhancePortraitBlurTitle", "enhancePortraitBlurParamBlurLabel", - "enhanceSuperResolution4xTitle" + "enhanceSuperResolution4xTitle", + "enhanceStyleTransferTitle" ] } diff --git a/app/lib/mobile/android/k.dart b/app/lib/mobile/android/k.dart index bc96c556..e8cafa51 100644 --- a/app/lib/mobile/android/k.dart +++ b/app/lib/mobile/android/k.dart @@ -1,3 +1,5 @@ +// ignore_for_file: constant_identifier_names + /// Standard activity result: operation canceled. const resultCanceled = 0; @@ -6,3 +8,7 @@ const resultOk = -1; /// Start of user-defined activity results. const resultFirstUser = 1; + +const ACTION_GET_CONTENT = "android.intent.action.GET_CONTENT"; +const CATEGORY_OPENABLE = "android.intent.category.OPENABLE"; +const EXTRA_LOCAL_ONLY = "android.intent.extra.LOCAL_ONLY"; diff --git a/app/lib/widget/handler/enhance_handler.dart b/app/lib/widget/handler/enhance_handler.dart index 05e493ff..bf804204 100644 --- a/app/lib/widget/handler/enhance_handler.dart +++ b/app/lib/widget/handler/enhance_handler.dart @@ -1,3 +1,6 @@ +import 'dart:math' as math; + +import 'package:android_intent_plus/android_intent.dart'; import 'package:flutter/material.dart'; import 'package:logging/logging.dart'; import 'package:nc_photos/account.dart'; @@ -6,14 +9,18 @@ import 'package:nc_photos/app_localizations.dart'; import 'package:nc_photos/entity/file.dart'; import 'package:nc_photos/entity/file_util.dart' as file_util; import 'package:nc_photos/help_utils.dart'; +import 'package:nc_photos/iterable_extension.dart'; import 'package:nc_photos/k.dart' as k; import 'package:nc_photos/mobile/android/android_info.dart'; +import 'package:nc_photos/mobile/android/content_uri_image_provider.dart'; +import 'package:nc_photos/mobile/android/k.dart' as android; import 'package:nc_photos/mobile/android/permission_util.dart'; import 'package:nc_photos/object_extension.dart'; import 'package:nc_photos/platform/k.dart' as platform_k; import 'package:nc_photos/pref.dart'; import 'package:nc_photos/snack_bar_manager.dart'; import 'package:nc_photos/theme.dart'; +import 'package:nc_photos/widget/selectable.dart'; import 'package:nc_photos/widget/settings.dart'; import 'package:nc_photos/widget/stateful_slider.dart'; import 'package:nc_photos_plugin/nc_photos_plugin.dart'; @@ -86,6 +93,22 @@ class EnhanceHandler { }, ); break; + + case _Algorithm.arbitraryStyleTransfer: + await ImageProcessor.arbitraryStyleTransfer( + "${account.url}/${file.path}", + file.filename, + math.min( + Pref().getEnhanceMaxWidthOr(), isAtLeast5GbRam() ? 1600 : 1280), + math.min( + Pref().getEnhanceMaxHeightOr(), isAtLeast5GbRam() ? 1200 : 960), + args["styleUri"], + args["weight"], + headers: { + "Authorization": Api.getAuthorizationHeaderValue(account), + }, + ); + break; } } @@ -194,6 +217,12 @@ class EnhanceHandler { link: enhanceEsrganUrl, algorithm: _Algorithm.esrgan, ), + if (platform_k.isAndroid && isAtLeast4GbRam()) + _Option( + title: L10n.global().enhanceStyleTransferTitle, + link: enhanceStyleTransferUrl, + algorithm: _Algorithm.arbitraryStyleTransfer, + ), ]; Future?> _getArgs( @@ -207,6 +236,9 @@ class EnhanceHandler { case _Algorithm.esrgan: return {}; + + case _Algorithm.arbitraryStyleTransfer: + return _getArbitraryStyleTransferArgs(context); } } @@ -314,6 +346,32 @@ class EnhanceHandler { return radius?.run((it) => {"radius": it}); } + Future?> _getArbitraryStyleTransferArgs( + BuildContext context) async { + final result = await showDialog<_StylePickerResult>( + context: context, + builder: (_) => const _StylePicker(), + ); + if (result == null) { + // user canceled + return null; + } else { + return { + "styleUri": result.styleUri, + "weight": result.weight, + }; + } + } + + bool isAtLeast4GbRam() { + // We can't compare with 4096 directly as some RAM are preserved + return AndroidInfo().totalMemMb > 3584; + } + + bool isAtLeast5GbRam() { + return AndroidInfo().totalMemMb > 4608; + } + final Account account; final File file; @@ -324,6 +382,7 @@ enum _Algorithm { zeroDce, deepLab3Portrait, esrgan, + arbitraryStyleTransfer, } class _Option { @@ -339,3 +398,186 @@ class _Option { final String? link; final _Algorithm algorithm; } + +class _StylePickerResult { + const _StylePickerResult(this.styleUri, this.weight); + + final String styleUri; + final double weight; +} + +class _StylePicker extends StatefulWidget { + const _StylePicker({ + Key? key, + }) : super(key: key); + + @override + createState() => _StylePickerState(); +} + +class _StylePickerState extends State<_StylePicker> { + @override + build(BuildContext context) { + return AppTheme( + child: AlertDialog( + title: Text("Pick a style"), + contentPadding: const EdgeInsets.fromLTRB(24.0, 20.0, 24.0, 0), + content: Column( + crossAxisAlignment: CrossAxisAlignment.start, + mainAxisSize: MainAxisSize.min, + children: [ + if (_selected != null) ...[ + Align( + alignment: Alignment.center, + child: SizedBox( + width: 128, + height: 128, + child: Image( + image: ResizeImage.resizeIfNeeded( + 128, null, ContentUriImage(_getSelectedUri())), + fit: BoxFit.cover, + ), + ), + ), + const SizedBox(height: 16), + ], + Wrap( + runSpacing: 8, + spacing: 8, + children: [ + ..._bundledStyles.mapWithIndex((i, e) => _buildItem( + i, + Image( + image: ResizeImage.resizeIfNeeded( + _thumbSize, null, ContentUriImage(e)), + fit: BoxFit.cover, + ), + )), + if (_customUri != null) + _buildItem( + _bundledStyles.length, + Image( + image: ResizeImage.resizeIfNeeded( + _thumbSize, null, ContentUriImage(_customUri!)), + fit: BoxFit.cover, + ), + ), + InkWell( + onTap: _onCustomTap, + child: SizedBox( + width: _thumbSize.toDouble(), + height: _thumbSize.toDouble(), + child: const Icon( + Icons.file_open_outlined, + size: 24, + ), + ), + ), + ], + ), + const SizedBox(height: 16), + Row( + mainAxisSize: MainAxisSize.max, + children: [ + Icon( + Icons.auto_fix_normal, + color: AppTheme.getSecondaryTextColor(context), + ), + Expanded( + child: StatefulSlider( + initialValue: _weight, + min: .01, + onChangeEnd: (value) { + _weight = value; + }, + ), + ), + Icon( + Icons.auto_fix_high, + color: AppTheme.getSecondaryTextColor(context), + ), + ], + ), + ], + ), + actions: [ + TextButton( + onPressed: () { + if (_selected == null) { + SnackBarManager().showSnackBar(const SnackBar( + content: Text("Please pick a style"), + duration: k.snackBarDurationNormal, + )); + } else { + final result = _StylePickerResult(_getSelectedUri(), _weight); + Navigator.of(context).pop(result); + } + }, + child: Text(L10n.global().enhanceButtonLabel), + ), + ], + ), + ); + } + + Widget _buildItem(int index, Widget child) { + return SizedBox( + width: _thumbSize.toDouble(), + height: _thumbSize.toDouble(), + child: Selectable( + isSelected: _selected == index, + iconSize: 24, + child: child, + onTap: () { + setState(() { + _selected = index; + }); + }, + ), + ); + } + + Future _onCustomTap() async { + const intent = AndroidIntent( + action: android.ACTION_GET_CONTENT, + type: "image/*", + category: android.CATEGORY_OPENABLE, + arguments: { + android.EXTRA_LOCAL_ONLY: true, + }, + ); + final result = await intent.launchForResult(); + _log.info("[onCustomTap] Intent result: $result"); + if (result?.resultCode == android.resultOk) { + if (mounted) { + setState(() { + _customUri = result!.data; + _selected = _bundledStyles.length; + }); + } + } + } + + String _getSelectedUri() { + return _selected! < _bundledStyles.length + ? _bundledStyles[_selected!] + : _customUri!; + } + + int? _selected; + String? _customUri; + double _weight = .85; + + static const _thumbSize = 56; + static const _bundledStyles = [ + "file:///android_asset/tf/arbitrary-style-transfer/1.jpg", + "file:///android_asset/tf/arbitrary-style-transfer/2.jpg", + "file:///android_asset/tf/arbitrary-style-transfer/3.jpg", + "file:///android_asset/tf/arbitrary-style-transfer/4.jpg", + "file:///android_asset/tf/arbitrary-style-transfer/5.jpg", + "file:///android_asset/tf/arbitrary-style-transfer/6.jpg", + ]; + + static final _log = + Logger("widget.handler.enhance_handler._StylePickerState"); +} diff --git a/app/pubspec.lock b/app/pubspec.lock index 796740c8..0ac69c12 100644 --- a/app/pubspec.lock +++ b/app/pubspec.lock @@ -18,10 +18,12 @@ packages: android_intent_plus: dependency: "direct main" description: - name: android_intent_plus - url: "https://pub.dartlang.org" - source: hosted - version: "3.1.0" + path: "packages/android_intent_plus" + ref: "android_intent_plus-v3.1.1-nc-photos-1" + resolved-ref: f257f10641e907a94b83ff0e060e80590bf1dae5 + url: "https://gitlab.com/nc-photos/plus_plugins" + source: git + version: "3.1.1" args: dependency: transitive description: diff --git a/app/pubspec.yaml b/app/pubspec.yaml index c896acf9..5a550631 100644 --- a/app/pubspec.yaml +++ b/app/pubspec.yaml @@ -28,7 +28,11 @@ dependencies: sdk: flutter # android only - android_intent_plus: ^3.0.1 + android_intent_plus: + git: + url: https://gitlab.com/nc-photos/plus_plugins + ref: android_intent_plus-v3.1.1-nc-photos-1 + path: packages/android_intent_plus battery_plus: ^2.1.3 bloc: ^7.0.0 cached_network_image: ^3.0.0 diff --git a/plugin/android/src/main/assets/tf/arbitrary-style-transfer-inceptionv3_dr_predict_1.tflite b/plugin/android/src/main/assets/tf/arbitrary-style-transfer-inceptionv3_dr_predict_1.tflite new file mode 100644 index 00000000..fbc1465f Binary files /dev/null and b/plugin/android/src/main/assets/tf/arbitrary-style-transfer-inceptionv3_dr_predict_1.tflite differ diff --git a/plugin/android/src/main/assets/tf/arbitrary-style-transfer-inceptionv3_dr_transfer_1.tflite b/plugin/android/src/main/assets/tf/arbitrary-style-transfer-inceptionv3_dr_transfer_1.tflite new file mode 100644 index 00000000..5ad99118 Binary files /dev/null and b/plugin/android/src/main/assets/tf/arbitrary-style-transfer-inceptionv3_dr_transfer_1.tflite differ diff --git a/plugin/android/src/main/assets/tf/arbitrary-style-transfer/1.jpg b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/1.jpg new file mode 100644 index 00000000..84b53364 Binary files /dev/null and b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/1.jpg differ diff --git a/plugin/android/src/main/assets/tf/arbitrary-style-transfer/2.jpg b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/2.jpg new file mode 100644 index 00000000..663616ef Binary files /dev/null and b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/2.jpg differ diff --git a/plugin/android/src/main/assets/tf/arbitrary-style-transfer/3.jpg b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/3.jpg new file mode 100644 index 00000000..d282b053 Binary files /dev/null and b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/3.jpg differ diff --git a/plugin/android/src/main/assets/tf/arbitrary-style-transfer/4.jpg b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/4.jpg new file mode 100644 index 00000000..e4394dc6 Binary files /dev/null and b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/4.jpg differ diff --git a/plugin/android/src/main/assets/tf/arbitrary-style-transfer/5.jpg b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/5.jpg new file mode 100644 index 00000000..cea68163 Binary files /dev/null and b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/5.jpg differ diff --git a/plugin/android/src/main/assets/tf/arbitrary-style-transfer/6.jpg b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/6.jpg new file mode 100644 index 00000000..bc07be4f Binary files /dev/null and b/plugin/android/src/main/assets/tf/arbitrary-style-transfer/6.jpg differ diff --git a/plugin/android/src/main/cpp/CMakeLists.txt b/plugin/android/src/main/cpp/CMakeLists.txt index 32673387..6964cac1 100644 --- a/plugin/android/src/main/cpp/CMakeLists.txt +++ b/plugin/android/src/main/cpp/CMakeLists.txt @@ -33,6 +33,7 @@ add_library( # Sets the name of the library. SHARED # Provides a relative path to your source file(s). + arbitrary_style_transfer.cpp deep_lap_3.cpp esrgan.cpp exception.cpp diff --git a/plugin/android/src/main/cpp/arbitrary_style_transfer.cpp b/plugin/android/src/main/cpp/arbitrary_style_transfer.cpp new file mode 100644 index 00000000..059e53cf --- /dev/null +++ b/plugin/android/src/main/cpp/arbitrary_style_transfer.cpp @@ -0,0 +1,223 @@ +#include "base_resample.h" +#include "exception.h" +#include "log.h" +#include "stopwatch.h" +#include "tflite_wrapper.h" +#include "util.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace plugin; +using namespace std; +using namespace tflite; + +namespace { + +constexpr const char *PREDICT_MODEL = + "tf/arbitrary-style-transfer-inceptionv3_dr_predict_1.tflite"; +constexpr const char *TRANSFER_MODEL = + "tf/arbitrary-style-transfer-inceptionv3_dr_transfer_1.tflite"; + +class ArbitraryStyleTransfer { +public: + explicit ArbitraryStyleTransfer(AAssetManager *const aam); + + std::vector infer(const uint8_t *image, const size_t width, + const size_t height, const uint8_t *style, + const float weight); + +private: + std::vector predict(const uint8_t *image, const size_t width, + const size_t height, const uint8_t *style, + const float weight); + std::vector transfer(const uint8_t *image, const size_t width, + const size_t height, + const std::vector &bottleneck); + + /** + * @param style The style image MUST be 256*256 + */ + std::vector predictStyle(const uint8_t *style); + + std::vector blendBottleneck(const std::vector &style, + const std::vector &image, + const float styleWeight); + + Model predictModel; + Model transferModel; + + static constexpr const char *TAG = "ArbitraryStyleTransfer"; +}; + +} // namespace + +extern "C" JNIEXPORT jbyteArray JNICALL +Java_com_nkming_nc_1photos_plugin_image_1processor_ArbitraryStyleTransfer_inferNative( + JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image, + jint width, jint height, jbyteArray style, jfloat weight) { + try { + initOpenMp(); + auto aam = AAssetManager_fromJava(env, assetManager); + ArbitraryStyleTransfer model(aam); + RaiiContainer cImage( + [&]() { return env->GetByteArrayElements(image, nullptr); }, + [&](jbyte *obj) { + env->ReleaseByteArrayElements(image, obj, JNI_ABORT); + }); + RaiiContainer cStyle( + [&]() { return env->GetByteArrayElements(style, nullptr); }, + [&](jbyte *obj) { + env->ReleaseByteArrayElements(style, obj, JNI_ABORT); + }); + const auto result = + model.infer(reinterpret_cast(cImage.get()), width, height, + reinterpret_cast(cStyle.get()), weight); + auto resultAry = env->NewByteArray(result.size()); + env->SetByteArrayRegion(resultAry, 0, result.size(), + reinterpret_cast(result.data())); + return resultAry; + } catch (const exception &e) { + throwJavaException(env, e.what()); + return nullptr; + } +} + +namespace { + +ArbitraryStyleTransfer::ArbitraryStyleTransfer(AAssetManager *const aam) + : predictModel(Asset(aam, PREDICT_MODEL)), + transferModel(Asset(aam, TRANSFER_MODEL)) {} + +vector ArbitraryStyleTransfer::infer(const uint8_t *image, + const size_t width, + const size_t height, + const uint8_t *style, + const float weight) { + const auto bottleneck = predict(image, width, height, style, weight); + return transfer(image, width, height, bottleneck); +} + +vector ArbitraryStyleTransfer::predict(const uint8_t *image, + const size_t width, + const size_t height, + const uint8_t *style, + const float weight) { + auto style_bottleneck = predictStyle(style); + vector imageStyleBitmap(256 * 256 * 3); + base::ResampleImage24(image, width, height, imageStyleBitmap.data(), 256, 256, + base::KernelTypeLanczos3); + auto image_bottleneck = predictStyle(imageStyleBitmap.data()); + return blendBottleneck(style_bottleneck, image_bottleneck, weight); +} + +vector +ArbitraryStyleTransfer::transfer(const uint8_t *image, const size_t width, + const size_t height, + const vector &bottleneck) { + vector resizedImage; + auto inputWidth = width; + auto inputHeight = height; + const uint8_t *inputImage = image; + if (width % 4 != 0 || height % 4 != 0) { + LOGI(TAG, "[transfer] Resize bitmap to multiple of 4"); + inputWidth = width - width % 4; + inputHeight = height - height % 4; + resizedImage.resize(inputWidth * inputHeight * 3); + base::ResampleImage24(image, width, height, resizedImage.data(), inputWidth, + inputHeight, base::KernelTypeLanczos3); + inputImage = resizedImage.data(); + } + + InterpreterOptions options; + options.setNumThreads(getNumberOfProcessors()); + Interpreter interpreter(transferModel, options); + const int dims[] = {1, static_cast(inputHeight), + static_cast(inputWidth), 3}; + interpreter.resizeInputTensor(1, dims, 4); + interpreter.allocateTensors(); + + LOGI(TAG, "[transfer] Copy bias"); + auto inputTensor0 = interpreter.getInputTensor(0); + assert(TfLiteTensorByteSize(inputTensor0) == + bottleneck.size() * sizeof(float)); + TfLiteTensorCopyFromBuffer(inputTensor0, bottleneck.data(), + bottleneck.size() * sizeof(float)); + + LOGI(TAG, "[transfer] Convert bitmap to input"); + auto input = rgb8ToRgbFloat(inputImage, inputWidth * inputHeight * 3, true); + auto inputTensor1 = interpreter.getInputTensor(1); + assert(TfLiteTensorByteSize(inputTensor1) == input.size() * sizeof(float)); + TfLiteTensorCopyFromBuffer(inputTensor1, input.data(), + input.size() * sizeof(float)); + input.clear(); + + LOGI(TAG, "[transfer] Inferring"); + Stopwatch stopwatch; + interpreter.invoke(); + LOGI(TAG, "[transfer] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f); + + auto outputTensor = interpreter.getOutputTensor(0); + vector output(inputWidth * inputHeight * 3); + assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float)); + TfLiteTensorCopyToBuffer(outputTensor, output.data(), + output.size() * sizeof(float)); + auto outputRgb8 = rgbFloatToRgb8(output.data(), output.size(), true); + output.clear(); + if (!resizedImage.empty()) { + // resize it back to the original resolution + vector temp(width * height * 3); + base::ResampleImage24(outputRgb8.data(), inputWidth, inputHeight, + temp.data(), width, height, base::KernelTypeBicubic); + return temp; + } else { + return outputRgb8; + } +} + +vector ArbitraryStyleTransfer::predictStyle(const uint8_t *style) { + InterpreterOptions options; + options.setNumThreads(getNumberOfProcessors()); + Interpreter interpreter(predictModel, options); + interpreter.allocateTensors(); + + LOGI(TAG, "[predictStyle] Convert bitmap to input"); + const auto input = rgb8ToRgbFloat(style, 256 * 256 * 3, true); + auto inputTensor = interpreter.getInputTensor(0); + assert(TfLiteTensorByteSize(inputTensor) == input.size() * sizeof(float)); + TfLiteTensorCopyFromBuffer(inputTensor, input.data(), + input.size() * sizeof(float)); + + LOGI(TAG, "[predictStyle] Inferring"); + Stopwatch stopwatch; + interpreter.invoke(); + LOGI(TAG, "[predictStyle] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f); + + auto outputTensor = interpreter.getOutputTensor(0); + vector output(100); + assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float)); + TfLiteTensorCopyToBuffer(outputTensor, output.data(), + output.size() * sizeof(float)); + return output; +} + +vector +ArbitraryStyleTransfer::blendBottleneck(const vector &style, + const vector &image, + const float styleWeight) { + assert(style.size() == 100); + assert(image.size() == 100); + vector product(100); + for (int i = 0; i < 100; ++i) { + product[i] = styleWeight * style[i] + (1 - styleWeight) * image[i]; + } + return product; +} + +} // namespace diff --git a/plugin/android/src/main/cpp/arbitrary_style_transfer.h b/plugin/android/src/main/cpp/arbitrary_style_transfer.h new file mode 100644 index 00000000..0b1d1457 --- /dev/null +++ b/plugin/android/src/main/cpp/arbitrary_style_transfer.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +JNIEXPORT jbyteArray JNICALL +Java_com_nkming_nc_1photos_plugin_image_1processor_ArbitraryStyleTransfer_inferNative( + JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image, + jint width, jint height, jbyteArray style, jfloat weight); + +#ifdef __cplusplus +} +#endif diff --git a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/BitmapUtil.kt b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/BitmapUtil.kt index 9edfddf1..65de88ed 100644 --- a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/BitmapUtil.kt +++ b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/BitmapUtil.kt @@ -5,9 +5,30 @@ import android.graphics.Bitmap import android.graphics.BitmapFactory import android.net.Uri import java.io.InputStream +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract fun Bitmap.aspectRatio() = width / height.toFloat() +/** + * Recycle the bitmap after @c block returns. + * + * @param block + * @return + */ +@OptIn(ExperimentalContracts::class) +inline fun Bitmap.use(block: (Bitmap) -> T): T { + contract { + callsInPlace(block, InvocationKind.EXACTLY_ONCE) + } + try { + return block(this) + } finally { + recycle() + } +} + enum class BitmapResizeMethod { FIT, FILL, diff --git a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorChannelHandler.kt b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorChannelHandler.kt index b9b7e61d..6022e9ea 100644 --- a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorChannelHandler.kt +++ b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorChannelHandler.kt @@ -2,6 +2,7 @@ package com.nkming.nc_photos.plugin import android.content.Context import android.content.Intent +import android.net.Uri import androidx.core.content.ContextCompat import io.flutter.plugin.common.EventChannel import io.flutter.plugin.common.MethodCall @@ -64,6 +65,23 @@ class ImageProcessorChannelHandler(context: Context) : } } + "arbitraryStyleTransfer" -> { + try { + arbitraryStyleTransfer( + call.argument("fileUrl")!!, + call.argument("headers"), + call.argument("filename")!!, + call.argument("maxWidth")!!, + call.argument("maxHeight")!!, + call.argument("styleUri")!!, + call.argument("weight")!!, + result + ) + } catch (e: Throwable) { + result.error("systemException", e.toString(), null) + } + } + else -> result.notImplemented() } } @@ -105,6 +123,21 @@ class ImageProcessorChannelHandler(context: Context) : ImageProcessorService.METHOD_ESRGAN, result ) + private fun arbitraryStyleTransfer( + fileUrl: String, headers: Map?, filename: String, + maxWidth: Int, maxHeight: Int, styleUri: String, weight: Float, + result: MethodChannel.Result + ) = method( + fileUrl, headers, filename, maxWidth, maxHeight, + ImageProcessorService.METHOD_ARBITRARY_STYLE_TRANSFER, result, + onIntent = { + it.putExtra( + ImageProcessorService.EXTRA_STYLE_URI, Uri.parse(styleUri) + ) + it.putExtra(ImageProcessorService.EXTRA_WEIGHT, weight) + } + ) + private fun method( fileUrl: String, headers: Map?, filename: String, maxWidth: Int, maxHeight: Int, method: String, diff --git a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorService.kt b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorService.kt index 54c432ae..ae8462e7 100644 --- a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorService.kt +++ b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/ImageProcessorService.kt @@ -17,6 +17,7 @@ import androidx.core.app.NotificationChannelCompat import androidx.core.app.NotificationCompat import androidx.core.app.NotificationManagerCompat import androidx.exifinterface.media.ExifInterface +import com.nkming.nc_photos.plugin.image_processor.ArbitraryStyleTransfer import com.nkming.nc_photos.plugin.image_processor.DeepLab3Portrait import com.nkming.nc_photos.plugin.image_processor.Esrgan import com.nkming.nc_photos.plugin.image_processor.ZeroDce @@ -30,6 +31,7 @@ class ImageProcessorService : Service() { const val METHOD_ZERO_DCE = "zero-dce" const val METHOD_DEEP_LAP_PORTRAIT = "DeepLab3Portrait" const val METHOD_ESRGAN = "Esrgan" + const val METHOD_ARBITRARY_STYLE_TRANSFER = "ArbitraryStyleTransfer" const val EXTRA_FILE_URL = "fileUrl" const val EXTRA_HEADERS = "headers" const val EXTRA_FILENAME = "filename" @@ -37,6 +39,8 @@ class ImageProcessorService : Service() { const val EXTRA_MAX_HEIGHT = "maxHeight" const val EXTRA_RADIUS = "radius" const val EXTRA_ITERATION = "iteration" + const val EXTRA_STYLE_URI = "styleUri" + const val EXTRA_WEIGHT = "weight" private const val ACTION_CANCEL = "cancel" @@ -109,6 +113,9 @@ class ImageProcessorService : Service() { startId, intent.extras!! ) METHOD_ESRGAN -> onEsrgan(startId, intent.extras!!) + METHOD_ARBITRARY_STYLE_TRANSFER -> onArbitraryStyleTransfer( + startId, intent.extras!! + ) else -> { logE(TAG, "Unknown method: $method") // we can't call stopSelf here as it'll stop the service even if @@ -142,6 +149,16 @@ class ImageProcessorService : Service() { return onMethod(startId, extras, METHOD_ESRGAN) } + private fun onArbitraryStyleTransfer(startId: Int, extras: Bundle) { + return onMethod( + startId, extras, METHOD_ARBITRARY_STYLE_TRANSFER, + args = mapOf( + "styleUri" to extras.getParcelable(EXTRA_STYLE_URI), + "weight" to extras.getFloat(EXTRA_WEIGHT), + ) + ) + } + /** * Handle methods without arguments * @@ -530,6 +547,12 @@ private open class ImageProcessorCommandTask(context: Context) : context, cmd.maxWidth, cmd.maxHeight ).infer(fileUri) + ImageProcessorService.METHOD_ARBITRARY_STYLE_TRANSFER -> ArbitraryStyleTransfer( + context, cmd.maxWidth, cmd.maxHeight, + cmd.args["styleUri"] as Uri, + cmd.args["weight"] as Float + ).infer(fileUri) + else -> throw IllegalArgumentException( "Unknown method: ${cmd.method}" ) diff --git a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/ArbitraryStyleTransfer.kt b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/ArbitraryStyleTransfer.kt new file mode 100644 index 00000000..4fd8b545 --- /dev/null +++ b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/ArbitraryStyleTransfer.kt @@ -0,0 +1,70 @@ +package com.nkming.nc_photos.plugin.image_processor + +import android.content.Context +import android.content.res.AssetManager +import android.graphics.Bitmap +import android.net.Uri +import com.nkming.nc_photos.plugin.BitmapResizeMethod +import com.nkming.nc_photos.plugin.BitmapUtil +import com.nkming.nc_photos.plugin.logI +import com.nkming.nc_photos.plugin.use + +class ArbitraryStyleTransfer( + context: Context, maxWidth: Int, maxHeight: Int, styleUri: Uri, + weight: Float +) { + companion object { + const val TAG = "ArbitraryStyleTransfer" + } + + fun infer(imageUri: Uri): Bitmap { + val width: Int + val height: Int + val rgb8Image = BitmapUtil.loadImage( + context, imageUri, maxWidth, maxHeight, BitmapResizeMethod.FIT, + isAllowSwapSide = true, shouldUpscale = false + ).use { + width = it.width + height = it.height + TfLiteHelper.bitmapToRgb8Array(it) + } + val rgb8Style = BitmapUtil.loadImage( + context, styleUri, 256, 256, BitmapResizeMethod.FILL, + isAllowSwapSide = false, shouldUpscale = true + ).use { + val styleBitmap = if (it.width != 256 || it.height != 256) { + val x = (it.width - 256) / 2 + val y = (it.height - 256) / 2 + logI( + TAG, + "[infer] Resize and crop style image: ${it.width}x${it.height} -> 256x256 ($x, $y)" + ) + // crop + Bitmap.createBitmap(it, x, y, 256, 256) + } else { + it + } + styleBitmap.use { + TfLiteHelper.bitmapToRgb8Array(styleBitmap) + } + } + val am = context.assets + + return inferNative( + am, rgb8Image, width, height, rgb8Style, weight + ).let { + TfLiteHelper.rgb8ArrayToBitmap(it, width, height) + } + } + + private external fun inferNative( + am: AssetManager, image: ByteArray, width: Int, height: Int, + style: ByteArray, weight: Float + ): ByteArray + + private val context = context + private val maxWidth = maxWidth + private val maxHeight = maxHeight + private val styleUri = styleUri + private val weight = weight +} diff --git a/plugin/lib/src/image_processor.dart b/plugin/lib/src/image_processor.dart index cd720145..1a61392a 100644 --- a/plugin/lib/src/image_processor.dart +++ b/plugin/lib/src/image_processor.dart @@ -53,6 +53,25 @@ class ImageProcessor { "maxHeight": maxHeight, }); + static Future arbitraryStyleTransfer( + String fileUrl, + String filename, + int maxWidth, + int maxHeight, + String styleUri, + double weight, { + Map? headers, + }) => + _methodChannel.invokeMethod("arbitraryStyleTransfer", { + "fileUrl": fileUrl, + "headers": headers, + "filename": filename, + "maxWidth": maxWidth, + "maxHeight": maxHeight, + "styleUri": styleUri, + "weight": weight, + }); + static const _methodChannel = MethodChannel("${k.libId}/image_processor_method"); }