Merge branch 'photo-enhancement-style-transfer' into dev

This commit is contained in:
Ming Ming 2022-05-25 21:43:43 +08:00
commit 3caa2ceecb
26 changed files with 760 additions and 17 deletions

View file

@ -7,3 +7,4 @@ const enhanceUrl = "https://bit.ly/3lF5OiT";
const enhanceZeroDceUrl = "https://bit.ly/3wKJcm9";
const enhanceDeepLabPortraitBlurUrl = "https://bit.ly/3wIuXy6";
const enhanceEsrganUrl = "https://bit.ly/3wO0NJP";
const enhanceStyleTransferUrl = "https://bit.ly/3agpTcF";

View file

@ -1213,6 +1213,10 @@
"@enhanceSuperResolution4xTitle": {
"description": "Upscale an image. The algorithm implemented in the app will upscale to 4x the original resolution (eg, 100x100 to 400x400)"
},
"enhanceStyleTransferTitle": "Style transfer",
"@enhanceStyleTransferTitle": {
"description": "Transfer the image style from a reference image to a photo"
},
"errorUnauthenticated": "Unauthenticated access. Please sign-in again if the problem continues",
"@errorUnauthenticated": {

View file

@ -99,6 +99,7 @@
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle",
"errorAlbumDowngrade"
],
@ -216,6 +217,7 @@
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle",
"errorAlbumDowngrade"
],
@ -388,6 +390,7 @@
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle",
"errorAlbumDowngrade"
],
@ -410,11 +413,13 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle"
],
"fi": [
"enhanceSuperResolution4xTitle"
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle"
],
"fr": [
@ -436,7 +441,8 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle"
],
"pl": [
@ -475,7 +481,8 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle"
],
"pt": [
@ -493,7 +500,8 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle"
],
"ru": [
@ -511,7 +519,8 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle"
],
"zh": [
@ -529,7 +538,8 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle"
],
"zh_Hant": [
@ -547,6 +557,7 @@
"deletePermanentlyLocalConfirmationDialogContent",
"enhancePortraitBlurTitle",
"enhancePortraitBlurParamBlurLabel",
"enhanceSuperResolution4xTitle"
"enhanceSuperResolution4xTitle",
"enhanceStyleTransferTitle"
]
}

View file

