From d5151f077f5c53bffabda8d4b4fc5ce237d28393 Mon Sep 17 00:00:00 2001
From: Ming Ming <nkming2@gmail.com>
Date: Sun, 15 May 2022 13:32:26 +0800
Subject: [PATCH] Reduce memory usage of ZeroDCE

---
 .../plugin/image_processor/ZeroDce.kt         | 65 ++++++++++++++-----
 1 file changed, 47 insertions(+), 18 deletions(-)

diff --git a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/ZeroDce.kt b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/ZeroDce.kt
index 17cc0a2a..bd6494e2 100644
--- a/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/ZeroDce.kt
+++ b/plugin/android/src/main/kotlin/com/nkming/nc_photos/plugin/image_processor/ZeroDce.kt
@@ -8,6 +8,7 @@ import com.nkming.nc_photos.plugin.BitmapResizeMethod
 import com.nkming.nc_photos.plugin.BitmapUtil
 import org.tensorflow.lite.Interpreter
 import java.nio.FloatBuffer
+import java.nio.IntBuffer
 import kotlin.math.pow
 
 class ZeroDce(context: Context, maxWidth: Int, maxHeight: Int) {
@@ -49,28 +50,56 @@ class ZeroDce(context: Context, maxWidth: Int, maxHeight: Int) {
 		imageUri: Uri, alphaMaps: Bitmap, iteration: Int
 	): Bitmap {
 		Log.i(TAG, "Enhancing image, iteration: $iteration")
+		// we can't work with FloatBuffer directly here as a FloatBuffer is way
+		// too large to fit in Android's heap limit
 		// downscale original to prevent OOM
-		val resized = BitmapUtil.loadImage(
+		val width: Int
+		val height: Int
+		val imgBuf: IntBuffer
+		BitmapUtil.loadImage(
 			context, imageUri, maxWidth, maxHeight, BitmapResizeMethod.FIT,
 			isAllowSwapSide = true, shouldUpscale = false
-		)
-		// resize aMaps
-		val resizedFilter = Bitmap.createScaledBitmap(
-			alphaMaps, resized.width, resized.height, true
-		)
-
-		val imgBuf = TfLiteHelper.bitmapToRgbFloatArray(resized)
-		val filterBuf = TfLiteHelper.bitmapToRgbFloatArray(resizedFilter)
-		for (i in 0 until iteration) {
-			val src = imgBuf.array()
-			val filter = filterBuf.array()
-			for (j in src.indices) {
-				src[j] = src[j] + -filter[j] * (src[j].pow(2f) - src[j])
-			}
+		).apply {
+			width = this.width
+			height = this.height
+			imgBuf = IntBuffer.allocate(width * height)
+			copyPixelsToBuffer(imgBuf)
+			recycle()
 		}
-		return TfLiteHelper.rgbFloatArrayToBitmap(
-			imgBuf, resized.width, resized.height
-		)
+		imgBuf.rewind()
+
+		// resize aMaps
+		val filterBuf: IntBuffer
+		Bitmap.createScaledBitmap(alphaMaps, width, height, true).apply {
+			filterBuf = IntBuffer.allocate(width * height)
+			copyPixelsToBuffer(filterBuf)
+			recycle()
+		}
+		filterBuf.rewind()
+
+		val src = imgBuf.array()
+		val filter = filterBuf.array()
+		for (i in src.indices) {
+			var sr = (src[i] and 0xFF) / 255f
+			var sg = (src[i] shr 8 and 0xFF) / 255f
+			var sb = (src[i] shr 16 and 0xFF) / 255f
+			val fr = (filter[i] and 0xFF) / 255f
+			val fg = (filter[i] shr 8 and 0xFF) / 255f
+			val fb = (filter[i] shr 16 and 0xFF) / 255f
+			for (j in 0 until iteration) {
+				sr += -fr * (sr.pow(2f) - sr)
+				sg += -fg * (sg.pow(2f) - sg)
+				sb += -fb * (sb.pow(2f) - sb)
+			}
+			src[i] = (0xFF shl 24) or
+					((sr * 255).toInt().coerceIn(0, 255)) or
+					((sg * 255).toInt().coerceIn(0, 255) shl 8) or
+					((sb * 255).toInt().coerceIn(0, 255) shl 16)
+		}
+		return Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
+			.apply {
+				copyPixelsFromBuffer(imgBuf)
+			}
 	}
 
 	private val context = context