Reduce memory usage of ZeroDCE

This commit is contained in:
Ming Ming 2022-05-15 13:32:26 +08:00
parent 9ecdb53e69
commit d5151f077f

View file

@ -8,6 +8,7 @@ import com.nkming.nc_photos.plugin.BitmapResizeMethod
import com.nkming.nc_photos.plugin.BitmapUtil
import org.tensorflow.lite.Interpreter
import java.nio.FloatBuffer
import java.nio.IntBuffer
import kotlin.math.pow
class ZeroDce(context: Context, maxWidth: Int, maxHeight: Int) {
@ -49,28 +50,56 @@ class ZeroDce(context: Context, maxWidth: Int, maxHeight: Int) {
imageUri: Uri, alphaMaps: Bitmap, iteration: Int
): Bitmap {
Log.i(TAG, "Enhancing image, iteration: $iteration")
// we can't work with FloatBuffer directly here as a FloatBuffer is way
// too large to fit in Android's heap limit
// downscale original to prevent OOM
val resized = BitmapUtil.loadImage(
val width: Int
val height: Int
val imgBuf: IntBuffer
BitmapUtil.loadImage(
context, imageUri, maxWidth, maxHeight, BitmapResizeMethod.FIT,
isAllowSwapSide = true, shouldUpscale = false
)
// resize aMaps
val resizedFilter = Bitmap.createScaledBitmap(
alphaMaps, resized.width, resized.height, true
)
val imgBuf = TfLiteHelper.bitmapToRgbFloatArray(resized)
val filterBuf = TfLiteHelper.bitmapToRgbFloatArray(resizedFilter)
for (i in 0 until iteration) {
val src = imgBuf.array()
val filter = filterBuf.array()
for (j in src.indices) {
src[j] = src[j] + -filter[j] * (src[j].pow(2f) - src[j])
}
).apply {
width = this.width
height = this.height
imgBuf = IntBuffer.allocate(width * height)
copyPixelsToBuffer(imgBuf)
recycle()
}
return TfLiteHelper.rgbFloatArrayToBitmap(
imgBuf, resized.width, resized.height
)
imgBuf.rewind()
// resize aMaps
val filterBuf: IntBuffer
Bitmap.createScaledBitmap(alphaMaps, width, height, true).apply {
filterBuf = IntBuffer.allocate(width * height)
copyPixelsToBuffer(filterBuf)
recycle()
}
filterBuf.rewind()
val src = imgBuf.array()
val filter = filterBuf.array()
for (i in src.indices) {
var sr = (src[i] and 0xFF) / 255f
var sg = (src[i] shr 8 and 0xFF) / 255f
var sb = (src[i] shr 16 and 0xFF) / 255f
val fr = (filter[i] and 0xFF) / 255f
val fg = (filter[i] shr 8 and 0xFF) / 255f
val fb = (filter[i] shr 16 and 0xFF) / 255f
for (j in 0 until iteration) {
sr += -fr * (sr.pow(2f) - sr)
sg += -fg * (sg.pow(2f) - sg)
sb += -fb * (sb.pow(2f) - sb)
}
src[i] = (0xFF shl 24) or
((sr * 255).toInt().coerceIn(0, 255)) or
((sg * 255).toInt().coerceIn(0, 255) shl 8) or
((sb * 255).toInt().coerceIn(0, 255) shl 16)
}
return Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
.apply {
copyPixelsFromBuffer(imgBuf)
}
}
private val context = context