@ -1,4 +1,7 @@
import 'package:device_info_plus/device_info_plus.dart';
import 'package:logging/logging.dart';
import 'package:memory_info/memory_info.dart';
import 'package:nc_photos/double_extension.dart';
/// System info for Android
///
@ -7,23 +10,38 @@ import 'package:device_info_plus/device_info_plus.dart';
class AndroidInfo {
factory AndroidInfo() => _inst;
AndroidInfo._({
const AndroidInfo._({
required this.sdkInt,
required this.totalMemMb,
});
static Future<void> init() async {
final info = await DeviceInfoPlugin().androidInfo;
final sdkInt = info.version.sdkInt!;
final memInfo = await MemoryInfoPlugin().memoryInfo;
final totalMemMb = memInfo.totalMem!.toDouble();
_inst = AndroidInfo._(
sdkInt: sdkInt,
totalMemMb: totalMemMb,
);
_log.info("[init] $_inst");
}
@override
toString() => "$runtimeType {"
"sdkInt: $sdkInt, "
"totalMemMb: ${totalMemMb.toStringAsFixedTruncated(2)}, "
"}";
static late final AndroidInfo _inst;
/// Corresponding to Build.VERSION.SDK_INT
final int sdkInt;
final double totalMemMb;
static final _log = Logger("mobile.android.android_info.AndroidInfo");
}
abstract class AndroidVersion {

View file

@ -1,3 +1,5 @@
// ignore_for_file: constant_identifier_names
/// Standard activity result: operation canceled.
const resultCanceled = 0;
@ -6,3 +8,7 @@ const resultOk = -1;
/// Start of user-defined activity results.
const resultFirstUser = 1;
const ACTION_GET_CONTENT = "android.intent.action.GET_CONTENT";
const CATEGORY_OPENABLE = "android.intent.category.OPENABLE";
const EXTRA_LOCAL_ONLY = "android.intent.extra.LOCAL_ONLY";

View file

@ -1,3 +1,6 @@
import 'dart:math' as math;
import 'package:android_intent_plus/android_intent.dart';
import 'package:flutter/material.dart';
import 'package:logging/logging.dart';
import 'package:nc_photos/account.dart';
@ -6,14 +9,18 @@ import 'package:nc_photos/app_localizations.dart';
import 'package:nc_photos/entity/file.dart';
import 'package:nc_photos/entity/file_util.dart' as file_util;
import 'package:nc_photos/help_utils.dart';
import 'package:nc_photos/iterable_extension.dart';
import 'package:nc_photos/k.dart' as k;
import 'package:nc_photos/mobile/android/android_info.dart';
import 'package:nc_photos/mobile/android/content_uri_image_provider.dart';
import 'package:nc_photos/mobile/android/k.dart' as android;
import 'package:nc_photos/mobile/android/permission_util.dart';
import 'package:nc_photos/object_extension.dart';
import 'package:nc_photos/platform/k.dart' as platform_k;
import 'package:nc_photos/pref.dart';
import 'package:nc_photos/snack_bar_manager.dart';
import 'package:nc_photos/theme.dart';
import 'package:nc_photos/widget/selectable.dart';
import 'package:nc_photos/widget/settings.dart';
import 'package:nc_photos/widget/stateful_slider.dart';
import 'package:nc_photos_plugin/nc_photos_plugin.dart';
@ -86,6 +93,22 @@ class EnhanceHandler {
},
);
break;
case _Algorithm.arbitraryStyleTransfer:
await ImageProcessor.arbitraryStyleTransfer(
"${account.url}/${file.path}",
file.filename,
math.min(
Pref().getEnhanceMaxWidthOr(), isAtLeast5GbRam() ? 1600 : 1280),
math.min(
Pref().getEnhanceMaxHeightOr(), isAtLeast5GbRam() ? 1200 : 960),
args["styleUri"],
args["weight"],
headers: {
"Authorization": Api.getAuthorizationHeaderValue(account),
},
);
break;
}
}
@ -194,6 +217,12 @@ class EnhanceHandler {
link: enhanceEsrganUrl,
algorithm: _Algorithm.esrgan,
),
if (platform_k.isAndroid && isAtLeast4GbRam())
_Option(
title: L10n.global().enhanceStyleTransferTitle,
link: enhanceStyleTransferUrl,
algorithm: _Algorithm.arbitraryStyleTransfer,
),
];
Future<Map<String, dynamic>?> _getArgs(
@ -207,6 +236,9 @@ class EnhanceHandler {
case _Algorithm.esrgan:
return {};
case _Algorithm.arbitraryStyleTransfer:
return _getArbitraryStyleTransferArgs(context);
}
}
@ -314,6 +346,32 @@ class EnhanceHandler {
return radius?.run((it) => {"radius": it});
}
Future<Map<String, dynamic>?> _getArbitraryStyleTransferArgs(
BuildContext context) async {
final result = await showDialog<_StylePickerResult>(
context: context,
builder: (_) => const _StylePicker(),
);
if (result == null) {
// user canceled
return null;
} else {
return {
"styleUri": result.styleUri,
"weight": result.weight,
};
}
}
bool isAtLeast4GbRam() {
// We can't compare with 4096 directly as some RAM are preserved
return AndroidInfo().totalMemMb > 3584;
}
bool isAtLeast5GbRam() {
return AndroidInfo().totalMemMb > 4608;
}
final Account account;
final File file;
@ -324,6 +382,7 @@ enum _Algorithm {
zeroDce,
deepLab3Portrait,
esrgan,
arbitraryStyleTransfer,
}
class _Option {
@ -339,3 +398,186 @@ class _Option {
final String? link;
final _Algorithm algorithm;
}
class _StylePickerResult {
const _StylePickerResult(this.styleUri, this.weight);
final String styleUri;
final double weight;
}
class _StylePicker extends StatefulWidget {
const _StylePicker({
Key? key,
}) : super(key: key);
@override
createState() => _StylePickerState();
}
class _StylePickerState extends State<_StylePicker> {
@override
build(BuildContext context) {
return AppTheme(
child: AlertDialog(
title: Text("Pick a style"),
contentPadding: const EdgeInsets.fromLTRB(24.0, 20.0, 24.0, 0),
content: Column(
crossAxisAlignment: CrossAxisAlignment.start,
mainAxisSize: MainAxisSize.min,
children: [
if (_selected != null) ...[
Align(
alignment: Alignment.center,
child: SizedBox(
width: 128,
height: 128,
child: Image(
image: ResizeImage.resizeIfNeeded(
128, null, ContentUriImage(_getSelectedUri())),
fit: BoxFit.cover,
),
),
),
const SizedBox(height: 16),
],
Wrap(
runSpacing: 8,
spacing: 8,
children: [
..._bundledStyles.mapWithIndex((i, e) => _buildItem(
i,
Image(
image: ResizeImage.resizeIfNeeded(
_thumbSize, null, ContentUriImage(e)),
fit: BoxFit.cover,
),
)),
if (_customUri != null)
_buildItem(
_bundledStyles.length,
Image(
image: ResizeImage.resizeIfNeeded(
_thumbSize, null, ContentUriImage(_customUri!)),
fit: BoxFit.cover,
),
),
InkWell(
onTap: _onCustomTap,
child: SizedBox(
width: _thumbSize.toDouble(),
height: _thumbSize.toDouble(),
child: const Icon(
Icons.file_open_outlined,
size: 24,
),
),
),
],
),
const SizedBox(height: 16),
Row(
mainAxisSize: MainAxisSize.max,
children: [
Icon(
Icons.auto_fix_normal,
color: AppTheme.getSecondaryTextColor(context),
),
Expanded(
child: StatefulSlider(
initialValue: _weight,
min: .01,
onChangeEnd: (value) {
_weight = value;
},
),
),
Icon(
Icons.auto_fix_high,
color: AppTheme.getSecondaryTextColor(context),
),
],
),
],
),
actions: [
TextButton(
onPressed: () {
if (_selected == null) {
SnackBarManager().showSnackBar(const SnackBar(
content: Text("Please pick a style"),
duration: k.snackBarDurationNormal,
));
} else {
final result = _StylePickerResult(_getSelectedUri(), _weight);
Navigator.of(context).pop(result);
}
},
child: Text(L10n.global().enhanceButtonLabel),
),
],
),
);
}
Widget _buildItem(int index, Widget child) {
return SizedBox(
width: _thumbSize.toDouble(),
height: _thumbSize.toDouble(),
child: Selectable(
isSelected: _selected == index,
iconSize: 24,
child: child,
onTap: () {
setState(() {
_selected = index;
});
},
),
);
}
Future<void> _onCustomTap() async {
const intent = AndroidIntent(
action: android.ACTION_GET_CONTENT,
type: "image/*",
category: android.CATEGORY_OPENABLE,
arguments: {
android.EXTRA_LOCAL_ONLY: true,
},
);
final result = await intent.launchForResult();
_log.info("[onCustomTap] Intent result: $result");
if (result?.resultCode == android.resultOk) {
if (mounted) {
setState(() {
_customUri = result!.data;
_selected = _bundledStyles.length;
});
}
}
}
String _getSelectedUri() {
return _selected! < _bundledStyles.length
? _bundledStyles[_selected!]
: _customUri!;
}
int? _selected;
String? _customUri;
double _weight = .85;
static const _thumbSize = 56;
static const _bundledStyles = [
"file:///android_asset/tf/arbitrary-style-transfer/1.jpg",
"file:///android_asset/tf/arbitrary-style-transfer/2.jpg",
"file:///android_asset/tf/arbitrary-style-transfer/3.jpg",
"file:///android_asset/tf/arbitrary-style-transfer/4.jpg",
"file:///android_asset/tf/arbitrary-style-transfer/5.jpg",
"file:///android_asset/tf/arbitrary-style-transfer/6.jpg",
];
static final _log =
Logger("widget.handler.enhance_handler._StylePickerState");
}

View file

@ -18,10 +18,12 @@ packages:
android_intent_plus:
dependency: "direct main"
description:
name: android_intent_plus
url: "https://pub.dartlang.org"
source: hosted
version: "3.1.0"
path: "packages/android_intent_plus"
ref: "android_intent_plus-v3.1.1-nc-photos-1"
resolved-ref: f257f10641e907a94b83ff0e060e80590bf1dae5
url: "https://gitlab.com/nc-photos/plus_plugins"
source: git
version: "3.1.1"
args:
dependency: transitive
description:
@ -605,6 +607,20 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "0.1.3"
memory_info:
dependency: "direct main"
description:
name: memory_info
url: "https://pub.dartlang.org"
source: hosted
version: "0.0.2"
memory_info_platform_interface:
dependency: transitive
description:
name: memory_info_platform_interface
url: "https://pub.dartlang.org"
source: hosted
version: "0.0.1"
meta:
dependency: transitive
description:

View file

@ -28,7 +28,11 @@ dependencies:
sdk: flutter
# android only
android_intent_plus: ^3.0.1
android_intent_plus:
git:
url: https://gitlab.com/nc-photos/plus_plugins
ref: android_intent_plus-v3.1.1-nc-photos-1
path: packages/android_intent_plus
battery_plus: ^2.1.3
bloc: ^7.0.0
cached_network_image: ^3.0.0
@ -71,6 +75,7 @@ dependencies:
intl: ^0.17.0
kiwi: ^4.0.1
logging: ^1.0.1
memory_info: ^0.0.2
mime: ^1.0.1
mutex: ^3.0.0
native_device_orientation: ^1.0.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

View file

@ -33,6 +33,7 @@ add_library( # Sets the name of the library.
SHARED
# Provides a relative path to your source file(s).
arbitrary_style_transfer.cpp
deep_lap_3.cpp
esrgan.cpp
exception.cpp

View file

@ -0,0 +1,223 @@
#include "base_resample.h"
#include "exception.h"
#include "log.h"
#include "stopwatch.h"
#include "tflite_wrapper.h"
#include "util.h"
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <jni.h>
#include <tensorflow/lite/c/c_api.h>
#include <vector>
using namespace plugin;
using namespace std;
using namespace tflite;
namespace {
constexpr const char *PREDICT_MODEL =
"tf/arbitrary-style-transfer-inceptionv3_dr_predict_1.tflite";
constexpr const char *TRANSFER_MODEL =
"tf/arbitrary-style-transfer-inceptionv3_dr_transfer_1.tflite";
class ArbitraryStyleTransfer {
public:
explicit ArbitraryStyleTransfer(AAssetManager *const aam);
std::vector<uint8_t> infer(const uint8_t *image, const size_t width,
const size_t height, const uint8_t *style,
const float weight);
private:
std::vector<float> predict(const uint8_t *image, const size_t width,
const size_t height, const uint8_t *style,
const float weight);
std::vector<uint8_t> transfer(const uint8_t *image, const size_t width,
const size_t height,
const std::vector<float> &bottleneck);
/**
* @param style The style image MUST be 256*256
*/
std::vector<float> predictStyle(const uint8_t *style);
std::vector<float> blendBottleneck(const std::vector<float> &style,
const std::vector<float> &image,
const float styleWeight);
Model predictModel;
Model transferModel;
static constexpr const char *TAG = "ArbitraryStyleTransfer";
};
} // namespace
extern "C" JNIEXPORT jbyteArray JNICALL
Java_com_nkming_nc_1photos_plugin_image_1processor_ArbitraryStyleTransfer_inferNative(
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
jint width, jint height, jbyteArray style, jfloat weight) {
try {
initOpenMp();
auto aam = AAssetManager_fromJava(env, assetManager);
ArbitraryStyleTransfer model(aam);
RaiiContainer<jbyte> cImage(
[&]() { return env->GetByteArrayElements(image, nullptr); },
[&](jbyte *obj) {
env->ReleaseByteArrayElements(image, obj, JNI_ABORT);
});
RaiiContainer<jbyte> cStyle(
[&]() { return env->GetByteArrayElements(style, nullptr); },
[&](jbyte *obj) {
env->ReleaseByteArrayElements(style, obj, JNI_ABORT);
});
const auto result =
model.infer(reinterpret_cast<uint8_t *>(cImage.get()), width, height,
reinterpret_cast<uint8_t *>(cStyle.get()), weight);
auto resultAry = env->NewByteArray(result.size());
env->SetByteArrayRegion(resultAry, 0, result.size(),
reinterpret_cast<const int8_t *>(result.data()));
return resultAry;
} catch (const exception &e) {
throwJavaException(env, e.what());
return nullptr;
}
}
namespace {
ArbitraryStyleTransfer::ArbitraryStyleTransfer(AAssetManager *const aam)
: predictModel(Asset(aam, PREDICT_MODEL)),
transferModel(Asset(aam, TRANSFER_MODEL)) {}
vector<uint8_t> ArbitraryStyleTransfer::infer(const uint8_t *image,
const size_t width,
const size_t height,
const uint8_t *style,
const float weight) {
const auto bottleneck = predict(image, width, height, style, weight);
return transfer(image, width, height, bottleneck);
}
vector<float> ArbitraryStyleTransfer::predict(const uint8_t *image,
const size_t width,
const size_t height,
const uint8_t *style,
const float weight) {
auto style_bottleneck = predictStyle(style);
vector<uint8_t> imageStyleBitmap(256 * 256 * 3);
base::ResampleImage24(image, width, height, imageStyleBitmap.data(), 256, 256,
base::KernelTypeLanczos3);
auto image_bottleneck = predictStyle(imageStyleBitmap.data());
return blendBottleneck(style_bottleneck, image_bottleneck, weight);
}
vector<uint8_t>
ArbitraryStyleTransfer::transfer(const uint8_t *image, const size_t width,
const size_t height,
const vector<float> &bottleneck) {
vector<uint8_t> resizedImage;
auto inputWidth = width;
auto inputHeight = height;
const uint8_t *inputImage = image;
if (width % 4 != 0 || height % 4 != 0) {
LOGI(TAG, "[transfer] Resize bitmap to multiple of 4");
inputWidth = width - width % 4;
inputHeight = height - height % 4;
resizedImage.resize(inputWidth * inputHeight * 3);
base::ResampleImage24(image, width, height, resizedImage.data(), inputWidth,
inputHeight, base::KernelTypeLanczos3);
inputImage = resizedImage.data();
}
InterpreterOptions options;
options.setNumThreads(getNumberOfProcessors());
Interpreter interpreter(transferModel, options);
const int dims[] = {1, static_cast<int>(inputHeight),
static_cast<int>(inputWidth), 3};
interpreter.resizeInputTensor(1, dims, 4);
interpreter.allocateTensors();
LOGI(TAG, "[transfer] Copy bias");
auto inputTensor0 = interpreter.getInputTensor(0);
assert(TfLiteTensorByteSize(inputTensor0) ==
bottleneck.size() * sizeof(float));
TfLiteTensorCopyFromBuffer(inputTensor0, bottleneck.data(),
bottleneck.size() * sizeof(float));
LOGI(TAG, "[transfer] Convert bitmap to input");
auto input = rgb8ToRgbFloat(inputImage, inputWidth * inputHeight * 3, true);
auto inputTensor1 = interpreter.getInputTensor(1);
assert(TfLiteTensorByteSize(inputTensor1) == input.size() * sizeof(float));
TfLiteTensorCopyFromBuffer(inputTensor1, input.data(),
input.size() * sizeof(float));
input.clear();
LOGI(TAG, "[transfer] Inferring");
Stopwatch stopwatch;
interpreter.invoke();
LOGI(TAG, "[transfer] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f);
auto outputTensor = interpreter.getOutputTensor(0);
vector<float> output(inputWidth * inputHeight * 3);
assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float));
TfLiteTensorCopyToBuffer(outputTensor, output.data(),
output.size() * sizeof(float));
auto outputRgb8 = rgbFloatToRgb8(output.data(), output.size(), true);
output.clear();
if (!resizedImage.empty()) {
// resize it back to the original resolution
vector<uint8_t> temp(width * height * 3);
base::ResampleImage24(outputRgb8.data(), inputWidth, inputHeight,
temp.data(), width, height, base::KernelTypeBicubic);
return temp;
} else {
return outputRgb8;
}
}
vector<float> ArbitraryStyleTransfer::predictStyle(const uint8_t *style) {
InterpreterOptions options;
options.setNumThreads(getNumberOfProcessors());
Interpreter interpreter(predictModel, options);
interpreter.allocateTensors();
LOGI(TAG, "[predictStyle] Convert bitmap to input");
const auto input = rgb8ToRgbFloat(style, 256 * 256 * 3, true);
auto inputTensor = interpreter.getInputTensor(0);
assert(TfLiteTensorByteSize(inputTensor) == input.size() * sizeof(float));
TfLiteTensorCopyFromBuffer(inputTensor, input.data(),
input.size() * sizeof(float));
LOGI(TAG, "[predictStyle] Inferring");
Stopwatch stopwatch;
interpreter.invoke();
LOGI(TAG, "[predictStyle] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f);
auto outputTensor = interpreter.getOutputTensor(0);
vector<float> output(100);
assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float));
TfLiteTensorCopyToBuffer(outputTensor, output.data(),
output.size() * sizeof(float));
return output;
}
vector<float>
ArbitraryStyleTransfer::blendBottleneck(const vector<float> &style,
const vector<float> &image,
const float styleWeight) {
assert(style.size() == 100);
assert(image.size() == 100);
vector<float> product(100);
for (int i = 0; i < 100; ++i) {
product[i] = styleWeight * style[i] + (1 - styleWeight) * image[i];
}
return product;
}
} // namespace

View file

@ -0,0 +1,16 @@
#pragma once
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif
JNIEXPORT jbyteArray JNICALL
Java_com_nkming_nc_1photos_plugin_image_1processor_ArbitraryStyleTransfer_inferNative(
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
jint width, jint height, jbyteArray style, jfloat weight);
#ifdef __cplusplus
}
#endif

View file

@ -4,9 +4,31 @@ import android.content.Context
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.net.Uri
import java.io.InputStream
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
fun Bitmap.aspectRatio() = width / height.toFloat()
/**
* Recycle the bitmap after @c block returns.
*
* @param block
* @return
*/
@OptIn(ExperimentalContracts::class)
inline fun <T> Bitmap.use(block: (Bitmap) -> T): T {
contract {
callsInPlace(block, InvocationKind.EXACTLY_ONCE)
}
try {
return block(this)
} finally {
recycle()
}
}
enum class BitmapResizeMethod {
FIT,
FILL,
@ -113,10 +135,20 @@ interface BitmapUtil {
}
}
private fun openUriInputStream(
context: Context, uri: Uri
): InputStream? {
return if (UriUtil.isAssetUri(uri)) {
context.assets.open(UriUtil.getAssetUriPath(uri))
} else {
context.contentResolver.openInputStream(uri)
}
}
private fun loadImageBounds(
context: Context, uri: Uri
): BitmapFactory.Options {
context.contentResolver.openInputStream(uri)!!.use {
openUriInputStream(context, uri)!!.use {
val opt = BitmapFactory.Options().apply {
inJustDecodeBounds = true
}
@ -128,7 +160,7 @@ interface BitmapUtil {
private fun loadImage(
context: Context, uri: Uri, opt: BitmapFactory.Options
): Bitmap {
context.contentResolver.openInputStream(uri)!!.use {
openUriInputStream(context, uri)!!.use {
return BitmapFactory.decodeStream(it, null, opt)!!
}
}

View file

@ -31,10 +31,15 @@ class ContentUriChannelHandler(context: Context) :
private fun readUri(uri: String, result: MethodChannel.Result) {
val uriTyped = Uri.parse(uri)
try {
val bytes =
val bytes = if (UriUtil.isAssetUri(uriTyped)) {
context.assets.open(UriUtil.getAssetUriPath(uriTyped)).use {
it.readBytes()
}
} else {
context.contentResolver.openInputStream(uriTyped)!!.use {
it.readBytes()
}
}
result.success(bytes)
} catch (e: FileNotFoundException) {
result.error("fileNotFoundException", e.toString(), null)

View file

@ -2,6 +2,7 @@ package com.nkming.nc_photos.plugin
import android.content.Context
import android.content.Intent
import android.net.Uri
import androidx.core.content.ContextCompat
import io.flutter.plugin.common.EventChannel
import io.flutter.plugin.common.MethodCall
@ -64,6 +65,23 @@ class ImageProcessorChannelHandler(context: Context) :
}
}
"arbitraryStyleTransfer" -> {
try {
arbitraryStyleTransfer(
call.argument("fileUrl")!!,
call.argument("headers"),
call.argument("filename")!!,
call.argument("maxWidth")!!,
call.argument("maxHeight")!!,
call.argument("styleUri")!!,
call.argument("weight")!!,
result
)
} catch (e: Throwable) {
result.error("systemException", e.toString(), null)
}
}
else -> result.notImplemented()
}
}
@ -105,6 +123,21 @@ class ImageProcessorChannelHandler(context: Context) :
ImageProcessorService.METHOD_ESRGAN, result
)
private fun arbitraryStyleTransfer(
fileUrl: String, headers: Map<String, String>?, filename: String,
maxWidth: Int, maxHeight: Int, styleUri: String, weight: Float,
result: MethodChannel.Result
) = method(
fileUrl, headers, filename, maxWidth, maxHeight,
ImageProcessorService.METHOD_ARBITRARY_STYLE_TRANSFER, result,
onIntent = {
it.putExtra(
ImageProcessorService.EXTRA_STYLE_URI, Uri.parse(styleUri)
)
it.putExtra(ImageProcessorService.EXTRA_WEIGHT, weight)
}
)
private fun method(
fileUrl: String, headers: Map<String, String>?, filename: String,
maxWidth: Int, maxHeight: Int, method: String,

View file

@ -17,6 +17,7 @@ import androidx.core.app.NotificationChannelCompat
import androidx.core.app.NotificationCompat
import androidx.core.app.NotificationManagerCompat
import androidx.exifinterface.media.ExifInterface
import com.nkming.nc_photos.plugin.image_processor.ArbitraryStyleTransfer
import com.nkming.nc_photos.plugin.image_processor.DeepLab3Portrait
import com.nkming.nc_photos.plugin.image_processor.Esrgan
import com.nkming.nc_photos.plugin.image_processor.ZeroDce
@ -30,6 +31,7 @@ class ImageProcessorService : Service() {
const val METHOD_ZERO_DCE = "zero-dce"
const val METHOD_DEEP_LAP_PORTRAIT = "DeepLab3Portrait"
const val METHOD_ESRGAN = "Esrgan"
const val METHOD_ARBITRARY_STYLE_TRANSFER = "ArbitraryStyleTransfer"
const val EXTRA_FILE_URL = "fileUrl"
const val EXTRA_HEADERS = "headers"
const val EXTRA_FILENAME = "filename"
@ -37,6 +39,8 @@ class ImageProcessorService : Service() {
const val EXTRA_MAX_HEIGHT = "maxHeight"
const val EXTRA_RADIUS = "radius"
const val EXTRA_ITERATION = "iteration"
const val EXTRA_STYLE_URI = "styleUri"
const val EXTRA_WEIGHT = "weight"
private const val ACTION_CANCEL = "cancel"
@ -109,6 +113,9 @@ class ImageProcessorService : Service() {
startId, intent.extras!!
)
METHOD_ESRGAN -> onEsrgan(startId, intent.extras!!)
METHOD_ARBITRARY_STYLE_TRANSFER -> onArbitraryStyleTransfer(
startId, intent.extras!!
)
else -> {
logE(TAG, "Unknown method: $method")
// we can't call stopSelf here as it'll stop the service even if
@ -142,6 +149,16 @@ class ImageProcessorService : Service() {
return onMethod(startId, extras, METHOD_ESRGAN)
}
private fun onArbitraryStyleTransfer(startId: Int, extras: Bundle) {
return onMethod(
startId, extras, METHOD_ARBITRARY_STYLE_TRANSFER,
args = mapOf(
"styleUri" to extras.getParcelable<Uri>(EXTRA_STYLE_URI),
"weight" to extras.getFloat(EXTRA_WEIGHT),
)
)
}
/**
* Handle methods without arguments
*
@ -530,6 +547,12 @@ private open class ImageProcessorCommandTask(context: Context) :
context, cmd.maxWidth, cmd.maxHeight
).infer(fileUri)
ImageProcessorService.METHOD_ARBITRARY_STYLE_TRANSFER -> ArbitraryStyleTransfer(
context, cmd.maxWidth, cmd.maxHeight,
cmd.args["styleUri"] as Uri,
cmd.args["weight"] as Float
).infer(fileUri)
else -> throw IllegalArgumentException(
"Unknown method: ${cmd.method}"
)

View file

@ -28,6 +28,24 @@ interface UriUtil {
}
}
/**
* Asset URI is a non-standard Uri that points to an asset file.
*
* An asset URI is formatted as file:///android_asset/path/to/file
*
* @param uri
* @return
*/
fun isAssetUri(uri: Uri): Boolean {
return uri.scheme == "file" && uri.path?.startsWith(
"/android_asset/"
) == true
}
fun getAssetUriPath(uri: Uri): String {
return uri.path!!.substring("/android_asset/".length)
}
private const val TAG = "UriUtil"
}
}

View file

@ -0,0 +1,70 @@
package com.nkming.nc_photos.plugin.image_processor
import android.content.Context
import android.content.res.AssetManager
import android.graphics.Bitmap
import android.net.Uri
import com.nkming.nc_photos.plugin.BitmapResizeMethod
import com.nkming.nc_photos.plugin.BitmapUtil
import com.nkming.nc_photos.plugin.logI
import com.nkming.nc_photos.plugin.use
class ArbitraryStyleTransfer(
context: Context, maxWidth: Int, maxHeight: Int, styleUri: Uri,
weight: Float
) {
companion object {
const val TAG = "ArbitraryStyleTransfer"
}
fun infer(imageUri: Uri): Bitmap {
val width: Int
val height: Int
val rgb8Image = BitmapUtil.loadImage(
context, imageUri, maxWidth, maxHeight, BitmapResizeMethod.FIT,
isAllowSwapSide = true, shouldUpscale = false
).use {
width = it.width
height = it.height
TfLiteHelper.bitmapToRgb8Array(it)
}
val rgb8Style = BitmapUtil.loadImage(
context, styleUri, 256, 256, BitmapResizeMethod.FILL,
isAllowSwapSide = false, shouldUpscale = true
).use {
val styleBitmap = if (it.width != 256 || it.height != 256) {
val x = (it.width - 256) / 2
val y = (it.height - 256) / 2
logI(
TAG,
"[infer] Resize and crop style image: ${it.width}x${it.height} -> 256x256 ($x, $y)"
)
// crop
Bitmap.createBitmap(it, x, y, 256, 256)
} else {
it
}
styleBitmap.use {
TfLiteHelper.bitmapToRgb8Array(styleBitmap)
}
}
val am = context.assets
return inferNative(
am, rgb8Image, width, height, rgb8Style, weight
).let {
TfLiteHelper.rgb8ArrayToBitmap(it, width, height)
}
}
private external fun inferNative(
am: AssetManager, image: ByteArray, width: Int, height: Int,
style: ByteArray, weight: Float
): ByteArray
private val context = context
private val maxWidth = maxWidth
private val maxHeight = maxHeight
private val styleUri = styleUri
private val weight = weight
}

View file

@ -53,6 +53,25 @@ class ImageProcessor {
"maxHeight": maxHeight,
});
static Future<void> arbitraryStyleTransfer(
String fileUrl,
String filename,
int maxWidth,
int maxHeight,
String styleUri,
double weight, {
Map<String, String>? headers,
}) =>
_methodChannel.invokeMethod("arbitraryStyleTransfer", <String, dynamic>{
"fileUrl": fileUrl,
"headers": headers,
"filename": filename,
"maxWidth": maxWidth,
"maxHeight": maxHeight,
"styleUri": styleUri,
"weight": weight,
});
static const _methodChannel =
MethodChannel("${k.libId}/image_processor_method");
}