mirror of
https://gitlab.com/nkming2/nc-photos.git
synced 2025-03-04 22:38:51 +01:00
Migrate photo enhancements to NDK
This commit is contained in:
parent
477e694133
commit
eeb95258c0
34 changed files with 4366 additions and 259 deletions
1
plugin/android/.gitignore
vendored
1
plugin/android/.gitignore
vendored
|
@ -6,3 +6,4 @@
|
|||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
/.cxx
|
||||
|
|
|
@ -48,6 +48,20 @@ android {
|
|||
|
||||
defaultConfig {
|
||||
minSdkVersion 21
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
cppFlags ''
|
||||
}
|
||||
}
|
||||
ndk {
|
||||
abiFilters "armeabi-v7a", "arm64-v8a", "x86_64"
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path file('src/main/cpp/CMakeLists.txt')
|
||||
version '3.18.1'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,7 +70,5 @@ dependencies {
|
|||
implementation "androidx.annotation:annotation:1.3.0"
|
||||
implementation "androidx.core:core-ktx:1.7.0"
|
||||
implementation "androidx.exifinterface:exifinterface:1.3.3"
|
||||
implementation 'com.github.android:renderscript-intrinsics-replacement-toolkit:b6363490c3'
|
||||
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version"
|
||||
implementation 'org.tensorflow:tensorflow-lite:2.8.0'
|
||||
}
|
||||
|
|
|
@ -0,0 +1,538 @@
|
|||
/*
|
||||
* Copyright (C) 2021 The Android Open Source Project
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ANDROID_RENDERSCRIPT_TOOLKIT_TOOLKIT_H
|
||||
#define ANDROID_RENDERSCRIPT_TOOLKIT_TOOLKIT_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
namespace renderscript {
|
||||
|
||||
class TaskProcessor;
|
||||
|
||||
/**
|
||||
* Define a range of data to process.
|
||||
*
|
||||
* This class is used to restrict a Toolkit operation to a rectangular subset of the input
|
||||
* tensor.
|
||||
*
|
||||
* @property startX The index of the first value to be included on the X axis.
|
||||
* @property endX The index after the last value to be included on the X axis.
|
||||
* @property startY The index of the first value to be included on the Y axis.
|
||||
* @property endY The index after the last value to be included on the Y axis.
|
||||
*/
|
||||
struct Restriction {
|
||||
size_t startX;
|
||||
size_t endX;
|
||||
size_t startY;
|
||||
size_t endY;
|
||||
};
|
||||
|
||||
/**
|
||||
* A collection of high-performance graphic utility functions like blur and blend.
|
||||
*
|
||||
* This toolkit provides ten image manipulation functions: blend, blur, color matrix, convolve,
|
||||
* histogram, histogramDot, lut, lut3d, resize, and YUV to RGB. These functions execute
|
||||
* multithreaded on the CPU.
|
||||
*
|
||||
* These functions work over raw byte arrays. You'll need to specify the width and height of
|
||||
* the data to be processed, as well as the number of bytes per pixel. For most use cases,
|
||||
* this will be 4.
|
||||
*
|
||||
* You should instantiate the Toolkit once and reuse it throughout your application.
|
||||
* On instantiation, the Toolkit creates a thread pool that's used for processing all the functions.
|
||||
* You can limit the number of pool threads used by the Toolkit via the constructor. The pool
|
||||
* threads are destroyed once the Toolkit is destroyed, after any pending work is done.
|
||||
*
|
||||
* This library is thread safe. You can call methods from different pool threads. The functions will
|
||||
* execute sequentially.
|
||||
*
|
||||
* A Java/Kotlin Toolkit is available. It calls this library through JNI.
|
||||
*
|
||||
* This toolkit can be used as a replacement for most RenderScript Intrinsic functions. Compared
|
||||
* to RenderScript, it's simpler to use and more than twice as fast on the CPU. However RenderScript
|
||||
* Intrinsics allow more flexibility for the type of allocation supported. In particular, this
|
||||
* toolkit does not support allocations of floats.
|
||||
*/
|
||||
class RenderScriptToolkit {
|
||||
/** Each Toolkit method call is converted to a Task. The processor owns the thread pool. It
|
||||
* tiles the tasks and schedule them over the pool threads.
|
||||
*/
|
||||
std::unique_ptr<TaskProcessor> processor;
|
||||
|
||||
public:
|
||||
/**
|
||||
* Creates the pool threads that are used for processing the method calls.
|
||||
*/
|
||||
RenderScriptToolkit(int numberOfThreads = 0);
|
||||
/**
|
||||
* Destroys the thread pool. This stops any in-progress work; the Toolkit methods called from
|
||||
* other pool threads will return without having completed the work. Because of the undefined
|
||||
* state of the output buffers, an application should avoid destroying the Toolkit if other pool
|
||||
* threads are executing Toolkit methods.
|
||||
*/
|
||||
~RenderScriptToolkit();
|
||||
|
||||
/**
|
||||
* Determines how a source buffer is blended into a destination buffer.
|
||||
*
|
||||
* See {@link RenderScriptToolkit::blend}.
|
||||
*
|
||||
* blend only works on 4 byte RGBA data. In the descriptions below, ".a" represents
|
||||
* the alpha channel.
|
||||
*/
|
||||
enum class BlendingMode {
|
||||
/**
|
||||
* dest = 0
|
||||
*
|
||||
* The destination is cleared, i.e. each pixel is set to (0, 0, 0, 0)
|
||||
*/
|
||||
CLEAR = 0,
|
||||
/**
|
||||
* dest = src
|
||||
*
|
||||
* Sets each pixel of the destination to the corresponding one in the source.
|
||||
*/
|
||||
SRC = 1,
|
||||
/**
|
||||
* dest = dest
|
||||
*
|
||||
* Leaves the destination untouched. This is a no-op.
|
||||
*/
|
||||
DST = 2,
|
||||
/**
|
||||
* dest = src + dest * (1.0 - src.a)
|
||||
*/
|
||||
SRC_OVER = 3,
|
||||
/**
|
||||
* dest = dest + src * (1.0 - dest.a)
|
||||
*/
|
||||
DST_OVER = 4,
|
||||
/**
|
||||
* dest = src * dest.a
|
||||
*/
|
||||
SRC_IN = 5,
|
||||
/**
|
||||
* dest = dest * src.a
|
||||
*/
|
||||
DST_IN = 6,
|
||||
/**
|
||||
* dest = src * (1.0 - dest.a)
|
||||
*/
|
||||
SRC_OUT = 7,
|
||||
/**
|
||||
* dest = dest * (1.0 - src.a)
|
||||
*/
|
||||
DST_OUT = 8,
|
||||
/**
|
||||
* dest.rgb = src.rgb * dest.a + (1.0 - src.a) * dest.rgb, dest.a = dest.a
|
||||
*/
|
||||
SRC_ATOP = 9,
|
||||
/**
|
||||
* dest = dest.rgb * src.a + (1.0 - dest.a) * src.rgb, dest.a = src.a
|
||||
*/
|
||||
DST_ATOP = 10,
|
||||
/**
|
||||
* dest = {src.r ^ dest.r, src.g ^ dest.g, src.b ^ dest.b, src.a ^ dest.a}
|
||||
*
|
||||
* Note: this is NOT the Porter/Duff XOR mode; this is a bitwise xor.
|
||||
*/
|
||||
XOR = 11,
|
||||
/**
|
||||
* dest = src * dest
|
||||
*/
|
||||
MULTIPLY = 12,
|
||||
/**
|
||||
* dest = min(src + dest, 1.0)
|
||||
*/
|
||||
ADD = 13,
|
||||
/**
|
||||
* dest = max(dest - src, 0.0)
|
||||
*/
|
||||
SUBTRACT = 14
|
||||
};
|
||||
|
||||
/**
|
||||
* Blend a source buffer with the destination buffer.
|
||||
*
|
||||
* Blends a source buffer and a destination buffer, placing the result in the destination
|
||||
* buffer. The blending is done pairwise between two corresponding RGBA values found in
|
||||
* each buffer. The mode parameter specifies one of fifteen blending operations.
|
||||
* See {@link BlendingMode}.
|
||||
*
|
||||
* An optional range parameter can be set to restrict the operation to a rectangular subset
|
||||
* of each buffer. If provided, the range must be wholly contained with the dimensions
|
||||
* described by sizeX and sizeY.
|
||||
*
|
||||
* The source and destination buffers must have the same dimensions. Both buffers should be
|
||||
* large enough for sizeX * sizeY * 4 bytes. The buffers have a row-major layout.
|
||||
*
|
||||
* @param mode The specific blending operation to do.
|
||||
* @param source The RGBA input buffer.
|
||||
* @param dest The destination buffer. Used for input and output.
|
||||
* @param sizeX The width of both buffers, as a number of RGBA values.
|
||||
* @param sizeY The height of both buffers, as a number of RGBA values.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void blend(BlendingMode mode, const uint8_t* _Nonnull source, uint8_t* _Nonnull dst,
|
||||
size_t sizeX, size_t sizeY, const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* Blur an image.
|
||||
*
|
||||
* Performs a Gaussian blur of the input image and stores the result in the out buffer.
|
||||
*
|
||||
* The radius determines which pixels are used to compute each blurred pixels. This Toolkit
|
||||
* accepts values between 1 and 25. Larger values create a more blurred effect but also
|
||||
* take longer to compute. When the radius extends past the edge, the edge pixel will
|
||||
* be used as replacement for the pixel that's out off boundary.
|
||||
*
|
||||
* Each input pixel can either be represented by four bytes (RGBA format) or one byte
|
||||
* for the less common blurring of alpha channel only image.
|
||||
*
|
||||
* An optional range parameter can be set to restrict the operation to a rectangular subset
|
||||
* of each buffer. If provided, the range must be wholly contained with the dimensions
|
||||
* described by sizeX and sizeY.
|
||||
*
|
||||
* The input and output buffers must have the same dimensions. Both buffers should be
|
||||
* large enough for sizeX * sizeY * vectorSize bytes. The buffers have a row-major layout.
|
||||
*
|
||||
* @param in The buffer of the image to be blurred.
|
||||
* @param out The buffer that receives the blurred image.
|
||||
* @param sizeX The width of both buffers, as a number of 1 or 4 byte cells.
|
||||
* @param sizeY The height of both buffers, as a number of 1 or 4 byte cells.
|
||||
* @param vectorSize Either 1 or 4, the number of bytes in each cell, i.e. A vs. RGBA.
|
||||
* @param radius The radius of the pixels used to blur.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void blur(const uint8_t* _Nonnull in, uint8_t* _Nonnull out, size_t sizeX, size_t sizeY,
|
||||
size_t vectorSize, int radius, const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* Identity matrix that can be passed to the {@link RenderScriptToolkit::colorMatrix} method.
|
||||
*
|
||||
* Using this matrix will result in no change to the pixel through multiplication although
|
||||
* the pixel value can still be modified by the add vector, or transformed to a different
|
||||
* format.
|
||||
*/
|
||||
static constexpr float kIdentityMatrix[] = {
|
||||
1.0f, 0.0f, 0.0f, 0.0f,
|
||||
0.0f, 1.0f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 1.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 1.0f
|
||||
};
|
||||
|
||||
/**
|
||||
* Matrix to turn color pixels to a grey scale.
|
||||
*
|
||||
* Use this matrix with the {@link RenderScriptToolkit::colorMatrix} method to convert an
|
||||
* image from color to greyscale.
|
||||
*/
|
||||
static constexpr float kGreyScaleColorMatrix[] = {
|
||||
0.299f, 0.299f, 0.299f, 0.0f,
|
||||
0.587f, 0.587f, 0.587f, 0.0f,
|
||||
0.114f, 0.114f, 0.114f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 1.0f
|
||||
};
|
||||
|
||||
/**
|
||||
* Matrix to convert RGB to YUV.
|
||||
*
|
||||
* Use this matrix with the {@link RenderScriptToolkit::colorMatrix} method to convert the
|
||||
* first three bytes of each pixel from RGB to YUV. This leaves the last byte (the alpha
|
||||
* channel) untouched.
|
||||
*
|
||||
* This is a simplistic conversion. Most YUV buffers have more complicated format, not supported
|
||||
* by this method.
|
||||
*/
|
||||
static constexpr float kRgbToYuvMatrix[] = {
|
||||
0.299f, -0.14713f, 0.615f, 0.0f,
|
||||
0.587f, -0.28886f, -0.51499f, 0.0f,
|
||||
0.114f, 0.436f, -0.10001f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 1.0f
|
||||
};
|
||||
|
||||
/**
|
||||
* Matrix to convert YUV to RGB.
|
||||
*
|
||||
* Use this matrix with the {@link RenderScriptToolkit::colorMatrix} method to convert the
|
||||
* first three bytes of each pixel from YUV to RGB. This leaves the last byte (the alpha
|
||||
* channel) untouched.
|
||||
*
|
||||
* This is a simplistic conversion. Most YUV buffers have more complicated format, not supported
|
||||
* by this method. Use {@link RenderScriptToolkit::yuvToRgb} to convert these buffers.
|
||||
*/
|
||||
static constexpr float kYuvToRgbMatrix[] = {
|
||||
1.0f, 1.0f, 1.0f, 0.0f,
|
||||
0.0f, -0.39465f, 2.03211f, 0.0f,
|
||||
1.13983f, -0.5806f, 0.0f, 0.0f,
|
||||
0.0f, 0.0f, 0.0f, 1.0f
|
||||
};
|
||||
|
||||
/**
|
||||
* Transform an image using a color matrix.
|
||||
*
|
||||
* Converts a 2D array of vectors of unsigned bytes, multiplying each vectors by a 4x4 matrix
|
||||
* and adding an optional vector.
|
||||
*
|
||||
* Each input vector is composed of 1-4 unsigned bytes. If less than 4 bytes, it's extended to
|
||||
* 4, padding with zeroes. The unsigned bytes are converted from 0-255 to 0.0-1.0 floats
|
||||
* before the multiplication is done.
|
||||
*
|
||||
* The resulting value is normalized from 0.0-1.0 to a 0-255 value and stored in the output.
|
||||
* If the output vector size is less than four, the unused channels are discarded.
|
||||
*
|
||||
* If addVector is null, a vector of zeroes is added, i.e. a noop.
|
||||
*
|
||||
* Check kIdentityMatrix, kGreyScaleColorMatrix, kRgbToYuvMatrix, and kYuvToRgbMatrix for sample
|
||||
* matrices. The YUV conversion may not work for all color spaces.
|
||||
*
|
||||
* @param in The buffer of the image to be converted.
|
||||
* @param out The buffer that receives the converted image.
|
||||
* @param inputVectorSize The number of bytes in each input cell, a value from 1 to 4.
|
||||
* @param outputVectorSize The number of bytes in each output cell, a value from 1 to 4.
|
||||
* @param sizeX The width of both buffers, as a number of 1 to 4 byte cells.
|
||||
* @param sizeY The height of both buffers, as a number of 1 to 4 byte cells.
|
||||
* @param matrix The 4x4 matrix to multiply, in row major format.
|
||||
* @param addVector A vector of four floats that's added to the result of the multiplication.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void colorMatrix(const void* _Nonnull in, void* _Nonnull out, size_t inputVectorSize,
|
||||
size_t outputVectorSize, size_t sizeX, size_t sizeY,
|
||||
const float* _Nonnull matrix, const float* _Nullable addVector = nullptr,
|
||||
const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* Convolve a ByteArray.
|
||||
*
|
||||
* Applies a 3x3 or 5x5 convolution to the input array using the provided coefficients.
|
||||
*
|
||||
* For 3x3 convolutions, 9 coefficients must be provided. For 5x5, 25 coefficients are needed.
|
||||
* The coefficients should be provided in row-major format.
|
||||
*
|
||||
* When the square extends past the edge, the edge values will be used as replacement for the
|
||||
* values that's are off boundary.
|
||||
*
|
||||
* Each input cell can either be represented by one to four bytes. Each byte is multiplied
|
||||
* and accumulated independently of the other bytes of the cell.
|
||||
*
|
||||
* An optional range parameter can be set to restrict the operation to a rectangular subset
|
||||
* of each buffer. If provided, the range must be wholly contained with the dimensions
|
||||
* described by sizeX and sizeY.
|
||||
*
|
||||
* The input and output buffers must have the same dimensions. Both buffers should be
|
||||
* large enough for sizeX * sizeY * vectorSize bytes. The buffers have a row-major layout.
|
||||
*
|
||||
* @param in The buffer of the image to be blurred.
|
||||
* @param out The buffer that receives the blurred image.
|
||||
* @param vectorSize The number of bytes in each cell, a value from 1 to 4.
|
||||
* @param sizeX The width of both buffers, as a number of 1 or 4 byte cells.
|
||||
* @param sizeY The height of both buffers, as a number of 1 or 4 byte cells.
|
||||
* @param coefficients 9 or 25 multipliers.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void convolve3x3(const void* _Nonnull in, void* _Nonnull out, size_t vectorSize, size_t sizeX,
|
||||
size_t sizeY, const float* _Nonnull coefficients,
|
||||
const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
void convolve5x5(const void* _Nonnull in, void* _Nonnull out, size_t vectorSize, size_t sizeX,
|
||||
size_t sizeY, const float* _Nonnull coefficients,
|
||||
const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* Compute the histogram of an image.
|
||||
*
|
||||
* Tallies how many times each of the 256 possible values of a byte is found in the input.
|
||||
*
|
||||
* An input cell can be represented by one to four bytes. The tally is done independently
|
||||
* for each of the bytes of the cell. Correspondingly, the out array will have
|
||||
* 256 * vectorSize entries. The counts for value 0 are consecutive, followed by those for
|
||||
* value 1, etc.
|
||||
*
|
||||
* An optional range parameter can be set to restrict the operation to a rectangular subset
|
||||
* of each buffer. If provided, the range must be wholly contained with the dimensions
|
||||
* described by sizeX and sizeY.
|
||||
*
|
||||
* The source buffers should be large enough for sizeX * sizeY * vectorSize bytes. The buffers
|
||||
* have a row-major layout. The out buffer should be large enough for 256 * vectorSize ints.
|
||||
*
|
||||
* @param in The buffer of the image to be analyzed.
|
||||
* @param out The resulting vector of counts.
|
||||
* @param sizeX The width of the input buffers, as a number of 1 or 4 byte cells.
|
||||
* @param sizeY The height of the input buffers, as a number of 1 or 4 byte cells.
|
||||
* @param vectorSize The number of bytes in each cell, a value from 1 to 4.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void histogram(const uint8_t* _Nonnull in, int32_t* _Nonnull out, size_t sizeX, size_t sizeY,
|
||||
size_t vectorSize, const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* Compute the histogram of the dot product of an image.
|
||||
*
|
||||
* This method supports cells of 1 to 4 bytes in length. For each cell of the array,
|
||||
* the dot product of its bytes with the provided coefficients is computed. The resulting
|
||||
* floating point value is converted to an unsigned byte and tallied in the histogram.
|
||||
*
|
||||
* If coefficients is null, the coefficients used for RGBA luminosity calculation will be used,
|
||||
* i.e. the values [0.299f, 0.587f, 0.114f, 0.f].
|
||||
*
|
||||
* Each coefficients must be >= 0 and their sum must be 1.0 or less. There must be the same
|
||||
* number of coefficients as vectorSize.
|
||||
*
|
||||
* An optional range parameter can be set to restrict the operation to a rectangular subset
|
||||
* of each buffer. If provided, the range must be wholly contained with the dimensions
|
||||
* described by sizeX and sizeY.
|
||||
*
|
||||
* The source buffers should be large enough for sizeX * sizeY * vectorSize bytes. The buffers
|
||||
* have a row-major layout. The out array should be large enough for 256 ints.
|
||||
*
|
||||
* @param in The buffer of the image to be analyzed.
|
||||
* @param out The resulting vector of counts.
|
||||
* @param sizeX The width of the input buffers, as a number of 1 or 4 byte cells.
|
||||
* @param sizeY The height of the input buffers, as a number of 1 or 4 byte cells.
|
||||
* @param vectorSize The number of bytes in each cell, a value from 1 to 4.
|
||||
* @param coefficients The values used for the dot product. Can be nullptr.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void histogramDot(const uint8_t* _Nonnull in, int32_t* _Nonnull out, size_t sizeX, size_t sizeY,
|
||||
size_t vectorSize, const float* _Nullable coefficients,
|
||||
const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* Transform an image using a look up table
|
||||
*
|
||||
* Transforms an image by using a per-channel lookup table. Each channel of the input has an
|
||||
* independent lookup table. The tables are 256 entries in size and can cover the full value
|
||||
* range of a byte.
|
||||
*
|
||||
* The input array should be in RGBA format, where four consecutive bytes form an cell.
|
||||
*
|
||||
* An optional range parameter can be set to restrict the operation to a rectangular subset
|
||||
* of each buffer. If provided, the range must be wholly contained with the dimensions
|
||||
* described by sizeX and sizeY.
|
||||
*
|
||||
* The input and output buffers must have the same dimensions. Both buffers should be
|
||||
* large enough for sizeX * sizeY * vectorSize bytes. The buffers have a row-major layout.
|
||||
*
|
||||
* @param in The buffer of the image to be transformed.
|
||||
* @param out The buffer that receives the transformed image.
|
||||
* @param sizeX The width of both buffers, as a number of 4 byte cells.
|
||||
* @param sizeY The height of both buffers, as a number of 4 byte cells.
|
||||
* @param red An array of 256 values that's used to convert the R channel.
|
||||
* @param green An array of 256 values that's used to convert the G channel.
|
||||
* @param blue An array of 256 values that's used to convert the B channel.
|
||||
* @param alpha An array of 256 values that's used to convert the A channel.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void lut(const uint8_t* _Nonnull in, uint8_t* _Nonnull out, size_t sizeX, size_t sizeY,
|
||||
const uint8_t* _Nonnull red, const uint8_t* _Nonnull green,
|
||||
const uint8_t* _Nonnull blue, const uint8_t* _Nonnull alpha,
|
||||
const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* Transform an image using a 3D look up table
|
||||
*
|
||||
* Transforms an image, converting RGB to RGBA by using a 3D lookup table. The incoming R, G,
|
||||
* and B values are normalized to the dimensions of the provided 3D buffer. The eight nearest
|
||||
* values in that 3D buffer are sampled and linearly interpolated. The resulting RGBA entry
|
||||
* is stored in the output.
|
||||
*
|
||||
* The input array should be in RGBA format, where four consecutive bytes form an cell.
|
||||
* The fourth byte of each input cell is ignored.
|
||||
*
|
||||
* An optional range parameter can be set to restrict the operation to a rectangular subset
|
||||
* of each buffer. If provided, the range must be wholly contained with the dimensions
|
||||
* described by sizeX and sizeY.
|
||||
*
|
||||
* The input and output buffers must have the same dimensions. Both buffers should be
|
||||
* large enough for sizeX * sizeY * vectorSize bytes. The buffers have a row-major layout.
|
||||
*
|
||||
* @param in The buffer of the image to be transformed.
|
||||
* @param out The buffer that receives the transformed image.
|
||||
* @param sizeX The width of both buffers, as a number of 4 byte cells.
|
||||
* @param sizeY The height of both buffers, as a number of 4 byte cells.
|
||||
* @param cube The translation cube, in row major-format.
|
||||
* @param cubeSizeX The number of RGBA entries in the cube in the X direction.
|
||||
* @param cubeSizeY The number of RGBA entries in the cube in the Y direction.
|
||||
* @param cubeSizeZ The number of RGBA entries in the cube in the Z direction.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void lut3d(const uint8_t* _Nonnull in, uint8_t* _Nonnull out, size_t sizeX, size_t sizeY,
|
||||
const uint8_t* _Nonnull cube, size_t cubeSizeX, size_t cubeSizeY, size_t cubeSizeZ,
|
||||
const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* Resize an image.
|
||||
*
|
||||
* Resizes an image using bicubic interpolation.
|
||||
*
|
||||
* This method supports cells of 1 to 4 bytes in length. Each byte of the cell is
|
||||
* interpolated independently from the others.
|
||||
*
|
||||
* An optional range parameter can be set to restrict the operation to a rectangular subset
|
||||
* of the output buffer. The corresponding scaled range of the input will be used. If provided,
|
||||
* the range must be wholly contained with the dimensions described by outputSizeX and
|
||||
* outputSizeY.
|
||||
*
|
||||
* The input and output buffers have a row-major layout. Both buffers should be
|
||||
* large enough for sizeX * sizeY * vectorSize bytes.
|
||||
*
|
||||
* @param in The buffer of the image to be resized.
|
||||
* @param out The buffer that receives the resized image.
|
||||
* @param inputSizeX The width of the input buffer, as a number of 1-4 byte cells.
|
||||
* @param inputSizeY The height of the input buffer, as a number of 1-4 byte cells.
|
||||
* @param vectorSize The number of bytes in each cell of both buffers. A value from 1 to 4.
|
||||
* @param outputSizeX The width of the output buffer, as a number of 1-4 byte cells.
|
||||
* @param outputSizeY The height of the output buffer, as a number of 1-4 byte cells.
|
||||
* @param restriction When not null, restricts the operation to a 2D range of pixels.
|
||||
*/
|
||||
void resize(const uint8_t* _Nonnull in, uint8_t* _Nonnull out, size_t inputSizeX,
|
||||
size_t inputSizeY, size_t vectorSize, size_t outputSizeX, size_t outputSizeY,
|
||||
const Restriction* _Nullable restriction = nullptr);
|
||||
|
||||
/**
|
||||
* The YUV formats supported by yuvToRgb.
|
||||
*/
|
||||
enum class YuvFormat {
|
||||
NV21 = 0x11,
|
||||
YV12 = 0x32315659,
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert an image from YUV to RGB.
|
||||
*
|
||||
* Converts an Android YUV buffer to RGB. The input allocation should be
|
||||
* supplied in a supported YUV format as a YUV cell Allocation.
|
||||
* The output is RGBA; the alpha channel will be set to 255.
|
||||
*
|
||||
* Note that for YV12 and a sizeX that's not a multiple of 32, the
|
||||
* RenderScript Intrinsic may not have converted the image correctly.
|
||||
* This Toolkit method should.
|
||||
*
|
||||
* @param in The buffer of the image to be converted.
|
||||
* @param out The buffer that receives the converted image.
|
||||
* @param sizeX The width in pixels of the image. Must be even.
|
||||
* @param sizeY The height in pixels of the image.
|
||||
* @param format Either YV12 or NV21.
|
||||
*/
|
||||
void yuvToRgb(const uint8_t* _Nonnull in, uint8_t* _Nonnull out, size_t sizeX, size_t sizeY,
|
||||
YuvFormat format);
|
||||
};
|
||||
|
||||
} // namespace renderscript
|
||||
|
||||
#endif // ANDROID_RENDERSCRIPT_TOOLKIT_TOOLKIT_H
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,186 @@
|
|||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_BUILTIN_OPS_H_
|
||||
#define TENSORFLOW_LITE_BUILTIN_OPS_H_
|
||||
|
||||
// DO NOT EDIT MANUALLY: This file is automatically generated by
|
||||
// `schema/builtin_ops_header/generator.cc`.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// The enum for builtin operators.
|
||||
// Note: CUSTOM, DELEGATE, and PLACEHOLDER_FOR_GREATER_OP_CODES are 3 special
|
||||
// ops which are not real built-in ops.
|
||||
typedef enum {
|
||||
kTfLiteBuiltinAdd = 0,
|
||||
kTfLiteBuiltinAveragePool2d = 1,
|
||||
kTfLiteBuiltinConcatenation = 2,
|
||||
kTfLiteBuiltinConv2d = 3,
|
||||
kTfLiteBuiltinDepthwiseConv2d = 4,
|
||||
kTfLiteBuiltinDepthToSpace = 5,
|
||||
kTfLiteBuiltinDequantize = 6,
|
||||
kTfLiteBuiltinEmbeddingLookup = 7,
|
||||
kTfLiteBuiltinFloor = 8,
|
||||
kTfLiteBuiltinFullyConnected = 9,
|
||||
kTfLiteBuiltinHashtableLookup = 10,
|
||||
kTfLiteBuiltinL2Normalization = 11,
|
||||
kTfLiteBuiltinL2Pool2d = 12,
|
||||
kTfLiteBuiltinLocalResponseNormalization = 13,
|
||||
kTfLiteBuiltinLogistic = 14,
|
||||
kTfLiteBuiltinLshProjection = 15,
|
||||
kTfLiteBuiltinLstm = 16,
|
||||
kTfLiteBuiltinMaxPool2d = 17,
|
||||
kTfLiteBuiltinMul = 18,
|
||||
kTfLiteBuiltinRelu = 19,
|
||||
kTfLiteBuiltinReluN1To1 = 20,
|
||||
kTfLiteBuiltinRelu6 = 21,
|
||||
kTfLiteBuiltinReshape = 22,
|
||||
kTfLiteBuiltinResizeBilinear = 23,
|
||||
kTfLiteBuiltinRnn = 24,
|
||||
kTfLiteBuiltinSoftmax = 25,
|
||||
kTfLiteBuiltinSpaceToDepth = 26,
|
||||
kTfLiteBuiltinSvdf = 27,
|
||||
kTfLiteBuiltinTanh = 28,
|
||||
kTfLiteBuiltinConcatEmbeddings = 29,
|
||||
kTfLiteBuiltinSkipGram = 30,
|
||||
kTfLiteBuiltinCall = 31,
|
||||
kTfLiteBuiltinCustom = 32,
|
||||
kTfLiteBuiltinEmbeddingLookupSparse = 33,
|
||||
kTfLiteBuiltinPad = 34,
|
||||
kTfLiteBuiltinUnidirectionalSequenceRnn = 35,
|
||||
kTfLiteBuiltinGather = 36,
|
||||
kTfLiteBuiltinBatchToSpaceNd = 37,
|
||||
kTfLiteBuiltinSpaceToBatchNd = 38,
|
||||
kTfLiteBuiltinTranspose = 39,
|
||||
kTfLiteBuiltinMean = 40,
|
||||
kTfLiteBuiltinSub = 41,
|
||||
kTfLiteBuiltinDiv = 42,
|
||||
kTfLiteBuiltinSqueeze = 43,
|
||||
kTfLiteBuiltinUnidirectionalSequenceLstm = 44,
|
||||
kTfLiteBuiltinStridedSlice = 45,
|
||||
kTfLiteBuiltinBidirectionalSequenceRnn = 46,
|
||||
kTfLiteBuiltinExp = 47,
|
||||
kTfLiteBuiltinTopkV2 = 48,
|
||||
kTfLiteBuiltinSplit = 49,
|
||||
kTfLiteBuiltinLogSoftmax = 50,
|
||||
kTfLiteBuiltinDelegate = 51,
|
||||
kTfLiteBuiltinBidirectionalSequenceLstm = 52,
|
||||
kTfLiteBuiltinCast = 53,
|
||||
kTfLiteBuiltinPrelu = 54,
|
||||
kTfLiteBuiltinMaximum = 55,
|
||||
kTfLiteBuiltinArgMax = 56,
|
||||
kTfLiteBuiltinMinimum = 57,
|
||||
kTfLiteBuiltinLess = 58,
|
||||
kTfLiteBuiltinNeg = 59,
|
||||
kTfLiteBuiltinPadv2 = 60,
|
||||
kTfLiteBuiltinGreater = 61,
|
||||
kTfLiteBuiltinGreaterEqual = 62,
|
||||
kTfLiteBuiltinLessEqual = 63,
|
||||
kTfLiteBuiltinSelect = 64,
|
||||
kTfLiteBuiltinSlice = 65,
|
||||
kTfLiteBuiltinSin = 66,
|
||||
kTfLiteBuiltinTransposeConv = 67,
|
||||
kTfLiteBuiltinSparseToDense = 68,
|
||||
kTfLiteBuiltinTile = 69,
|
||||
kTfLiteBuiltinExpandDims = 70,
|
||||
kTfLiteBuiltinEqual = 71,
|
||||
kTfLiteBuiltinNotEqual = 72,
|
||||
kTfLiteBuiltinLog = 73,
|
||||
kTfLiteBuiltinSum = 74,
|
||||
kTfLiteBuiltinSqrt = 75,
|
||||
kTfLiteBuiltinRsqrt = 76,
|
||||
kTfLiteBuiltinShape = 77,
|
||||
kTfLiteBuiltinPow = 78,
|
||||
kTfLiteBuiltinArgMin = 79,
|
||||
kTfLiteBuiltinFakeQuant = 80,
|
||||
kTfLiteBuiltinReduceProd = 81,
|
||||
kTfLiteBuiltinReduceMax = 82,
|
||||
kTfLiteBuiltinPack = 83,
|
||||
kTfLiteBuiltinLogicalOr = 84,
|
||||
kTfLiteBuiltinOneHot = 85,
|
||||
kTfLiteBuiltinLogicalAnd = 86,
|
||||
kTfLiteBuiltinLogicalNot = 87,
|
||||
kTfLiteBuiltinUnpack = 88,
|
||||
kTfLiteBuiltinReduceMin = 89,
|
||||
kTfLiteBuiltinFloorDiv = 90,
|
||||
kTfLiteBuiltinReduceAny = 91,
|
||||
kTfLiteBuiltinSquare = 92,
|
||||
kTfLiteBuiltinZerosLike = 93,
|
||||
kTfLiteBuiltinFill = 94,
|
||||
kTfLiteBuiltinFloorMod = 95,
|
||||
kTfLiteBuiltinRange = 96,
|
||||
kTfLiteBuiltinResizeNearestNeighbor = 97,
|
||||
kTfLiteBuiltinLeakyRelu = 98,
|
||||
kTfLiteBuiltinSquaredDifference = 99,
|
||||
kTfLiteBuiltinMirrorPad = 100,
|
||||
kTfLiteBuiltinAbs = 101,
|
||||
kTfLiteBuiltinSplitV = 102,
|
||||
kTfLiteBuiltinUnique = 103,
|
||||
kTfLiteBuiltinCeil = 104,
|
||||
kTfLiteBuiltinReverseV2 = 105,
|
||||
kTfLiteBuiltinAddN = 106,
|
||||
kTfLiteBuiltinGatherNd = 107,
|
||||
kTfLiteBuiltinCos = 108,
|
||||
kTfLiteBuiltinWhere = 109,
|
||||
kTfLiteBuiltinRank = 110,
|
||||
kTfLiteBuiltinElu = 111,
|
||||
kTfLiteBuiltinReverseSequence = 112,
|
||||
kTfLiteBuiltinMatrixDiag = 113,
|
||||
kTfLiteBuiltinQuantize = 114,
|
||||
kTfLiteBuiltinMatrixSetDiag = 115,
|
||||
kTfLiteBuiltinRound = 116,
|
||||
kTfLiteBuiltinHardSwish = 117,
|
||||
kTfLiteBuiltinIf = 118,
|
||||
kTfLiteBuiltinWhile = 119,
|
||||
kTfLiteBuiltinNonMaxSuppressionV4 = 120,
|
||||
kTfLiteBuiltinNonMaxSuppressionV5 = 121,
|
||||
kTfLiteBuiltinScatterNd = 122,
|
||||
kTfLiteBuiltinSelectV2 = 123,
|
||||
kTfLiteBuiltinDensify = 124,
|
||||
kTfLiteBuiltinSegmentSum = 125,
|
||||
kTfLiteBuiltinBatchMatmul = 126,
|
||||
kTfLiteBuiltinPlaceholderForGreaterOpCodes = 127,
|
||||
kTfLiteBuiltinCumsum = 128,
|
||||
kTfLiteBuiltinCallOnce = 129,
|
||||
kTfLiteBuiltinBroadcastTo = 130,
|
||||
kTfLiteBuiltinRfft2d = 131,
|
||||
kTfLiteBuiltinConv3d = 132,
|
||||
kTfLiteBuiltinImag = 133,
|
||||
kTfLiteBuiltinReal = 134,
|
||||
kTfLiteBuiltinComplexAbs = 135,
|
||||
kTfLiteBuiltinHashtable = 136,
|
||||
kTfLiteBuiltinHashtableFind = 137,
|
||||
kTfLiteBuiltinHashtableImport = 138,
|
||||
kTfLiteBuiltinHashtableSize = 139,
|
||||
kTfLiteBuiltinReduceAll = 140,
|
||||
kTfLiteBuiltinConv3dTranspose = 141,
|
||||
kTfLiteBuiltinVarHandle = 142,
|
||||
kTfLiteBuiltinReadVariable = 143,
|
||||
kTfLiteBuiltinAssignVariable = 144,
|
||||
kTfLiteBuiltinBroadcastArgs = 145,
|
||||
kTfLiteBuiltinRandomStandardNormal = 146,
|
||||
kTfLiteBuiltinBucketize = 147,
|
||||
kTfLiteBuiltinRandomUniform = 148,
|
||||
kTfLiteBuiltinMultinomial = 149,
|
||||
kTfLiteBuiltinGelu = 150,
|
||||
} TfLiteBuiltinOperator;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
#endif // TENSORFLOW_LITE_BUILTIN_OPS_H_
|
|
@ -0,0 +1,296 @@
|
|||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_C_C_API_H_
|
||||
#define TENSORFLOW_LITE_C_C_API_H_
|
||||
|
||||
#include <stdarg.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_types.h" // IWYU pragma: export
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
/// C API for TensorFlow Lite.
|
||||
///
|
||||
/// The API leans towards simplicity and uniformity instead of convenience, as
|
||||
/// most usage will be by language-specific wrappers. It provides largely the
|
||||
/// same set of functionality as that of the C++ TensorFlow Lite `Interpreter`
|
||||
/// API, but is useful for shared libraries where having a stable ABI boundary
|
||||
/// is important.
|
||||
///
|
||||
/// Conventions:
|
||||
/// * We use the prefix TfLite for everything in the API.
|
||||
/// * size_t is used to represent byte sizes of objects that are
|
||||
/// materialized in the address space of the calling process.
|
||||
/// * int is used as an index into arrays.
|
||||
///
|
||||
/// Usage:
|
||||
/// <pre><code>
|
||||
/// // Create the model and interpreter options.
|
||||
/// TfLiteModel* model = TfLiteModelCreateFromFile("/path/to/model.tflite");
|
||||
/// TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
|
||||
/// TfLiteInterpreterOptionsSetNumThreads(options, 2);
|
||||
///
|
||||
/// // Create the interpreter.
|
||||
/// TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
|
||||
///
|
||||
/// // Allocate tensors and populate the input tensor data.
|
||||
/// TfLiteInterpreterAllocateTensors(interpreter);
|
||||
/// TfLiteTensor* input_tensor =
|
||||
/// TfLiteInterpreterGetInputTensor(interpreter, 0);
|
||||
/// TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
|
||||
/// input.size() * sizeof(float));
|
||||
///
|
||||
/// // Execute inference.
|
||||
/// TfLiteInterpreterInvoke(interpreter);
|
||||
///
|
||||
/// // Extract the output tensor data.
|
||||
/// const TfLiteTensor* output_tensor =
|
||||
// TfLiteInterpreterGetOutputTensor(interpreter, 0);
|
||||
/// TfLiteTensorCopyToBuffer(output_tensor, output.data(),
|
||||
/// output.size() * sizeof(float));
|
||||
///
|
||||
/// // Dispose of the model and interpreter objects.
|
||||
/// TfLiteInterpreterDelete(interpreter);
|
||||
/// TfLiteInterpreterOptionsDelete(options);
|
||||
/// TfLiteModelDelete(model);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Opaque types used by the C API.
|
||||
|
||||
// TfLiteModel wraps a loaded TensorFlow Lite model.
|
||||
typedef struct TfLiteModel TfLiteModel;
|
||||
|
||||
// TfLiteInterpreterOptions allows customized interpreter configuration.
|
||||
typedef struct TfLiteInterpreterOptions TfLiteInterpreterOptions;
|
||||
|
||||
// Allows delegation of nodes to alternative backends.
|
||||
typedef struct TfLiteDelegate TfLiteDelegate;
|
||||
|
||||
// TfLiteInterpreter provides inference from a provided model.
|
||||
typedef struct TfLiteInterpreter TfLiteInterpreter;
|
||||
|
||||
// A tensor in the interpreter system which is a wrapper around a buffer of
|
||||
// data including a dimensionality (or NULL if not currently defined).
|
||||
typedef struct TfLiteTensor TfLiteTensor;
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TfLiteVersion returns a string describing version information of the
|
||||
// TensorFlow Lite library. TensorFlow Lite uses semantic versioning.
|
||||
TFL_CAPI_EXPORT extern const char* TfLiteVersion(void);
|
||||
|
||||
// Returns a model from the provided buffer, or null on failure.
|
||||
//
|
||||
// NOTE: The caller retains ownership of the `model_data` and should ensure that
|
||||
// the lifetime of the `model_data` must be at least as long as the lifetime
|
||||
// of the `TfLiteModel`.
|
||||
TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreate(const void* model_data,
|
||||
size_t model_size);
|
||||
|
||||
// Returns a model from the provided file, or null on failure.
|
||||
TFL_CAPI_EXPORT extern TfLiteModel* TfLiteModelCreateFromFile(
|
||||
const char* model_path);
|
||||
|
||||
// Destroys the model instance.
|
||||
TFL_CAPI_EXPORT extern void TfLiteModelDelete(TfLiteModel* model);
|
||||
|
||||
// Returns a new interpreter options instances.
|
||||
TFL_CAPI_EXPORT extern TfLiteInterpreterOptions*
|
||||
TfLiteInterpreterOptionsCreate();
|
||||
|
||||
// Destroys the interpreter options instance.
|
||||
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsDelete(
|
||||
TfLiteInterpreterOptions* options);
|
||||
|
||||
// Sets the number of CPU threads to use for the interpreter.
|
||||
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetNumThreads(
|
||||
TfLiteInterpreterOptions* options, int32_t num_threads);
|
||||
|
||||
// Adds a delegate to be applied during `TfLiteInterpreter` creation.
|
||||
//
|
||||
// If delegate application fails, interpreter creation will also fail with an
|
||||
// associated error logged.
|
||||
//
|
||||
// NOTE: The caller retains ownership of the delegate and should ensure that it
|
||||
// remains valid for the duration of any created interpreter's lifetime.
|
||||
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsAddDelegate(
|
||||
TfLiteInterpreterOptions* options, TfLiteDelegate* delegate);
|
||||
|
||||
// Sets a custom error reporter for interpreter execution.
|
||||
//
|
||||
// * `reporter` takes the provided `user_data` object, as well as a C-style
|
||||
// format string and arg list (see also vprintf).
|
||||
// * `user_data` is optional. If non-null, it is owned by the client and must
|
||||
// remain valid for the duration of the interpreter lifetime.
|
||||
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetErrorReporter(
|
||||
TfLiteInterpreterOptions* options,
|
||||
void (*reporter)(void* user_data, const char* format, va_list args),
|
||||
void* user_data);
|
||||
|
||||
// Returns a new interpreter using the provided model and options, or null on
|
||||
// failure.
|
||||
//
|
||||
// * `model` must be a valid model instance. The caller retains ownership of the
|
||||
// object, and can destroy it immediately after creating the interpreter; the
|
||||
// interpreter will maintain its own reference to the underlying model data.
|
||||
// * `optional_options` may be null. The caller retains ownership of the object,
|
||||
// and can safely destroy it immediately after creating the interpreter.
|
||||
//
|
||||
// NOTE: The client *must* explicitly allocate tensors before attempting to
|
||||
// access input tensor data or invoke the interpreter.
|
||||
TFL_CAPI_EXPORT extern TfLiteInterpreter* TfLiteInterpreterCreate(
|
||||
const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options);
|
||||
|
||||
// Destroys the interpreter.
|
||||
TFL_CAPI_EXPORT extern void TfLiteInterpreterDelete(
|
||||
TfLiteInterpreter* interpreter);
|
||||
|
||||
// Returns the number of input tensors associated with the model.
|
||||
TFL_CAPI_EXPORT extern int32_t TfLiteInterpreterGetInputTensorCount(
|
||||
const TfLiteInterpreter* interpreter);
|
||||
|
||||
// Returns the tensor associated with the input index.
|
||||
// REQUIRES: 0 <= input_index < TfLiteInterpreterGetInputTensorCount(tensor)
|
||||
TFL_CAPI_EXPORT extern TfLiteTensor* TfLiteInterpreterGetInputTensor(
|
||||
const TfLiteInterpreter* interpreter, int32_t input_index);
|
||||
|
||||
// Resizes the specified input tensor.
|
||||
//
|
||||
// NOTE: After a resize, the client *must* explicitly allocate tensors before
|
||||
// attempting to access the resized tensor data or invoke the interpreter.
|
||||
//
|
||||
// REQUIRES: 0 <= input_index < TfLiteInterpreterGetInputTensorCount(tensor)
|
||||
//
|
||||
// This function makes a copy of the input dimensions, so the client can safely
|
||||
// deallocate `input_dims` immediately after this function returns.
|
||||
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResizeInputTensor(
|
||||
TfLiteInterpreter* interpreter, int32_t input_index, const int* input_dims,
|
||||
int32_t input_dims_size);
|
||||
|
||||
// Updates allocations for all tensors, resizing dependent tensors using the
|
||||
// specified input tensor dimensionality.
|
||||
//
|
||||
// This is a relatively expensive operation, and need only be called after
|
||||
// creating the graph and/or resizing any inputs.
|
||||
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterAllocateTensors(
|
||||
TfLiteInterpreter* interpreter);
|
||||
|
||||
// Runs inference for the loaded graph.
|
||||
//
|
||||
// Before calling this function, the caller should first invoke
|
||||
// TfLiteInterpreterAllocateTensors() and should also set the values for the
|
||||
// input tensors. After successfully calling this function, the values for the
|
||||
// output tensors will be set.
|
||||
//
|
||||
// NOTE: It is possible that the interpreter is not in a ready state to
|
||||
// evaluate (e.g., if AllocateTensors() hasn't been called, or if a
|
||||
// ResizeInputTensor() has been performed without a subsequent call to
|
||||
// AllocateTensors()).
|
||||
//
|
||||
// If the (experimental!) delegate fallback option was enabled in the
|
||||
// interpreter options, then the interpreter will automatically fall back to
|
||||
// not using any delegates if execution with delegates fails. For details, see
|
||||
// TfLiteInterpreterOptionsSetEnableDelegateFallback in c_api_experimental.h.
|
||||
//
|
||||
// Returns one of the following status codes:
|
||||
// - kTfLiteOk: Success. Output is valid.
|
||||
// - kTfLiteDelegateError: Execution with delegates failed, due to a problem
|
||||
// with the delegate(s). If fallback was not enabled, output is invalid.
|
||||
// If fallback was enabled, this return value indicates that fallback
|
||||
// succeeded, the output is valid, and all delegates previously applied to
|
||||
// the interpreter have been undone.
|
||||
// - kTfLiteApplicationError: Same as for kTfLiteDelegateError, except that
|
||||
// the problem was not with the delegate itself, but rather was
|
||||
// due to an incompatibility between the delegate(s) and the
|
||||
// interpreter or model.
|
||||
// - kTfLiteError: Unexpected/runtime failure. Output is invalid.
|
||||
|
||||
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterInvoke(
|
||||
TfLiteInterpreter* interpreter);
|
||||
|
||||
// Returns the number of output tensors associated with the model.
|
||||
TFL_CAPI_EXPORT extern int32_t TfLiteInterpreterGetOutputTensorCount(
|
||||
const TfLiteInterpreter* interpreter);
|
||||
|
||||
// Returns the tensor associated with the output index.
|
||||
// REQUIRES: 0 <= output_index < TfLiteInterpreterGetOutputTensorCount(tensor)
|
||||
//
|
||||
// NOTE: The shape and underlying data buffer for output tensors may be not
|
||||
// be available until after the output tensor has been both sized and allocated.
|
||||
// In general, best practice is to interact with the output tensor *after*
|
||||
// calling TfLiteInterpreterInvoke().
|
||||
TFL_CAPI_EXPORT extern const TfLiteTensor* TfLiteInterpreterGetOutputTensor(
|
||||
const TfLiteInterpreter* interpreter, int32_t output_index);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// TfLiteTensor wraps data associated with a graph tensor.
|
||||
//
|
||||
// Note that, while the TfLiteTensor struct is not currently opaque, and its
|
||||
// fields can be accessed directly, these methods are still convenient for
|
||||
// language bindings. In the future the tensor struct will likely be made opaque
|
||||
// in the public API.
|
||||
|
||||
// Returns the type of a tensor element.
|
||||
TFL_CAPI_EXPORT extern TfLiteType TfLiteTensorType(const TfLiteTensor* tensor);
|
||||
|
||||
// Returns the number of dimensions that the tensor has.
|
||||
TFL_CAPI_EXPORT extern int32_t TfLiteTensorNumDims(const TfLiteTensor* tensor);
|
||||
|
||||
// Returns the length of the tensor in the "dim_index" dimension.
|
||||
// REQUIRES: 0 <= dim_index < TFLiteTensorNumDims(tensor)
|
||||
TFL_CAPI_EXPORT extern int32_t TfLiteTensorDim(const TfLiteTensor* tensor,
|
||||
int32_t dim_index);
|
||||
|
||||
// Returns the size of the underlying data in bytes.
|
||||
TFL_CAPI_EXPORT extern size_t TfLiteTensorByteSize(const TfLiteTensor* tensor);
|
||||
|
||||
// Returns a pointer to the underlying data buffer.
|
||||
//
|
||||
// NOTE: The result may be null if tensors have not yet been allocated, e.g.,
|
||||
// if the Tensor has just been created or resized and `TfLiteAllocateTensors()`
|
||||
// has yet to be called, or if the output tensor is dynamically sized and the
|
||||
// interpreter hasn't been invoked.
|
||||
TFL_CAPI_EXPORT extern void* TfLiteTensorData(const TfLiteTensor* tensor);
|
||||
|
||||
// Returns the (null-terminated) name of the tensor.
|
||||
TFL_CAPI_EXPORT extern const char* TfLiteTensorName(const TfLiteTensor* tensor);
|
||||
|
||||
// Returns the parameters for asymmetric quantization. The quantization
|
||||
// parameters are only valid when the tensor type is `kTfLiteUInt8` and the
|
||||
// `scale != 0`. Quantized values can be converted back to float using:
|
||||
// real_value = scale * (quantized_value - zero_point);
|
||||
TFL_CAPI_EXPORT extern TfLiteQuantizationParams TfLiteTensorQuantizationParams(
|
||||
const TfLiteTensor* tensor);
|
||||
|
||||
// Copies from the provided input buffer into the tensor's buffer.
|
||||
// REQUIRES: input_data_size == TfLiteTensorByteSize(tensor)
|
||||
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteTensorCopyFromBuffer(
|
||||
TfLiteTensor* tensor, const void* input_data, size_t input_data_size);
|
||||
|
||||
// Copies to the provided output buffer from the tensor's buffer.
|
||||
// REQUIRES: output_data_size == TfLiteTensorByteSize(tensor)
|
||||
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteTensorCopyToBuffer(
|
||||
const TfLiteTensor* output_tensor, void* output_data,
|
||||
size_t output_data_size);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_C_C_API_H_
|
|
@ -0,0 +1,198 @@
|
|||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_
|
||||
#define TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_
|
||||
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/c_api.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
/// Resets all variable tensors to zero.
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterResetVariableTensors(
|
||||
TfLiteInterpreter* interpreter);
|
||||
|
||||
/// Adds an op registration for a builtin operator.
|
||||
///
|
||||
/// Op registrations are used to map ops referenced in the flatbuffer model
|
||||
/// to executable function pointers (`TfLiteRegistration`s).
|
||||
///
|
||||
/// NOTE: The interpreter will make a shallow copy of `registration` internally,
|
||||
/// so the caller should ensure that its contents (function pointers, etc...)
|
||||
/// remain valid for the duration of the interpreter's lifetime. A common
|
||||
/// practice is making the provided `TfLiteRegistration` instance static.
|
||||
///
|
||||
/// Code that uses this function should NOT call
|
||||
/// `TfLiteInterpreterOptionsSetOpResolver` on the same options object.
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddBuiltinOp(
|
||||
TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
|
||||
const TfLiteRegistration* registration, int32_t min_version,
|
||||
int32_t max_version);
|
||||
|
||||
/// Adds an op registration for a custom operator.
|
||||
///
|
||||
/// Op registrations are used to map ops referenced in the flatbuffer model
|
||||
/// to executable function pointers (`TfLiteRegistration`s).
|
||||
///
|
||||
/// NOTE: The interpreter will make a shallow copy of `registration` internally,
|
||||
/// so the caller should ensure that its contents (function pointers, etc...)
|
||||
/// remain valid for the duration of any created interpreter's lifetime. A
|
||||
/// common practice is making the provided `TfLiteRegistration` instance static.
|
||||
///
|
||||
/// The lifetime of the string pointed to by `name` must be at least as long
|
||||
/// as the lifetime of the `TfLiteInterpreterOptions`.
|
||||
///
|
||||
/// Code that uses this function should NOT call
|
||||
/// `TfLiteInterpreterOptionsSetOpResolver` on the same options object.
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT void TfLiteInterpreterOptionsAddCustomOp(
|
||||
TfLiteInterpreterOptions* options, const char* name,
|
||||
const TfLiteRegistration* registration, int32_t min_version,
|
||||
int32_t max_version);
|
||||
|
||||
/// Registers callbacks for resolving builtin or custom operators.
|
||||
///
|
||||
/// The `TfLiteInterpreterOptionsSetOpResolver` function provides an alternative
|
||||
/// method for registering builtin ops and/or custom ops, by providing operator
|
||||
/// resolver callbacks. Unlike using `TfLiteInterpreterOptionsAddBuiltinOp`
|
||||
/// and/or `TfLiteInterpreterOptionsAddAddCustomOp`, these let you register all
|
||||
/// the operators in a single call.
|
||||
///
|
||||
/// Code that uses this function should NOT call
|
||||
/// `TfLiteInterpreterOptionsAddBuiltin` or
|
||||
/// `TfLiteInterpreterOptionsAddCustomOp` on the same options object.
|
||||
///
|
||||
/// If `op_resolver_user_data` is non-null, its lifetime must be at least as
|
||||
/// long as the lifetime of the `TfLiteInterpreterOptions`.
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
void TfLiteInterpreterOptionsSetOpResolver(
|
||||
TfLiteInterpreterOptions* options,
|
||||
const TfLiteRegistration* (*find_builtin_op)(void* user_data,
|
||||
TfLiteBuiltinOperator op,
|
||||
int version),
|
||||
const TfLiteRegistration* (*find_custom_op)(void* user_data,
|
||||
const char* custom_op,
|
||||
int version),
|
||||
void* op_resolver_user_data);
|
||||
|
||||
/// Returns a new interpreter using the provided model and options, or null on
|
||||
/// failure, where the model uses only the operators explicitly added to the
|
||||
/// options. This is the same as `TFLiteInterpreterCreate` from `c_api.h`,
|
||||
/// except that the only operators that are supported are the ones registered
|
||||
/// in `options` via calls to `TfLiteInterpreterOptionsSetOpResolver`,
|
||||
/// `TfLiteInterpreterOptionsAddBuiltinOp`, and/or
|
||||
/// `TfLiteInterpreterOptionsAddCustomOp`.
|
||||
///
|
||||
/// * `model` must be a valid model instance. The caller retains ownership of
|
||||
/// the object, and can destroy it immediately after creating the interpreter;
|
||||
/// the interpreter will maintain its own reference to the underlying model
|
||||
/// data.
|
||||
/// * `options` should not be null. The caller retains ownership of the object,
|
||||
/// and can safely destroy it immediately after creating the interpreter.
|
||||
///
|
||||
/// NOTE: The client *must* explicitly allocate tensors before attempting to
|
||||
/// access input tensor data or invoke the interpreter.
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT extern TfLiteInterpreter*
|
||||
TfLiteInterpreterCreateWithSelectedOps(const TfLiteModel* model,
|
||||
const TfLiteInterpreterOptions* options);
|
||||
|
||||
/// Enable or disable the NN API delegate for the interpreter (true to enable).
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetUseNNAPI(
|
||||
TfLiteInterpreterOptions* options, bool enable);
|
||||
|
||||
/// Enable or disable CPU fallback for the interpreter (true to enable).
|
||||
/// If enabled, TfLiteInterpreterInvoke will do automatic fallback from
|
||||
/// executing with delegate(s) to regular execution without delegates
|
||||
/// (i.e. on CPU).
|
||||
///
|
||||
/// Allowing the fallback is suitable only if both of the following hold:
|
||||
/// - The caller is known not to cache pointers to tensor data across
|
||||
/// TfLiteInterpreterInvoke calls.
|
||||
/// - The model is not stateful (no variables, no LSTMs) or the state isn't
|
||||
/// needed between batches.
|
||||
///
|
||||
/// When delegate fallback is enabled, TfLiteInterpreterInvoke will
|
||||
/// behave as follows:
|
||||
/// If one or more delegates were set in the interpreter options
|
||||
/// (see TfLiteInterpreterOptionsAddDelegate),
|
||||
/// AND inference fails,
|
||||
/// then the interpreter will fall back to not using any delegates.
|
||||
/// In that case, the previously applied delegate(s) will be automatically
|
||||
/// undone, and an attempt will be made to return the interpreter to an
|
||||
/// invokable state, which may invalidate previous tensor addresses,
|
||||
/// and the inference will be attempted again, using input tensors with
|
||||
/// the same value as previously set.
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT extern void TfLiteInterpreterOptionsSetEnableDelegateFallback(
|
||||
TfLiteInterpreterOptions* options, bool enable);
|
||||
|
||||
// Set if buffer handle output is allowed.
|
||||
//
|
||||
/// When using hardware delegation, Interpreter will make the data of output
|
||||
/// tensors available in `tensor->data` by default. If the application can
|
||||
/// consume the buffer handle directly (e.g. reading output from OpenGL
|
||||
/// texture), it can set this flag to false, so Interpreter won't copy the
|
||||
/// data from buffer handle to CPU memory. WARNING: This is an experimental
|
||||
/// API and subject to change.
|
||||
TFL_CAPI_EXPORT extern void TfLiteSetAllowBufferHandleOutput(
|
||||
const TfLiteInterpreter* interpreter, bool allow_buffer_handle_output);
|
||||
|
||||
/// Allow a delegate to look at the graph and modify the graph to handle
|
||||
/// parts of the graph themselves. After this is called, the graph may
|
||||
/// contain new nodes that replace 1 more nodes.
|
||||
/// 'delegate' must outlive the interpreter.
|
||||
/// Use `TfLiteInterpreterOptionsAddDelegate` instead of this unless
|
||||
/// absolutely required.
|
||||
/// Returns one of the following three status codes:
|
||||
/// 1. kTfLiteOk: Success.
|
||||
/// 2. kTfLiteDelegateError: Delegation failed due to an error in the
|
||||
/// delegate. The Interpreter has been restored to its pre-delegation state.
|
||||
/// NOTE: This undoes all delegates previously applied to the Interpreter.
|
||||
/// 3. kTfLiteError: Unexpected/runtime failure.
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT extern TfLiteStatus TfLiteInterpreterModifyGraphWithDelegate(
|
||||
const TfLiteInterpreter* interpreter, TfLiteDelegate* delegate);
|
||||
|
||||
/// Returns the tensor index corresponding to the input tensor
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT extern int32_t TfLiteInterpreterGetInputTensorIndex(
|
||||
const TfLiteInterpreter* interpreter, int32_t input_index);
|
||||
|
||||
/// Returns the tensor index corresponding to the output tensor
|
||||
///
|
||||
/// WARNING: This is an experimental API and subject to change.
|
||||
TFL_CAPI_EXPORT extern int32_t TfLiteInterpreterGetOutputTensorIndex(
|
||||
const TfLiteInterpreter* interpreter, int32_t output_index);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
||||
#endif // TENSORFLOW_LITE_C_C_API_EXPERIMENTAL_H_
|
|
@ -0,0 +1,117 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file declares types used by the pure C inference API defined in c_api.h,
|
||||
// some of which are also used in the C++ and C kernel and interpreter APIs.
|
||||
|
||||
#ifndef TENSORFLOW_LITE_C_C_API_TYPES_H_
|
||||
#define TENSORFLOW_LITE_C_C_API_TYPES_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Define TFL_CAPI_EXPORT macro to export a function properly with a shared
|
||||
// library.
|
||||
#ifdef SWIG
|
||||
#define TFL_CAPI_EXPORT
|
||||
#elif defined(TFL_STATIC_LIBRARY_BUILD)
|
||||
#define TFL_CAPI_EXPORT
|
||||
#else // not definded TFL_STATIC_LIBRARY_BUILD
|
||||
#if defined(_WIN32)
|
||||
#ifdef TFL_COMPILE_LIBRARY
|
||||
#define TFL_CAPI_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define TFL_CAPI_EXPORT __declspec(dllimport)
|
||||
#endif // TFL_COMPILE_LIBRARY
|
||||
#else
|
||||
#define TFL_CAPI_EXPORT __attribute__((visibility("default")))
|
||||
#endif // _WIN32
|
||||
#endif // SWIG
|
||||
|
||||
// Note that new error status values may be added in future in order to
|
||||
// indicate more fine-grained internal states, therefore, applications should
|
||||
// not rely on status values being members of the enum.
|
||||
typedef enum TfLiteStatus {
|
||||
kTfLiteOk = 0,
|
||||
|
||||
// Generally referring to an error in the runtime (i.e. interpreter)
|
||||
kTfLiteError = 1,
|
||||
|
||||
// Generally referring to an error from a TfLiteDelegate itself.
|
||||
kTfLiteDelegateError = 2,
|
||||
|
||||
// Generally referring to an error in applying a delegate due to
|
||||
// incompatibility between runtime and delegate, e.g., this error is returned
|
||||
// when trying to apply a TF Lite delegate onto a model graph that's already
|
||||
// immutable.
|
||||
kTfLiteApplicationError = 3,
|
||||
|
||||
// Generally referring to serialized delegate data not being found.
|
||||
// See tflite::delegates::Serialization.
|
||||
kTfLiteDelegateDataNotFound = 4,
|
||||
|
||||
// Generally referring to data-writing issues in delegate serialization.
|
||||
// See tflite::delegates::Serialization.
|
||||
kTfLiteDelegateDataWriteError = 5,
|
||||
|
||||
// Generally referring to data-reading issues in delegate serialization.
|
||||
// See tflite::delegates::Serialization.
|
||||
kTfLiteDelegateDataReadError = 6,
|
||||
|
||||
// Generally referring to issues when the TF Lite model has ops that cannot be
|
||||
// resolved at runtime. This could happen when the specific op is not
|
||||
// registered or built with the TF Lite framework.
|
||||
kTfLiteUnresolvedOps = 7,
|
||||
} TfLiteStatus;
|
||||
|
||||
// Types supported by tensor
|
||||
typedef enum {
|
||||
kTfLiteNoType = 0,
|
||||
kTfLiteFloat32 = 1,
|
||||
kTfLiteInt32 = 2,
|
||||
kTfLiteUInt8 = 3,
|
||||
kTfLiteInt64 = 4,
|
||||
kTfLiteString = 5,
|
||||
kTfLiteBool = 6,
|
||||
kTfLiteInt16 = 7,
|
||||
kTfLiteComplex64 = 8,
|
||||
kTfLiteInt8 = 9,
|
||||
kTfLiteFloat16 = 10,
|
||||
kTfLiteFloat64 = 11,
|
||||
kTfLiteComplex128 = 12,
|
||||
kTfLiteUInt64 = 13,
|
||||
kTfLiteResource = 14,
|
||||
kTfLiteVariant = 15,
|
||||
kTfLiteUInt32 = 16,
|
||||
} TfLiteType;
|
||||
|
||||
// Legacy. Will be deprecated in favor of TfLiteAffineQuantization.
|
||||
// If per-layer quantization is specified this field will still be populated in
|
||||
// addition to TfLiteAffineQuantization.
|
||||
// Parameters for asymmetric quantization. Quantized values can be converted
|
||||
// back to float using:
|
||||
// real_value = scale * (quantized_value - zero_point)
|
||||
typedef struct TfLiteQuantizationParams {
|
||||
float scale;
|
||||
int32_t zero_point;
|
||||
} TfLiteQuantizationParams;
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern C
|
||||
#endif
|
||||
#endif // TENSORFLOW_LITE_C_C_API_TYPES_H_
|
|
@ -0,0 +1,990 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file defines common C types and APIs for implementing operations,
|
||||
// delegates and other constructs in TensorFlow Lite. The actual operations and
|
||||
// delegates can be defined using C++, but the interface between the interpreter
|
||||
// and the operations are C.
|
||||
//
|
||||
// Summary of abstractions
|
||||
// TF_LITE_ENSURE - Self-sufficient error checking
|
||||
// TfLiteStatus - Status reporting
|
||||
// TfLiteIntArray - stores tensor shapes (dims),
|
||||
// TfLiteContext - allows an op to access the tensors
|
||||
// TfLiteTensor - tensor (a multidimensional array)
|
||||
// TfLiteNode - a single node or operation
|
||||
// TfLiteRegistration - the implementation of a conceptual operation.
|
||||
// TfLiteDelegate - allows delegation of nodes to alternative backends.
|
||||
//
|
||||
// Some abstractions in this file are created and managed by Interpreter.
|
||||
//
|
||||
// NOTE: The order of values in these structs are "semi-ABI stable". New values
|
||||
// should be added only to the end of structs and never reordered.
|
||||
|
||||
#ifndef TENSORFLOW_LITE_C_COMMON_H_
|
||||
#define TENSORFLOW_LITE_C_COMMON_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/lite/c/c_api_types.h" // IWYU pragma: export
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif // __cplusplus
|
||||
|
||||
// The list of external context types known to TF Lite. This list exists solely
|
||||
// to avoid conflicts and to ensure ops can share the external contexts they
|
||||
// need. Access to the external contexts is controlled by one of the
|
||||
// corresponding support files.
|
||||
typedef enum TfLiteExternalContextType {
|
||||
kTfLiteEigenContext = 0, // include eigen_support.h to use.
|
||||
kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
|
||||
kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
|
||||
kTfLiteCpuBackendContext = 3, // include cpu_backend_context.h to use.
|
||||
kTfLiteMaxExternalContexts = 4
|
||||
} TfLiteExternalContextType;
|
||||
|
||||
// Forward declare so dependent structs and methods can reference these types
|
||||
// prior to the struct definitions.
|
||||
struct TfLiteContext;
|
||||
struct TfLiteDelegate;
|
||||
struct TfLiteRegistration;
|
||||
|
||||
// An external context is a collection of information unrelated to the TF Lite
|
||||
// framework, but useful to a subset of the ops. TF Lite knows very little
|
||||
// about the actual contexts, but it keeps a list of them, and is able to
|
||||
// refresh them if configurations like the number of recommended threads
|
||||
// change.
|
||||
typedef struct TfLiteExternalContext {
|
||||
TfLiteExternalContextType type;
|
||||
TfLiteStatus (*Refresh)(struct TfLiteContext* context);
|
||||
} TfLiteExternalContext;
|
||||
|
||||
#define kTfLiteOptionalTensor (-1)
|
||||
|
||||
// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
|
||||
// indices
|
||||
typedef struct TfLiteIntArray {
|
||||
int size;
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
// Context for why this is needed is in http://b/189926408#comment21
|
||||
int data[1];
|
||||
#elif (!defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
|
||||
__GNUC_MINOR__ >= 1) || \
|
||||
defined(HEXAGON) || \
|
||||
(defined(__clang__) && __clang_major__ == 7 && __clang_minor__ == 1)
|
||||
// gcc 6.1+ have a bug where flexible members aren't properly handled
|
||||
// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
|
||||
int data[0];
|
||||
#else
|
||||
int data[];
|
||||
#endif
|
||||
} TfLiteIntArray;
|
||||
|
||||
// Given the size (number of elements) in a TfLiteIntArray, calculate its size
|
||||
// in bytes.
|
||||
size_t TfLiteIntArrayGetSizeInBytes(int size);
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Create a array of a given `size` (uninitialized entries).
|
||||
// This returns a pointer, that you must free using TfLiteIntArrayFree().
|
||||
TfLiteIntArray* TfLiteIntArrayCreate(int size);
|
||||
#endif
|
||||
|
||||
// Check if two intarrays are equal. Returns 1 if they are equal, 0 otherwise.
|
||||
int TfLiteIntArrayEqual(const TfLiteIntArray* a, const TfLiteIntArray* b);
|
||||
|
||||
// Check if an intarray equals an array. Returns 1 if equals, 0 otherwise.
|
||||
int TfLiteIntArrayEqualsArray(const TfLiteIntArray* a, int b_size,
|
||||
const int b_data[]);
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Create a copy of an array passed as `src`.
|
||||
// You are expected to free memory with TfLiteIntArrayFree
|
||||
TfLiteIntArray* TfLiteIntArrayCopy(const TfLiteIntArray* src);
|
||||
|
||||
// Free memory of array `a`.
|
||||
void TfLiteIntArrayFree(TfLiteIntArray* a);
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// Fixed size list of floats. Used for per-channel quantization.
|
||||
typedef struct TfLiteFloatArray {
|
||||
int size;
|
||||
#if defined(_MSC_VER)
|
||||
// Context for why this is needed is in http://b/189926408#comment21
|
||||
float data[1];
|
||||
#elif (!defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
|
||||
__GNUC_MINOR__ >= 1) || \
|
||||
defined(HEXAGON) || \
|
||||
(defined(__clang__) && __clang_major__ == 7 && __clang_minor__ == 1)
|
||||
// gcc 6.1+ have a bug where flexible members aren't properly handled
|
||||
// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
|
||||
float data[0];
|
||||
#else
|
||||
float data[];
|
||||
#endif
|
||||
} TfLiteFloatArray;
|
||||
|
||||
// Given the size (number of elements) in a TfLiteFloatArray, calculate its size
|
||||
// in bytes.
|
||||
int TfLiteFloatArrayGetSizeInBytes(int size);
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Create a array of a given `size` (uninitialized entries).
|
||||
// This returns a pointer, that you must free using TfLiteFloatArrayFree().
|
||||
TfLiteFloatArray* TfLiteFloatArrayCreate(int size);
|
||||
|
||||
// Free memory of array `a`.
|
||||
void TfLiteFloatArrayFree(TfLiteFloatArray* a);
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// Since we must not depend on any libraries, define a minimal subset of
|
||||
// error macros while avoiding names that have pre-conceived meanings like
|
||||
// assert and check.
|
||||
|
||||
// Try to make all reporting calls through TF_LITE_KERNEL_LOG rather than
|
||||
// calling the context->ReportError function directly, so that message strings
|
||||
// can be stripped out if the binary size needs to be severely optimized.
|
||||
#ifndef TF_LITE_STRIP_ERROR_STRINGS
|
||||
#define TF_LITE_KERNEL_LOG(context, ...) \
|
||||
do { \
|
||||
(context)->ReportError((context), __VA_ARGS__); \
|
||||
} while (false)
|
||||
|
||||
#define TF_LITE_MAYBE_KERNEL_LOG(context, ...) \
|
||||
do { \
|
||||
if ((context) != nullptr) { \
|
||||
(context)->ReportError((context), __VA_ARGS__); \
|
||||
} \
|
||||
} while (false)
|
||||
#else // TF_LITE_STRIP_ERROR_STRINGS
|
||||
#define TF_LITE_KERNEL_LOG(context, ...)
|
||||
#define TF_LITE_MAYBE_KERNEL_LOG(context, ...)
|
||||
#endif // TF_LITE_STRIP_ERROR_STRINGS
|
||||
|
||||
// Check whether value is true, and if not return kTfLiteError from
|
||||
// the current function (and report the error string msg).
|
||||
#define TF_LITE_ENSURE_MSG(context, value, msg) \
|
||||
do { \
|
||||
if (!(value)) { \
|
||||
TF_LITE_KERNEL_LOG((context), __FILE__ " " msg); \
|
||||
return kTfLiteError; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Check whether the value `a` is true, and if not return kTfLiteError from
|
||||
// the current function, while also reporting the location of the error.
|
||||
#define TF_LITE_ENSURE(context, a) \
|
||||
do { \
|
||||
if (!(a)) { \
|
||||
TF_LITE_KERNEL_LOG((context), "%s:%d %s was not true.", __FILE__, \
|
||||
__LINE__, #a); \
|
||||
return kTfLiteError; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TF_LITE_ENSURE_STATUS(a) \
|
||||
do { \
|
||||
const TfLiteStatus s = (a); \
|
||||
if (s != kTfLiteOk) { \
|
||||
return s; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Check whether the value `a == b` is true, and if not return kTfLiteError from
|
||||
// the current function, while also reporting the location of the error.
|
||||
// `a` and `b` may be evaluated more than once, so no side effects or
|
||||
// extremely expensive computations should be done.
|
||||
// NOTE: Use TF_LITE_ENSURE_TYPES_EQ if comparing TfLiteTypes.
|
||||
#define TF_LITE_ENSURE_EQ(context, a, b) \
|
||||
do { \
|
||||
if ((a) != (b)) { \
|
||||
TF_LITE_KERNEL_LOG((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
|
||||
__LINE__, #a, #b, (a), (b)); \
|
||||
return kTfLiteError; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TF_LITE_ENSURE_TYPES_EQ(context, a, b) \
|
||||
do { \
|
||||
if ((a) != (b)) { \
|
||||
TF_LITE_KERNEL_LOG((context), "%s:%d %s != %s (%s != %s)", __FILE__, \
|
||||
__LINE__, #a, #b, TfLiteTypeGetName(a), \
|
||||
TfLiteTypeGetName(b)); \
|
||||
return kTfLiteError; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TF_LITE_ENSURE_NEAR(context, a, b, epsilon) \
|
||||
do { \
|
||||
auto delta = ((a) > (b)) ? ((a) - (b)) : ((b) - (a)); \
|
||||
if (delta > epsilon) { \
|
||||
TF_LITE_KERNEL_LOG((context), "%s:%d %s not near %s (%f != %f)", \
|
||||
__FILE__, __LINE__, #a, #b, static_cast<double>(a), \
|
||||
static_cast<double>(b)); \
|
||||
return kTfLiteError; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TF_LITE_ENSURE_OK(context, status) \
|
||||
do { \
|
||||
const TfLiteStatus s = (status); \
|
||||
if ((s) != kTfLiteOk) { \
|
||||
return s; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Single-precision complex data type compatible with the C99 definition.
|
||||
typedef struct TfLiteComplex64 {
|
||||
float re, im; // real and imaginary parts, respectively.
|
||||
} TfLiteComplex64;
|
||||
|
||||
// Double-precision complex data type compatible with the C99 definition.
|
||||
typedef struct TfLiteComplex128 {
|
||||
double re, im; // real and imaginary parts, respectively.
|
||||
} TfLiteComplex128;
|
||||
|
||||
// Half precision data type compatible with the C99 definition.
|
||||
typedef struct TfLiteFloat16 {
|
||||
uint16_t data;
|
||||
} TfLiteFloat16;
|
||||
|
||||
// Return the name of a given type, for error reporting purposes.
|
||||
const char* TfLiteTypeGetName(TfLiteType type);
|
||||
|
||||
// SupportedQuantizationTypes.
|
||||
typedef enum TfLiteQuantizationType {
|
||||
// No quantization.
|
||||
kTfLiteNoQuantization = 0,
|
||||
// Affine quantization (with support for per-channel quantization).
|
||||
// Corresponds to TfLiteAffineQuantization.
|
||||
kTfLiteAffineQuantization = 1,
|
||||
} TfLiteQuantizationType;
|
||||
|
||||
// Structure specifying the quantization used by the tensor, if-any.
|
||||
typedef struct TfLiteQuantization {
|
||||
// The type of quantization held by params.
|
||||
TfLiteQuantizationType type;
|
||||
// Holds an optional reference to a quantization param structure. The actual
|
||||
// type depends on the value of the `type` field (see the comment there for
|
||||
// the values and corresponding types).
|
||||
void* params;
|
||||
} TfLiteQuantization;
|
||||
|
||||
// Parameters for asymmetric quantization across a dimension (i.e per output
|
||||
// channel quantization).
|
||||
// quantized_dimension specifies which dimension the scales and zero_points
|
||||
// correspond to.
|
||||
// For a particular value in quantized_dimension, quantized values can be
|
||||
// converted back to float using:
|
||||
// real_value = scale * (quantized_value - zero_point)
|
||||
typedef struct TfLiteAffineQuantization {
|
||||
TfLiteFloatArray* scale;
|
||||
TfLiteIntArray* zero_point;
|
||||
int32_t quantized_dimension;
|
||||
} TfLiteAffineQuantization;
|
||||
|
||||
/* A union of pointers that points to memory for a given tensor. */
|
||||
typedef union TfLitePtrUnion {
|
||||
/* Do not access these members directly, if possible, use
|
||||
* GetTensorData<TYPE>(tensor) instead, otherwise only access .data, as other
|
||||
* members are deprecated. */
|
||||
int32_t* i32;
|
||||
uint32_t* u32;
|
||||
int64_t* i64;
|
||||
uint64_t* u64;
|
||||
float* f;
|
||||
TfLiteFloat16* f16;
|
||||
double* f64;
|
||||
char* raw;
|
||||
const char* raw_const;
|
||||
uint8_t* uint8;
|
||||
bool* b;
|
||||
int16_t* i16;
|
||||
TfLiteComplex64* c64;
|
||||
TfLiteComplex128* c128;
|
||||
int8_t* int8;
|
||||
/* Only use this member. */
|
||||
void* data;
|
||||
} TfLitePtrUnion;
|
||||
|
||||
// Memory allocation strategies.
|
||||
// * kTfLiteMmapRo: Read-only memory-mapped data, or data externally allocated.
|
||||
// * kTfLiteArenaRw: Arena allocated with no guarantees about persistence,
|
||||
// and available during eval.
|
||||
// * kTfLiteArenaRwPersistent: Arena allocated but persistent across eval, and
|
||||
// only available during eval.
|
||||
// * kTfLiteDynamic: Allocated during eval, or for string tensors.
|
||||
// * kTfLitePersistentRo: Allocated and populated during prepare. This is
|
||||
// useful for tensors that can be computed during prepare and treated
|
||||
// as constant inputs for downstream ops (also in prepare).
|
||||
// * kTfLiteCustom: Custom memory allocation provided by the user. See
|
||||
// TfLiteCustomAllocation below.
|
||||
typedef enum TfLiteAllocationType {
|
||||
kTfLiteMemNone = 0,
|
||||
kTfLiteMmapRo,
|
||||
kTfLiteArenaRw,
|
||||
kTfLiteArenaRwPersistent,
|
||||
kTfLiteDynamic,
|
||||
kTfLitePersistentRo,
|
||||
kTfLiteCustom,
|
||||
} TfLiteAllocationType;
|
||||
|
||||
// The delegates should use zero or positive integers to represent handles.
|
||||
// -1 is reserved from unallocated status.
|
||||
typedef int TfLiteBufferHandle;
|
||||
enum {
|
||||
kTfLiteNullBufferHandle = -1,
|
||||
};
|
||||
|
||||
// Storage format of each dimension in a sparse tensor.
|
||||
typedef enum TfLiteDimensionType {
|
||||
kTfLiteDimDense = 0,
|
||||
kTfLiteDimSparseCSR,
|
||||
} TfLiteDimensionType;
|
||||
|
||||
// Metadata to encode each dimension in a sparse tensor.
|
||||
typedef struct TfLiteDimensionMetadata {
|
||||
TfLiteDimensionType format;
|
||||
int dense_size;
|
||||
TfLiteIntArray* array_segments;
|
||||
TfLiteIntArray* array_indices;
|
||||
} TfLiteDimensionMetadata;
|
||||
|
||||
// Parameters used to encode a sparse tensor. For detailed explanation of each
|
||||
// field please refer to lite/schema/schema.fbs.
|
||||
typedef struct TfLiteSparsity {
|
||||
TfLiteIntArray* traversal_order;
|
||||
TfLiteIntArray* block_map;
|
||||
TfLiteDimensionMetadata* dim_metadata;
|
||||
int dim_metadata_size;
|
||||
} TfLiteSparsity;
|
||||
|
||||
// Defines a custom memory allocation not owned by the runtime.
|
||||
// `data` should be aligned to kDefaultTensorAlignment defined in
|
||||
// lite/util.h. (Currently 64 bytes)
|
||||
// NOTE: See Interpreter.SetCustomAllocationForTensor for details on usage.
|
||||
typedef struct TfLiteCustomAllocation {
|
||||
void* data;
|
||||
size_t bytes;
|
||||
} TfLiteCustomAllocation;
|
||||
|
||||
// The flags used in `Interpreter::SetCustomAllocationForTensor`.
|
||||
// Note that this is a bitmask, so the values should be 1, 2, 4, 8, ...etc.
|
||||
typedef enum TfLiteCustomAllocationFlags {
|
||||
kTfLiteCustomAllocationFlagsNone = 0,
|
||||
// Skips checking whether allocation.data points to an aligned buffer as
|
||||
// expected by the TFLite runtime.
|
||||
// NOTE: Setting this flag can cause crashes when calling Invoke().
|
||||
// Use with caution.
|
||||
kTfLiteCustomAllocationFlagsSkipAlignCheck = 1,
|
||||
} TfLiteCustomAllocationFlags;
|
||||
|
||||
// A tensor in the interpreter system which is a wrapper around a buffer of
|
||||
// data including a dimensionality (or NULL if not currently defined).
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
typedef struct TfLiteTensor {
|
||||
// The data type specification for data stored in `data`. This affects
|
||||
// what member of `data` union should be used.
|
||||
TfLiteType type;
|
||||
// A union of data pointers. The appropriate type should be used for a typed
|
||||
// tensor based on `type`.
|
||||
TfLitePtrUnion data;
|
||||
// A pointer to a structure representing the dimensionality interpretation
|
||||
// that the buffer should have. NOTE: the product of elements of `dims`
|
||||
// and the element datatype size should be equal to `bytes` below.
|
||||
TfLiteIntArray* dims;
|
||||
// Quantization information.
|
||||
TfLiteQuantizationParams params;
|
||||
// How memory is mapped
|
||||
// kTfLiteMmapRo: Memory mapped read only.
|
||||
// i.e. weights
|
||||
// kTfLiteArenaRw: Arena allocated read write memory
|
||||
// (i.e. temporaries, outputs).
|
||||
TfLiteAllocationType allocation_type;
|
||||
// The number of bytes required to store the data of this Tensor. I.e.
|
||||
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
|
||||
// type is kTfLiteFloat32 and dims = {3, 2} then
|
||||
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
|
||||
size_t bytes;
|
||||
|
||||
// An opaque pointer to a tflite::MMapAllocation
|
||||
const void* allocation;
|
||||
|
||||
// Null-terminated name of this tensor.
|
||||
const char* name;
|
||||
|
||||
// The delegate which knows how to handle `buffer_handle`.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
struct TfLiteDelegate* delegate;
|
||||
|
||||
// An integer buffer handle that can be handled by `delegate`.
|
||||
// The value is valid only when delegate is not null.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteBufferHandle buffer_handle;
|
||||
|
||||
// If the delegate uses its own buffer (e.g. GPU memory), the delegate is
|
||||
// responsible to set data_is_stale to true.
|
||||
// `delegate->CopyFromBufferHandle` can be called to copy the data from
|
||||
// delegate buffer.
|
||||
// WARNING: This is an // experimental interface that is subject to change.
|
||||
bool data_is_stale;
|
||||
|
||||
// True if the tensor is a variable.
|
||||
bool is_variable;
|
||||
|
||||
// Quantization information. Replaces params field above.
|
||||
TfLiteQuantization quantization;
|
||||
|
||||
// Parameters used to encode a sparse tensor.
|
||||
// This is optional. The field is NULL if a tensor is dense.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteSparsity* sparsity;
|
||||
|
||||
// Optional. Encodes shapes with unknown dimensions with -1. This field is
|
||||
// only populated when unknown dimensions exist in a read-write tensor (i.e.
|
||||
// an input or output tensor). (e.g. `dims` contains [1, 1, 1, 3] and
|
||||
// `dims_signature` contains [1, -1, -1, 3]).
|
||||
const TfLiteIntArray* dims_signature;
|
||||
} TfLiteTensor;
|
||||
|
||||
// A structure representing an instance of a node.
|
||||
// This structure only exhibits the inputs, outputs, user defined data and some
|
||||
// node properties (like statefulness), not other features like the type.
|
||||
typedef struct TfLiteNode {
|
||||
// Inputs to this node expressed as indices into the simulator's tensors.
|
||||
TfLiteIntArray* inputs;
|
||||
|
||||
// Outputs to this node expressed as indices into the simulator's tensors.
|
||||
TfLiteIntArray* outputs;
|
||||
|
||||
// intermediate tensors to this node expressed as indices into the simulator's
|
||||
// tensors.
|
||||
TfLiteIntArray* intermediates;
|
||||
|
||||
// Temporary tensors uses during the computations. This usually contains no
|
||||
// tensors, but ops are allowed to change that if they need scratch space of
|
||||
// any sort.
|
||||
TfLiteIntArray* temporaries;
|
||||
|
||||
// Opaque data provided by the node implementer through `Registration.init`.
|
||||
void* user_data;
|
||||
|
||||
// Opaque data provided to the node if the node is a builtin. This is usually
|
||||
// a structure defined in builtin_op_data.h
|
||||
void* builtin_data;
|
||||
|
||||
// Custom initial data. This is the opaque data provided in the flatbuffer.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
const void* custom_initial_data;
|
||||
int custom_initial_data_size;
|
||||
|
||||
// The pointer to the delegate. This is non-null only when the node is
|
||||
// created by calling `interpreter.ModifyGraphWithDelegate`.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
struct TfLiteDelegate* delegate;
|
||||
|
||||
// Whether this op might have side effect (e.g. stateful op).
|
||||
bool might_have_side_effect;
|
||||
} TfLiteNode;
|
||||
#else // defined(TF_LITE_STATIC_MEMORY)?
|
||||
// NOTE: This flag is opt-in only at compile time.
|
||||
//
|
||||
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
|
||||
// contains only the minimum fields required to initialize and prepare a micro
|
||||
// inference graph. The fields in this struct have been ordered from
|
||||
// largest-to-smallest for optimal struct sizeof.
|
||||
//
|
||||
// This struct does not use:
|
||||
// - allocation
|
||||
// - buffer_handle
|
||||
// - data_is_stale
|
||||
// - delegate
|
||||
// - dims_signature
|
||||
// - name
|
||||
// - sparsity
|
||||
typedef struct TfLiteTensor {
|
||||
// TODO(b/155784997): Consider consolidating these quantization fields:
|
||||
// Quantization information. Replaces params field above.
|
||||
TfLiteQuantization quantization;
|
||||
|
||||
// Quantization information.
|
||||
TfLiteQuantizationParams params;
|
||||
|
||||
// A union of data pointers. The appropriate type should be used for a typed
|
||||
// tensor based on `type`.
|
||||
TfLitePtrUnion data;
|
||||
|
||||
// A pointer to a structure representing the dimensionality interpretation
|
||||
// that the buffer should have. NOTE: the product of elements of `dims`
|
||||
// and the element datatype size should be equal to `bytes` below.
|
||||
TfLiteIntArray* dims;
|
||||
|
||||
// The number of bytes required to store the data of this Tensor. I.e.
|
||||
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
|
||||
// type is kTfLiteFloat32 and dims = {3, 2} then
|
||||
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
|
||||
size_t bytes;
|
||||
|
||||
// The data type specification for data stored in `data`. This affects
|
||||
// what member of `data` union should be used.
|
||||
TfLiteType type;
|
||||
|
||||
// How memory is mapped
|
||||
// kTfLiteMmapRo: Memory mapped read only.
|
||||
// i.e. weights
|
||||
// kTfLiteArenaRw: Arena allocated read write memory
|
||||
// (i.e. temporaries, outputs).
|
||||
TfLiteAllocationType allocation_type;
|
||||
|
||||
// True if the tensor is a variable.
|
||||
bool is_variable;
|
||||
} TfLiteTensor;
|
||||
|
||||
// Specific reduced TfLiteNode struct for TF Micro runtime. This struct contains
|
||||
// only the minimum fields required to represent a node.
|
||||
//
|
||||
// This struct does not use:
|
||||
// - delegate
|
||||
// - intermediates
|
||||
// - temporaries
|
||||
typedef struct TfLiteNode {
|
||||
// Inputs to this node expressed as indices into the simulator's tensors.
|
||||
TfLiteIntArray* inputs;
|
||||
|
||||
// Outputs to this node expressed as indices into the simulator's tensors.
|
||||
TfLiteIntArray* outputs;
|
||||
|
||||
// intermediate tensors to this node expressed as indices into the simulator's
|
||||
// tensors.
|
||||
TfLiteIntArray* intermediates;
|
||||
|
||||
// Opaque data provided by the node implementer through `Registration.init`.
|
||||
void* user_data;
|
||||
|
||||
// Opaque data provided to the node if the node is a builtin. This is usually
|
||||
// a structure defined in builtin_op_data.h
|
||||
void* builtin_data;
|
||||
|
||||
// Custom initial data. This is the opaque data provided in the flatbuffer.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
const void* custom_initial_data;
|
||||
int custom_initial_data_size;
|
||||
} TfLiteNode;
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// Light-weight tensor struct for TF Micro runtime. Provides the minimal amount
|
||||
// of information required for a kernel to run during TfLiteRegistration::Eval.
|
||||
// TODO(b/160955687): Move this field into TF_LITE_STATIC_MEMORY when TFLM
|
||||
// builds with this flag by default internally.
|
||||
typedef struct TfLiteEvalTensor {
|
||||
// A union of data pointers. The appropriate type should be used for a typed
|
||||
// tensor based on `type`.
|
||||
TfLitePtrUnion data;
|
||||
|
||||
// A pointer to a structure representing the dimensionality interpretation
|
||||
// that the buffer should have.
|
||||
TfLiteIntArray* dims;
|
||||
|
||||
// The data type specification for data stored in `data`. This affects
|
||||
// what member of `data` union should be used.
|
||||
TfLiteType type;
|
||||
} TfLiteEvalTensor;
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Free data memory of tensor `t`.
|
||||
void TfLiteTensorDataFree(TfLiteTensor* t);
|
||||
|
||||
// Free quantization data.
|
||||
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
|
||||
|
||||
// Free sparsity parameters.
|
||||
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
|
||||
|
||||
// Free memory of tensor `t`.
|
||||
void TfLiteTensorFree(TfLiteTensor* t);
|
||||
|
||||
// Set all of a tensor's fields (and free any previously allocated data).
|
||||
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
|
||||
TfLiteQuantizationParams quantization, char* buffer,
|
||||
size_t size, TfLiteAllocationType allocation_type,
|
||||
const void* allocation, bool is_variable,
|
||||
TfLiteTensor* tensor);
|
||||
|
||||
// Copies the contents of 'src' in 'dst'.
|
||||
// Function does nothing if either 'src' or 'dst' is passed as nullptr and
|
||||
// return kTfLiteOk.
|
||||
// Returns kTfLiteError if 'src' and 'dst' doesn't have matching data size.
|
||||
// Note function copies contents, so it won't create new data pointer
|
||||
// or change allocation type.
|
||||
// All Tensor related properties will be copied from 'src' to 'dst' like
|
||||
// quantization, sparsity, ...
|
||||
TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst);
|
||||
|
||||
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
|
||||
// types other than kTfLiteDynamic will be ignored.
|
||||
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
//
|
||||
// Currently, TfLiteDelegateParams has to be allocated in a way that it's
|
||||
// trivially destructable. It will be stored as `builtin_data` field in
|
||||
// `TfLiteNode` of the delegate node.
|
||||
//
|
||||
// See also the `CreateDelegateParams` function in `interpreter.cc` details.
|
||||
typedef struct TfLiteDelegateParams {
|
||||
struct TfLiteDelegate* delegate;
|
||||
TfLiteIntArray* nodes_to_replace;
|
||||
TfLiteIntArray* input_tensors;
|
||||
TfLiteIntArray* output_tensors;
|
||||
} TfLiteDelegateParams;
|
||||
|
||||
typedef struct TfLiteContext {
|
||||
// Number of tensors in the context.
|
||||
size_t tensors_size;
|
||||
|
||||
// The execution plan contains a list of the node indices in execution
|
||||
// order. execution_plan->size is the current number of nodes. And,
|
||||
// execution_plan->data[0] is the first node that needs to be run.
|
||||
// TfLiteDelegates can traverse the current execution plan by iterating
|
||||
// through each member of this array and using GetNodeAndRegistration() to
|
||||
// access details about a node. i.e.
|
||||
//
|
||||
// TfLiteIntArray* execution_plan;
|
||||
// TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
|
||||
// for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
|
||||
// int node_index = execution_plan->data[exec_index];
|
||||
// TfLiteNode* node;
|
||||
// TfLiteRegistration* reg;
|
||||
// context->GetNodeAndRegistration(context, node_index, &node, ®);
|
||||
// }
|
||||
// Note: the memory pointed by '`*execution_plan` is OWNED by TfLite runtime.
|
||||
// Future calls to GetExecutionPlan invalidates earlier outputs. The following
|
||||
// code snippet shows the issue of such an invocation pattern. After calling
|
||||
// CheckNode, subsequent access to `plan_1st` is undefined.
|
||||
//
|
||||
// void CheckNode(const TfLiteNode* node) {
|
||||
// ...
|
||||
// TfLiteIntArray* plan_2nd;
|
||||
// TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan_2nd));
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// TfLiteIntArray* plan_1st;
|
||||
// TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan_1st));
|
||||
// for (int exec_index = 0; exec_index < plan_1st->size; exec_index++) {
|
||||
// int node_index = plan_1st->data[exec_index];
|
||||
// TfLiteNode* node;
|
||||
// TfLiteRegistration* reg;
|
||||
// context->GetNodeAndRegistration(context, node_index, &node, ®);
|
||||
// CheckNode(node);
|
||||
// }
|
||||
//
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
|
||||
TfLiteIntArray** execution_plan);
|
||||
|
||||
// An array of tensors in the interpreter context (of length `tensors_size`)
|
||||
TfLiteTensor* tensors;
|
||||
|
||||
// opaque full context ptr (an opaque c++ data structure)
|
||||
void* impl_;
|
||||
|
||||
// Request memory pointer be resized. Updates dimensions on the tensor.
|
||||
// NOTE: ResizeTensor takes ownership of newSize.
|
||||
TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
|
||||
TfLiteIntArray* new_size);
|
||||
// Request that an error be reported with format string msg.
|
||||
void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
|
||||
|
||||
// Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
|
||||
// non-null, the value pointed to by `first_new_tensor_index` will be set to
|
||||
// the index of the first new tensor.
|
||||
TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
|
||||
int* first_new_tensor_index);
|
||||
|
||||
// Get a Tensor node by node_index.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus (*GetNodeAndRegistration)(
|
||||
struct TfLiteContext*, int node_index, TfLiteNode** node,
|
||||
struct TfLiteRegistration** registration);
|
||||
|
||||
// Replace ops with one or more stub delegate operations. This function
|
||||
// does not take ownership of `nodes_to_replace`.
|
||||
TfLiteStatus (*ReplaceNodeSubsetsWithDelegateKernels)(
|
||||
struct TfLiteContext*, struct TfLiteRegistration registration,
|
||||
const TfLiteIntArray* nodes_to_replace, struct TfLiteDelegate* delegate);
|
||||
|
||||
// Number of threads that are recommended to subsystems like gemmlowp and
|
||||
// eigen.
|
||||
int recommended_num_threads;
|
||||
|
||||
// Access external contexts by type.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
|
||||
TfLiteExternalContextType);
|
||||
// Set the value of a external context. Does not take ownership of the
|
||||
// pointer.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
|
||||
TfLiteExternalContext*);
|
||||
|
||||
// Flag for allowing float16 precision for FP32 calculation.
|
||||
// default: false.
|
||||
// WARNING: This is an experimental API and subject to change.
|
||||
bool allow_fp32_relax_to_fp16;
|
||||
|
||||
// Pointer to the op-level profiler, if set; nullptr otherwise.
|
||||
void* profiler;
|
||||
|
||||
// Allocate persistent buffer which has the same life time as the interpreter.
|
||||
// Returns nullptr on failure.
|
||||
// The memory is allocated from heap for TFL, and from tail in TFLM.
|
||||
// This method is only available in Init or Prepare stage.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
void* (*AllocatePersistentBuffer)(struct TfLiteContext* ctx, size_t bytes);
|
||||
|
||||
// Allocate a buffer which will be deallocated right after invoke phase.
|
||||
// The memory is allocated from heap in TFL, and from volatile arena in TFLM.
|
||||
// This method is only available in invoke stage.
|
||||
// NOTE: If possible use RequestScratchBufferInArena method to avoid memory
|
||||
// allocation during inference time.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus (*AllocateBufferForEval)(struct TfLiteContext* ctx, size_t bytes,
|
||||
void** ptr);
|
||||
|
||||
// Request a scratch buffer in the arena through static memory planning.
|
||||
// This method is only available in Prepare stage and the buffer is allocated
|
||||
// by the interpreter between Prepare and Eval stage. In Eval stage,
|
||||
// GetScratchBuffer API can be used to fetch the address.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus (*RequestScratchBufferInArena)(struct TfLiteContext* ctx,
|
||||
size_t bytes, int* buffer_idx);
|
||||
|
||||
// Get the scratch buffer pointer.
|
||||
// This method is only available in Eval stage.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
void* (*GetScratchBuffer)(struct TfLiteContext* ctx, int buffer_idx);
|
||||
|
||||
// Resize the memory pointer of the `tensor`. This method behaves the same as
|
||||
// `ResizeTensor`, except that it makes a copy of the shape array internally
|
||||
// so the shape array could be deallocated right afterwards.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus (*ResizeTensorExplicit)(struct TfLiteContext* ctx,
|
||||
TfLiteTensor* tensor, int dims,
|
||||
const int* shape);
|
||||
|
||||
// This method provides a preview of post-delegation partitioning. Each
|
||||
// TfLiteDelegateParams in the referenced array corresponds to one instance of
|
||||
// the delegate kernel.
|
||||
// Example usage:
|
||||
//
|
||||
// TfLiteIntArray* nodes_to_replace = ...;
|
||||
// TfLiteDelegateParams* params_array;
|
||||
// int num_partitions = 0;
|
||||
// TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
||||
// context, delegate, nodes_to_replace, ¶ms_array, &num_partitions));
|
||||
// for (int idx = 0; idx < num_partitions; idx++) {
|
||||
// const auto& partition_params = params_array[idx];
|
||||
// ...
|
||||
// }
|
||||
//
|
||||
// NOTE: The context owns the memory referenced by partition_params_array. It
|
||||
// will be cleared with another call to PreviewDelegateParitioning, or after
|
||||
// TfLiteDelegateParams::Prepare returns.
|
||||
//
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus (*PreviewDelegatePartitioning)(
|
||||
struct TfLiteContext* context, const TfLiteIntArray* nodes_to_replace,
|
||||
TfLiteDelegateParams** partition_params_array, int* num_partitions);
|
||||
|
||||
// Returns a TfLiteTensor struct for a given index.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
// WARNING: This method may not be available on all platforms.
|
||||
TfLiteTensor* (*GetTensor)(const struct TfLiteContext* context,
|
||||
int tensor_idx);
|
||||
|
||||
// Returns a TfLiteEvalTensor struct for a given index.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
// WARNING: This method may not be available on all platforms.
|
||||
TfLiteEvalTensor* (*GetEvalTensor)(const struct TfLiteContext* context,
|
||||
int tensor_idx);
|
||||
|
||||
// Retrieves named metadata buffer from the TFLite model.
|
||||
// Returns kTfLiteOk if metadata is successfully obtained from the flatbuffer
|
||||
// Model: that is, there exists a `metadata` entry with given `name` string.
|
||||
// (see TFLite's schema.fbs).
|
||||
// The corresponding `buffer` information is populated in `ptr` & `bytes`.
|
||||
// The data from `ptr` is valid for the lifetime of the Interpreter.
|
||||
//
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
TfLiteStatus (*GetModelMetadata)(const struct TfLiteContext* context,
|
||||
const char* name, const char** ptr,
|
||||
size_t* bytes);
|
||||
} TfLiteContext;
|
||||
|
||||
typedef struct TfLiteRegistration {
|
||||
// Initializes the op from serialized data.
|
||||
// Called only *once* for the lifetime of the op, so any one-time allocations
|
||||
// should be made here (unless they depend on tensor sizes).
|
||||
//
|
||||
// If a built-in op:
|
||||
// `buffer` is the op's params data (TfLiteLSTMParams*).
|
||||
// `length` is zero.
|
||||
// If custom op:
|
||||
// `buffer` is the op's `custom_options`.
|
||||
// `length` is the size of the buffer.
|
||||
//
|
||||
// Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
|
||||
// or an instance of a struct).
|
||||
//
|
||||
// The returned pointer will be stored with the node in the `user_data` field,
|
||||
// accessible within prepare and invoke functions below.
|
||||
// NOTE: if the data is already in the desired format, simply implement this
|
||||
// function to return `nullptr` and implement the free function to be a no-op.
|
||||
void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
|
||||
|
||||
// The pointer `buffer` is the data previously returned by an init invocation.
|
||||
void (*free)(TfLiteContext* context, void* buffer);
|
||||
|
||||
// prepare is called when the inputs this node depends on have been resized.
|
||||
// context->ResizeTensor() can be called to request output tensors to be
|
||||
// resized.
|
||||
// Can be called multiple times for the lifetime of the op.
|
||||
//
|
||||
// Returns kTfLiteOk on success.
|
||||
TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
// Execute the node (should read node->inputs and output to node->outputs).
|
||||
// Returns kTfLiteOk on success.
|
||||
TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
|
||||
|
||||
// profiling_string is called during summarization of profiling information
|
||||
// in order to group executions together. Providing a value here will cause a
|
||||
// given op to appear multiple times is the profiling report. This is
|
||||
// particularly useful for custom ops that can perform significantly
|
||||
// different calculations depending on their `user-data`.
|
||||
const char* (*profiling_string)(const TfLiteContext* context,
|
||||
const TfLiteNode* node);
|
||||
|
||||
// Builtin codes. If this kernel refers to a builtin this is the code
|
||||
// of the builtin. This is so we can do marshaling to other frameworks like
|
||||
// NN API.
|
||||
// Note: It is the responsibility of the registration binder to set this
|
||||
// properly.
|
||||
int32_t builtin_code;
|
||||
|
||||
// Custom op name. If the op is a builtin, this will be null.
|
||||
// Note: It is the responsibility of the registration binder to set this
|
||||
// properly.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
const char* custom_name;
|
||||
|
||||
// The version of the op.
|
||||
// Note: It is the responsibility of the registration binder to set this
|
||||
// properly.
|
||||
int version;
|
||||
} TfLiteRegistration;
|
||||
|
||||
// The flags used in `TfLiteDelegate`. Note that this is a bitmask, so the
|
||||
// values should be 1, 2, 4, 8, ...etc.
|
||||
typedef enum TfLiteDelegateFlags {
|
||||
kTfLiteDelegateFlagsNone = 0,
|
||||
// The flag is set if the delegate can handle dynamic sized tensors.
|
||||
// For example, the output shape of a `Resize` op with non-constant shape
|
||||
// can only be inferred when the op is invoked.
|
||||
// In this case, the Delegate is responsible for calling
|
||||
// `SetTensorToDynamic` to mark the tensor as a dynamic tensor, and calling
|
||||
// `ResizeTensor` when invoking the op.
|
||||
//
|
||||
// If the delegate isn't capable to handle dynamic tensors, this flag need
|
||||
// to be set to false.
|
||||
kTfLiteDelegateFlagsAllowDynamicTensors = 1,
|
||||
|
||||
// This flag can be used by delegates (that allow dynamic tensors) to ensure
|
||||
// applicable tensor shapes are automatically propagated in the case of tensor
|
||||
// resizing.
|
||||
// This means that non-dynamic (allocation_type != kTfLiteDynamic) I/O tensors
|
||||
// of a delegate kernel will have correct shapes before its Prepare() method
|
||||
// is called. The runtime leverages TFLite builtin ops in the original
|
||||
// execution plan to propagate shapes.
|
||||
//
|
||||
// A few points to note:
|
||||
// 1. This requires kTfLiteDelegateFlagsAllowDynamicTensors. If that flag is
|
||||
// false, this one is redundant since the delegate kernels are re-initialized
|
||||
// every time tensors are resized.
|
||||
// 2. Enabling this flag adds some overhead to AllocateTensors(), since extra
|
||||
// work is required to prepare the original execution plan.
|
||||
// 3. This flag requires that the original execution plan only have ops with
|
||||
// valid registrations (and not 'dummy' custom ops like with Flex).
|
||||
// WARNING: This feature is experimental and subject to change.
|
||||
kTfLiteDelegateFlagsRequirePropagatedShapes = 2
|
||||
} TfLiteDelegateFlags;
|
||||
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
typedef struct TfLiteDelegate {
|
||||
// Data that delegate needs to identify itself. This data is owned by the
|
||||
// delegate. The delegate is owned in the user code, so the delegate is
|
||||
// responsible for doing this when it is destroyed.
|
||||
void* data_;
|
||||
|
||||
// Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
|
||||
// delegate a view of the current graph through TfLiteContext*. It typically
|
||||
// will look at the nodes and call ReplaceNodeSubsetsWithDelegateKernels()
|
||||
// to ask the TensorFlow lite runtime to create macro-nodes to represent
|
||||
// delegated subgraphs of the original graph.
|
||||
TfLiteStatus (*Prepare)(TfLiteContext* context,
|
||||
struct TfLiteDelegate* delegate);
|
||||
|
||||
// Copy the data from delegate buffer handle into raw memory of the given
|
||||
// 'tensor'. Note that the delegate is allowed to allocate the raw bytes as
|
||||
// long as it follows the rules for kTfLiteDynamic tensors, in which case this
|
||||
// cannot be null.
|
||||
TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
|
||||
struct TfLiteDelegate* delegate,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
TfLiteTensor* tensor);
|
||||
|
||||
// Copy the data from raw memory of the given 'tensor' to delegate buffer
|
||||
// handle. This can be null if the delegate doesn't use its own buffer.
|
||||
TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
|
||||
struct TfLiteDelegate* delegate,
|
||||
TfLiteBufferHandle buffer_handle,
|
||||
TfLiteTensor* tensor);
|
||||
|
||||
// Free the Delegate Buffer Handle. Note: This only frees the handle, but
|
||||
// this doesn't release the underlying resource (e.g. textures). The
|
||||
// resources are either owned by application layer or the delegate.
|
||||
// This can be null if the delegate doesn't use its own buffer.
|
||||
void (*FreeBufferHandle)(TfLiteContext* context,
|
||||
struct TfLiteDelegate* delegate,
|
||||
TfLiteBufferHandle* handle);
|
||||
|
||||
// Bitmask flags. See the comments in `TfLiteDelegateFlags`.
|
||||
int64_t flags;
|
||||
} TfLiteDelegate;
|
||||
|
||||
// Build a 'null' delegate, with all the fields properly set to their default
|
||||
// values.
|
||||
TfLiteDelegate TfLiteDelegateCreate(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
#endif // TENSORFLOW_LITE_C_COMMON_H_
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
77
plugin/android/src/main/cpp/CMakeLists.txt
Normal file
77
plugin/android/src/main/cpp/CMakeLists.txt
Normal file
|
@ -0,0 +1,77 @@
|
|||
|
||||
# For more information about using CMake with Android Studio, read the
|
||||
# documentation: https://d.android.com/studio/projects/add-native-code.html
|
||||
|
||||
# Sets the minimum version of CMake required to build the native library.
|
||||
|
||||
cmake_minimum_required(VERSION 3.18.1)
|
||||
|
||||
# Declares and names the project.
|
||||
|
||||
project("plugin")
|
||||
|
||||
# Creates and names a library, sets it as either STATIC
|
||||
# or SHARED, and provides the relative paths to its source code.
|
||||
# You can define multiple libraries, and CMake builds them for you.
|
||||
# Gradle automatically packages shared libraries with your APK.
|
||||
|
||||
# configure import libs
|
||||
set(dependency_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../../dependency)
|
||||
|
||||
add_library( tensorflowlite SHARED IMPORTED )
|
||||
set_target_properties( tensorflowlite PROPERTIES IMPORTED_LOCATION
|
||||
${dependency_DIR}/tensorflowlite/jni/${ANDROID_ABI}/libtensorflowlite_jni.so )
|
||||
|
||||
add_library( renderscript-intrinsics-replacement-toolkit SHARED IMPORTED )
|
||||
set_target_properties( renderscript-intrinsics-replacement-toolkit PROPERTIES IMPORTED_LOCATION
|
||||
${dependency_DIR}/renderscript-intrinsics-replacement-toolkit/jni/${ANDROID_ABI}/librenderscript-toolkit.so )
|
||||
|
||||
add_library( # Sets the name of the library.
|
||||
plugin
|
||||
|
||||
# Sets the library as a shared library.
|
||||
SHARED
|
||||
|
||||
# Provides a relative path to your source file(s).
|
||||
deep_lap_3.cpp
|
||||
exception.cpp
|
||||
stopwatch.cpp
|
||||
tflite_wrapper.cpp
|
||||
util.cpp
|
||||
zero_dce.cpp
|
||||
)
|
||||
set_target_properties( plugin PROPERTIES COMPILE_OPTIONS -fopenmp )
|
||||
|
||||
# Searches for a specified prebuilt library and stores the path as a
|
||||
# variable. Because CMake includes system libraries in the search path by
|
||||
# default, you only need to specify the name of the public NDK library
|
||||
# you want to add. CMake verifies that the library exists before
|
||||
# completing its build.
|
||||
|
||||
find_library( # Sets the name of the path variable.
|
||||
log-lib
|
||||
|
||||
# Specifies the name of the NDK library that
|
||||
# you want CMake to locate.
|
||||
log )
|
||||
|
||||
find_library( android-lib android )
|
||||
|
||||
target_include_directories( plugin PRIVATE
|
||||
${dependency_DIR}/tensorflowlite/headers
|
||||
${dependency_DIR}/renderscript-intrinsics-replacement-toolkit/headers )
|
||||
|
||||
# Specifies libraries CMake should link to your target library. You
|
||||
# can link multiple libraries, such as libraries you define in this
|
||||
# build script, prebuilt third-party libraries, or system libraries.
|
||||
|
||||
target_link_libraries( # Specifies the target library.
|
||||
plugin
|
||||
|
||||
# Links the target library to the log library
|
||||
# included in the NDK.
|
||||
${log-lib}
|
||||
${android-lib}
|
||||
tensorflowlite
|
||||
renderscript-intrinsics-replacement-toolkit
|
||||
-fopenmp -static-openmp )
|
904
plugin/android/src/main/cpp/base_resample.h
Normal file
904
plugin/android/src/main/cpp/base_resample.h
Normal file
|
@ -0,0 +1,904 @@
|
|||
/*
|
||||
//
|
||||
// Copyright (c) 1998-2019 Joe Bertolami. All Right Reserved.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistributions in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, CLUDG, BUT NOT LIMITED TO, THE
|
||||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
// ARE DISCLAIMED. NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
// LIABLE FOR ANY DIRECT, DIRECT, CIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
// CONSEQUENTIAL DAMAGES (CLUDG, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
|
||||
// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSESS TERRUPTION)
|
||||
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER CONTRACT, STRICT
|
||||
// LIABILITY, OR TORT (CLUDG NEGLIGENCE OR OTHERWISE) ARISG ANY WAY OF THE
|
||||
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
//
|
||||
// Additional Information:
|
||||
//
|
||||
// For more information, visit http://www.bertolami.com.
|
||||
//
|
||||
*/
|
||||
|
||||
#ifndef __BASE_RESAMPLE_H__
|
||||
#define __BASE_RESAMPLE_H__
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#ifndef __BASE_TYPES_H__
|
||||
#define __BASE_TYPES_H__
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
namespace base {
|
||||
|
||||
typedef int64_t int64;
|
||||
typedef int32_t int32;
|
||||
typedef int16_t int16;
|
||||
typedef int8_t int8;
|
||||
typedef uint64_t uint64;
|
||||
typedef uint32_t uint32;
|
||||
typedef uint16_t uint16;
|
||||
typedef uint8_t uint8;
|
||||
typedef float float32;
|
||||
typedef double float64;
|
||||
|
||||
} // namespace base
|
||||
|
||||
#endif // __BASE_TYPES_H__
|
||||
|
||||
namespace base {
|
||||
|
||||
enum KernelType : uint8 {
|
||||
KernelTypeNearest,
|
||||
KernelTypeAverage,
|
||||
KernelTypeBilinear,
|
||||
KernelTypeBicubic,
|
||||
KernelTypeMitchell,
|
||||
KernelTypeCardinal,
|
||||
KernelTypeBSpline,
|
||||
KernelTypeLanczos,
|
||||
KernelTypeLanczos2,
|
||||
KernelTypeLanczos3,
|
||||
KernelTypeLanczos4,
|
||||
KernelTypeLanczos5,
|
||||
KernelTypeCatmull,
|
||||
KernelTypeGaussian,
|
||||
};
|
||||
|
||||
enum KernelDirection : uint8 {
|
||||
KernelDirectionHorizontal,
|
||||
KernelDirectionVertical,
|
||||
};
|
||||
|
||||
#define BASE_PI (3.14159265359f)
|
||||
|
||||
// #define BLOCK_OFFSET_RGB24(ptr, width, x, y) (ptr + (3 * width) * y + 3 * x)
|
||||
inline const uint8_t *block_offset(const uint8 *ptr, const uint32 width,
|
||||
const int32 x, const int32 y,
|
||||
const int channels) {
|
||||
return ptr + (channels * width) * y + channels * x;
|
||||
}
|
||||
|
||||
inline uint8_t *block_offset(uint8 *ptr, const uint32 width, const int32 x,
|
||||
const int32 y, const int channels) {
|
||||
return ptr + (channels * width) * y + channels * x;
|
||||
}
|
||||
|
||||
inline int32 clip_range(int32 input, int32 low, int32 high) {
|
||||
return (input < low) ? low : (input > high) ? high : input;
|
||||
}
|
||||
|
||||
/* Cubic weighing function
|
||||
|
||||
Source: Mitchell, Netravali, "Reconstruction Filters in Computer Graphics"
|
||||
1988
|
||||
|
||||
Several of the popular cubic functions used for bi-directional image
|
||||
filtering can be generated as a simple weight function with two parameters.
|
||||
Thus, we use a weight function to generate the majority of our bicubic
|
||||
kernels. */
|
||||
inline float32 bicubic_weight(float32 f_b, float32 f_c, float32 distance) {
|
||||
/* Our bicubic function is designed to provide feedback over a radius of 2.0
|
||||
* pixels. */
|
||||
float32 distance2 = distance * distance;
|
||||
float32 distance3 = distance * distance * distance;
|
||||
float32 result = 0.0;
|
||||
|
||||
if (distance < 1.0) {
|
||||
float32 cubic_term = (12.0 - 9.0 * f_b - 6.0 * f_c) * distance3;
|
||||
float32 quad_term = (-18.0 + 12.0 * f_b + 6.0 * f_c) * distance2;
|
||||
float32 const_term = (6.0 - 2.0 * f_b);
|
||||
result = (1.0f / 6.0f) * (cubic_term + quad_term + const_term);
|
||||
}
|
||||
|
||||
else if (distance >= 1.0 && distance < 2.0) {
|
||||
float32 cubic_term = (-f_b - 6.0 * f_c) * distance3;
|
||||
float32 quad_term = (6.0 * f_b + 30.0 * f_c) * distance2;
|
||||
float32 lin_term = (-12.0 * f_b - 48.0 * f_c) * distance;
|
||||
float32 const_term = (8.0 * f_b + 24.0 * f_c);
|
||||
result = (1.0f / 6.0f) * (cubic_term + quad_term + lin_term + const_term);
|
||||
}
|
||||
|
||||
if (result < 0) {
|
||||
result = 0.0;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/* Gaussian weighting function. Our simple gaussian distribution function with
|
||||
mean of zero and std-dev (d):
|
||||
|
||||
1.0 -(x^2 / (2 * d * d))
|
||||
f(x) = --------------- * e
|
||||
0.5
|
||||
d * (2 * Pi)
|
||||
*/
|
||||
|
||||
inline float32 gaussian_weight(float32 distance, float32 radius) {
|
||||
float32 range = distance / radius;
|
||||
|
||||
/* Gaussian function with mean = 0 and variance = 0.1. */
|
||||
static const float32 variance = 0.1f;
|
||||
static const float32 stddev = sqrt(variance);
|
||||
static const float32 coeff = 1.0f / (stddev * sqrt(2.0 * BASE_PI));
|
||||
return coeff * exp(-1.0f * (range * range) / (2.0 * variance));
|
||||
}
|
||||
|
||||
inline float32 sinc(float32 f_x) {
|
||||
if (0.0 == f_x)
|
||||
return 1.0;
|
||||
return sin(BASE_PI * f_x) / (BASE_PI * f_x);
|
||||
}
|
||||
|
||||
inline float32 lanczos_weight(float32 f_n, float32 distance) {
|
||||
if (distance <= f_n) {
|
||||
return sinc(distance) * sinc(distance / f_n);
|
||||
}
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelBilinearH(const uint8 *src, uint32 src_width,
|
||||
uint32 src_height, float32 f_x, float32 f_y,
|
||||
uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/* We do not bias our float coordinate by 0.5 because we wish
|
||||
to sample using the nearest 2 pixels to our coordinate. */
|
||||
int32 sample_x = f_x;
|
||||
int32 sample_y = f_y;
|
||||
const uint8 *pixels[2] = {nullptr};
|
||||
float32 f_delta = (float32)f_x - sample_x;
|
||||
|
||||
/* compute our two pixels that will be interpolated together. */
|
||||
for (uint32 i = 0; i < 2; i++) {
|
||||
int32 src_x = clip_range(sample_x + i, 0, src_width - 1);
|
||||
int32 src_y = clip_range(sample_y, 0, src_height - 1);
|
||||
|
||||
pixels[i] = block_offset(src, src_width, src_x, src_y, ch);
|
||||
}
|
||||
|
||||
/* perform the interpolation of our lerp_pixels. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[ch] = pixels[0][ch] * (1.0f - f_delta) + pixels[1][ch] * f_delta;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelBilinearV(const uint8 *src, uint32 src_width,
|
||||
uint32 src_height, float32 f_x, float32 f_y,
|
||||
uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/* We do not bias our float coordinate by 0.5 because we wish
|
||||
to sample using the nearest 2 pixels to our coordinate. */
|
||||
int32 sample_x = f_x;
|
||||
int32 sample_y = f_y;
|
||||
const uint8 *pixels[2] = {nullptr};
|
||||
float32 f_delta = (float32)f_y - sample_y;
|
||||
|
||||
/* compute our two pixels that will be interpolated together. */
|
||||
for (uint32 i = 0; i < 2; i++) {
|
||||
int32 src_x = clip_range(sample_x, 0, src_width - 1);
|
||||
int32 src_y = clip_range(sample_y + i, 0, src_height - 1);
|
||||
|
||||
pixels[i] = block_offset(src, src_width, src_x, src_y, ch);
|
||||
}
|
||||
|
||||
/* perform the interpolation of our lerp_pixels. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[ch] = pixels[0][ch] * (1.0f - f_delta) + pixels[1][ch] * f_delta;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelBilinear(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
KernelDirection direction, float32 f_x, float32 f_y,
|
||||
uint8 *output) {
|
||||
switch (direction) {
|
||||
case KernelDirectionHorizontal:
|
||||
return SampleKernelBilinearH<ch>(src, src_width, src_height, f_x, f_y,
|
||||
output);
|
||||
case KernelDirectionVertical:
|
||||
return SampleKernelBilinearV<ch>(src, src_width, src_height, f_x, f_y,
|
||||
output);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelBicubicH(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
float32 f_x, float32 f_y, float32 coeff_b,
|
||||
float32 coeff_c, uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
float32 sample_count = 0;
|
||||
float32 total_samples[ch] = {0};
|
||||
|
||||
/* Scan the kernel space adding up the bicubic weights and pixel values. */
|
||||
for (int32 i = -2; i < 2; i++) {
|
||||
int32 i_x = (int32)f_x + i;
|
||||
int32 i_y = (int32)f_y;
|
||||
|
||||
if (i_x < 0 || i_y < 0 || i_x > src_width - 1 || i_y > src_height - 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float32 x_delta = (float32)f_x - i_x;
|
||||
float32 distance = fabs(x_delta);
|
||||
float32 weight = bicubic_weight(coeff_b, coeff_c, distance);
|
||||
|
||||
const uint8 *src_pixel = block_offset(src, src_width, i_x, i_y, ch);
|
||||
|
||||
/* accumulate bicubic weighted samples from the source. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
total_samples[i] += src_pixel[i] * weight;
|
||||
}
|
||||
|
||||
/* record the total weights of the sample for later normalization. */
|
||||
sample_count += weight;
|
||||
}
|
||||
|
||||
/* Normalize our bicubic sum back to the valid pixel range. */
|
||||
float32 scale_factor = 1.0f / sample_count;
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = clip_range(scale_factor * total_samples[i], 0, 255);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelBicubicV(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
float32 f_x, float32 f_y, float32 coeff_b,
|
||||
float32 coeff_c, uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
float32 sample_count = 0;
|
||||
float32 total_samples[ch] = {0};
|
||||
|
||||
/* Scan the kernel space adding up the bicubic weights and pixel values. */
|
||||
for (int32 i = -2; i < 2; i++) {
|
||||
int32 i_x = (int32)f_x;
|
||||
int32 i_y = (int32)f_y + i;
|
||||
|
||||
if (i_x < 0 || i_y < 0 || i_x > src_width - 1 || i_y > src_height - 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float32 y_delta = (float32)f_y - i_y;
|
||||
float32 distance = fabs(y_delta);
|
||||
float32 weight = bicubic_weight(coeff_b, coeff_c, distance);
|
||||
const uint8 *src_pixel = block_offset(src, src_width, i_x, i_y, ch);
|
||||
|
||||
/* accumulate bicubic weighted samples from the source. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
total_samples[i] += src_pixel[i] * weight;
|
||||
}
|
||||
|
||||
/* record the total weights of the sample for later normalization. */
|
||||
sample_count += weight;
|
||||
}
|
||||
|
||||
/* Normalize our bicubic sum back to the valid pixel range. */
|
||||
float32 scale_factor = 1.0f / sample_count;
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = clip_range(scale_factor * total_samples[i], 0, 255);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelBicubic(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
KernelDirection direction, float32 f_x, float32 f_y,
|
||||
float32 coeff_b, float32 coeff_c, uint8 *output) {
|
||||
switch (direction) {
|
||||
case KernelDirectionHorizontal:
|
||||
return SampleKernelBicubicH<ch>(src, src_width, src_height, f_x, f_y,
|
||||
coeff_b, coeff_c, output);
|
||||
case KernelDirectionVertical:
|
||||
return SampleKernelBicubicV<ch>(src, src_width, src_height, f_x, f_y,
|
||||
coeff_b, coeff_c, output);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelLanczosH(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
float32 f_x, float32 f_y, float32 coeff_a,
|
||||
uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32 radius = coeff_a;
|
||||
float32 sample_count = 0;
|
||||
float32 total_samples[ch] = {0};
|
||||
|
||||
/* Scan the kernel space adding up the bicubic weights and pixel values. */
|
||||
for (int32 i = -radius; i < radius; i++) {
|
||||
int32 i_x = (int32)f_x + i;
|
||||
int32 i_y = (int32)f_y;
|
||||
|
||||
if (i_x < 0 || i_y < 0 || i_x > src_width - 1 || i_y > src_height - 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float32 x_delta = (float32)f_x - i_x;
|
||||
float32 distance = fabs(x_delta);
|
||||
float32 weight = lanczos_weight(coeff_a, distance);
|
||||
|
||||
const uint8 *src_pixel = block_offset(src, src_width, i_x, i_y, ch);
|
||||
|
||||
/* accumulate bicubic weighted samples from the source. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
total_samples[i] += src_pixel[i] * weight;
|
||||
}
|
||||
|
||||
/* record the total weights of the sample for later normalization. */
|
||||
sample_count += weight;
|
||||
}
|
||||
|
||||
/* Normalize our bicubic sum back to the valid pixel range. */
|
||||
float32 scale_factor = 1.0f / sample_count;
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = clip_range(scale_factor * total_samples[i], 0, 255);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelLanczosV(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
float32 f_x, float32 f_y, float32 coeff_a,
|
||||
uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32 radius = coeff_a;
|
||||
float32 sample_count = 0;
|
||||
float32 total_samples[ch] = {0};
|
||||
|
||||
/* Scan the kernel space adding up the bicubic weights and pixel values. */
|
||||
for (int32 i = -radius; i < radius; i++) {
|
||||
int32 i_x = (int32)f_x;
|
||||
int32 i_y = (int32)f_y + i;
|
||||
|
||||
if (i_x < 0 || i_y < 0 || i_x > src_width - 1 || i_y > src_height - 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float32 y_delta = (float32)f_y - i_y;
|
||||
float32 distance = fabs(y_delta);
|
||||
float32 weight = lanczos_weight(coeff_a, distance);
|
||||
|
||||
const uint8 *src_pixel = block_offset(src, src_width, i_x, i_y, ch);
|
||||
|
||||
/* accumulate bicubic weighted samples from the source. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
total_samples[i] += src_pixel[i] * weight;
|
||||
}
|
||||
|
||||
/* record the total weights of the sample for later normalization. */
|
||||
sample_count += weight;
|
||||
}
|
||||
|
||||
/* Normalize our bicubic sum back to the valid pixel range. */
|
||||
float32 scale_factor = 1.0f / sample_count;
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = clip_range(scale_factor * total_samples[i], 0, 255);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelLanczos(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
KernelDirection direction, float32 f_x, float32 f_y,
|
||||
float32 coeff_a, uint8 *output) {
|
||||
switch (direction) {
|
||||
case KernelDirectionHorizontal:
|
||||
return SampleKernelLanczosH<ch>(src, src_width, src_height, f_x, f_y,
|
||||
coeff_a, output);
|
||||
case KernelDirectionVertical:
|
||||
return SampleKernelLanczosV<ch>(src, src_width, src_height, f_x, f_y,
|
||||
coeff_a, output);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelAverageH(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
float32 f_x, float32 f_y, float32 h_ratio,
|
||||
uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32 radius = h_ratio + 1.0f;
|
||||
float32 max_distance = h_ratio;
|
||||
float32 sample_count = 0;
|
||||
float32 total_samples[ch] = {0};
|
||||
|
||||
/* Scan the kernel space adding up the bicubic weights and pixel values. */
|
||||
for (int32 i = -radius + 1; i <= radius; i++) {
|
||||
int32 i_x = (int32)f_x + i;
|
||||
int32 i_y = (int32)f_y;
|
||||
|
||||
if (i_x < 0 || i_y < 0 || i_x > src_width - 1 || i_y > src_height - 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float32 x_delta = (float32)f_x - i_x;
|
||||
float32 distance = fabs(x_delta);
|
||||
float32 weight = 0.0f;
|
||||
|
||||
const uint8 *src_pixel = block_offset(src, src_width, i_x, i_y, ch);
|
||||
|
||||
if (h_ratio >= 1.0) {
|
||||
distance = std::min(max_distance, distance);
|
||||
weight = 1.0f - distance / max_distance;
|
||||
} else {
|
||||
if (distance >= 0.5f - h_ratio) {
|
||||
weight = 1.0f - distance;
|
||||
} else {
|
||||
/* our average kernel is smaller than a pixel and is fully contained
|
||||
within the source pixel, so we simply copy the value out. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = src_pixel[i];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/* accumulate bicubic weighted samples from the source. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
total_samples[i] += src_pixel[i] * weight;
|
||||
}
|
||||
|
||||
/* record the total weights of the sample for later normalization. */
|
||||
sample_count += weight;
|
||||
}
|
||||
|
||||
/* Normalize our bicubic sum back to the valid pixel range. */
|
||||
float32 scale_factor = 1.0f / sample_count;
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = clip_range(scale_factor * total_samples[i], 0, 255);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelAverageV(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
float32 f_x, float32 f_y, float32 v_ratio,
|
||||
uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32 radius = v_ratio + 1.0f;
|
||||
float32 max_distance = v_ratio;
|
||||
float32 sample_count = 0;
|
||||
float32 total_samples[ch] = {0};
|
||||
|
||||
/* Scan the kernel space adding up the bicubic weights and pixel values. */
|
||||
for (int32 i = -radius + 1; i <= radius; i++) {
|
||||
int32 i_x = (int32)f_x;
|
||||
int32 i_y = (int32)f_y + i;
|
||||
|
||||
if (i_x < 0 || i_y < 0 || i_x > src_width - 1 || i_y > src_height - 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float32 y_delta = (float32)f_y - i_y;
|
||||
float32 distance = fabs(y_delta);
|
||||
float32 weight = 0.0f;
|
||||
|
||||
const uint8 *src_pixel = block_offset(src, src_width, i_x, i_y, ch);
|
||||
|
||||
if (v_ratio >= 1.0) {
|
||||
distance = std::min(max_distance, distance);
|
||||
weight = 1.0f - distance / max_distance;
|
||||
} else {
|
||||
if (distance >= 0.5f - v_ratio) {
|
||||
weight = 1.0f - distance;
|
||||
} else {
|
||||
/* our average kernel is smaller than a pixel and is fully contained
|
||||
within the source pixel, so we simply copy the value out. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = src_pixel[i];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/* accumulate bicubic weighted samples from the source. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
total_samples[i] += src_pixel[i] * weight;
|
||||
}
|
||||
|
||||
/* record the total weights of the sample for later normalization. */
|
||||
sample_count += weight;
|
||||
}
|
||||
|
||||
/* Normalize our bicubic sum back to the valid pixel range. */
|
||||
float32 scale_factor = 1.0f / sample_count;
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = clip_range(scale_factor * total_samples[i], 0, 255);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelGaussianH(const uint8 *src, uint32 src_width,
|
||||
uint32 src_height, float32 f_x, float32 f_y,
|
||||
float32 h_ratio, uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32 radius = h_ratio + 1.0f;
|
||||
float32 max_distance = h_ratio;
|
||||
float32 sample_count = 0;
|
||||
float32 total_samples[ch] = {0};
|
||||
|
||||
/* Scan the kernel space adding up the bicubic weights and pixel values. */
|
||||
for (int32 i = -radius; i <= radius; i++) {
|
||||
int32 i_x = (int32)f_x + i;
|
||||
int32 i_y = (int32)f_y;
|
||||
|
||||
if (i_x < 0 || i_y < 0 || i_x > src_width - 1 || i_y > src_height - 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float32 x_delta = (float32)f_x - i_x;
|
||||
float32 distance = fabs(x_delta);
|
||||
float32 weight = gaussian_weight(distance, max_distance);
|
||||
|
||||
const uint8 *src_pixel = block_offset(src, src_width, i_x, i_y, ch);
|
||||
|
||||
/* accumulate bicubic weighted samples from the source. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
total_samples[i] += src_pixel[i] * weight;
|
||||
}
|
||||
|
||||
/* record the total weights of the sample for later normalization. */
|
||||
sample_count += weight;
|
||||
}
|
||||
|
||||
/* Normalize our bicubic sum back to the valid pixel range. */
|
||||
float32 scale_factor = 1.0f / sample_count;
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = clip_range(scale_factor * total_samples[i], 0, 255);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelGaussianV(const uint8 *src, uint32 src_width,
|
||||
uint32 src_height, float32 f_x, float32 f_y,
|
||||
float32 v_ratio, uint8 *output) {
|
||||
if (!src || !src_width || !src_height || f_x < 0 || f_y < 0 || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32 radius = v_ratio + 1.0f;
|
||||
float32 max_distance = v_ratio;
|
||||
float32 sample_count = 0;
|
||||
float32 total_samples[ch] = {0};
|
||||
|
||||
/* Scan the kernel space adding up the bicubic weights and pixel values. */
|
||||
for (int32 i = -radius; i <= radius; i++) {
|
||||
int32 i_x = (int32)f_x;
|
||||
int32 i_y = (int32)f_y + i;
|
||||
|
||||
if (i_x < 0 || i_y < 0 || i_x > src_width - 1 || i_y > src_height - 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float32 y_delta = (float32)f_y - i_y;
|
||||
float32 distance = fabs(y_delta);
|
||||
float32 weight = gaussian_weight(distance, max_distance);
|
||||
|
||||
const uint8 *src_pixel = block_offset(src, src_width, i_x, i_y, ch);
|
||||
|
||||
/* accumulate bicubic weighted samples from the source. */
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
total_samples[i] += src_pixel[i] * weight;
|
||||
}
|
||||
|
||||
/* record the total weights of the sample for later normalization. */
|
||||
sample_count += weight;
|
||||
}
|
||||
|
||||
/* Normalize our bicubic sum back to the valid pixel range. */
|
||||
float32 scale_factor = 1.0f / sample_count;
|
||||
for (unsigned i = 0; i < ch; ++i) {
|
||||
output[i] = clip_range(scale_factor * total_samples[i], 0, 255);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelAverage(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
KernelDirection direction, float32 f_x, float32 f_y,
|
||||
float32 h_ratio, float32 v_ratio, uint8 *output) {
|
||||
switch (direction) {
|
||||
case KernelDirectionHorizontal:
|
||||
return SampleKernelAverageH<ch>(src, src_width, src_height, f_x, f_y,
|
||||
h_ratio, output);
|
||||
case KernelDirectionVertical:
|
||||
return SampleKernelAverageV<ch>(src, src_width, src_height, f_x, f_y,
|
||||
v_ratio, output);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelGaussian(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
KernelDirection direction, float32 f_x, float32 f_y,
|
||||
float32 h_ratio, float32 v_ratio, uint8 *output) {
|
||||
switch (direction) {
|
||||
case KernelDirectionHorizontal:
|
||||
return SampleKernelGaussianH<ch>(src, src_width, src_height, f_x, f_y,
|
||||
h_ratio, output);
|
||||
case KernelDirectionVertical:
|
||||
return SampleKernelGaussianV<ch>(src, src_width, src_height, f_x, f_y,
|
||||
v_ratio, output);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernelNearest(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
float32 f_x, float32 f_y, uint8 *output) {
|
||||
if (!src || !src_width || !src_height || !output) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32 i_x = (int32)(f_x + 0.5f);
|
||||
int32 i_y = (int32)(f_y + 0.5f);
|
||||
|
||||
/* Floating point pixel coordinates are pixel-center based. Thus, a coordinate
|
||||
of (0,0) refers to the center of the first pixel in an image, and a
|
||||
coordinate of (0.5,0) refers to the border between the first and second
|
||||
pixels. */
|
||||
i_x = clip_range(i_x, 0, src_width - 1);
|
||||
i_y = clip_range(i_y, 0, src_height - 1);
|
||||
|
||||
/* Sample our pixel and write it to the output buffer. */
|
||||
memcpy(output, block_offset(src, src_width, i_x, i_y, ch), ch);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <size_t ch>
|
||||
bool SampleKernel(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
KernelDirection direction, float32 f_x, float32 f_y,
|
||||
KernelType type, float32 h_ratio, float32 v_ratio,
|
||||
uint8 *output) {
|
||||
switch (type) {
|
||||
case KernelTypeNearest:
|
||||
return SampleKernelNearest<ch>(src, src_width, src_height, f_x, f_y,
|
||||
output);
|
||||
case KernelTypeBilinear:
|
||||
return SampleKernelBilinear<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, output);
|
||||
case KernelTypeBicubic:
|
||||
return SampleKernelBicubic<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 0, 1, output);
|
||||
case KernelTypeCatmull:
|
||||
return SampleKernelBicubic<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 0, 0.5, output);
|
||||
case KernelTypeMitchell:
|
||||
return SampleKernelBicubic<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 1.0f / 3.0f, 1.0f / 3.0f, output);
|
||||
case KernelTypeCardinal:
|
||||
return SampleKernelBicubic<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 0.0f, 0.75f, output);
|
||||
case KernelTypeBSpline:
|
||||
return SampleKernelBicubic<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 1, 0, output);
|
||||
case KernelTypeLanczos:
|
||||
return SampleKernelLanczos<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 1, output);
|
||||
case KernelTypeLanczos2:
|
||||
return SampleKernelLanczos<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 2, output);
|
||||
case KernelTypeLanczos3:
|
||||
return SampleKernelLanczos<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 3, output);
|
||||
case KernelTypeLanczos4:
|
||||
return SampleKernelLanczos<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 4, output);
|
||||
case KernelTypeLanczos5:
|
||||
return SampleKernelLanczos<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, 5, output);
|
||||
case KernelTypeAverage:
|
||||
return SampleKernelAverage<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, h_ratio, v_ratio, output);
|
||||
case KernelTypeGaussian:
|
||||
return SampleKernelGaussian<ch>(src, src_width, src_height, direction, f_x,
|
||||
f_y, h_ratio, v_ratio, output);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Resamples a N-channel RGB image using a bilinear, bicubic, or lanczos filter.
|
||||
*/
|
||||
template <size_t ch>
|
||||
bool ResampleImage(const uint8 *src, uint32 src_width, uint32 src_height,
|
||||
uint8 *dst, uint32 dst_width, uint32 dst_height,
|
||||
KernelType type, ::std::string *errors = nullptr) {
|
||||
if (!src || !dst || !src_width || !src_height || !dst_width || !dst_height) {
|
||||
if (errors) {
|
||||
*errors = "Invalid parameter passed to ResampleImage.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32 src_row_pitch = ch * src_width;
|
||||
uint32 dst_row_pitch = ch * dst_width;
|
||||
uint32 buffer_size = dst_row_pitch * src_height;
|
||||
uint32 dst_image_size = dst_row_pitch * dst_height;
|
||||
|
||||
if (src_width == dst_width && src_height == dst_height) {
|
||||
/* no resampling needed, simply copy the image over. */
|
||||
memcpy(dst, src, dst_image_size);
|
||||
return true;
|
||||
}
|
||||
|
||||
/* allocate a temporary buffer to hold our horizontal pass output. We're
|
||||
using unique_ptr rather than vector because we want a fast and smart way
|
||||
to allocate very large buffers without initialization. */
|
||||
::std::unique_ptr<uint8[]> buffer(new uint8[buffer_size]);
|
||||
|
||||
/* Prepare to perform our resample. This is perhaps the most important part of
|
||||
our resizer -- the calculation of our image ratios. These ratios are
|
||||
responsible for mapping between our integer pixel locations of the source
|
||||
image and our float sub-pixel coordinates within the source image that
|
||||
represent a reflection of our destination pixels.
|
||||
|
||||
For a source 2x1 image and a destination 4x1 image:
|
||||
|
||||
+------------+------------+
|
||||
Src: | 0 | 1 |
|
||||
+------------+------------+
|
||||
| |
|
||||
0.0 1.0
|
||||
| |
|
||||
+---+---+---+---+
|
||||
Dst: | 0 | 1 | 2 | 3 |
|
||||
+---+---+---+---+
|
||||
|
||||
o: Note that the center of the first and last pixels in both our src and dst
|
||||
images line up with our float edges of 0.0 and 1.0.
|
||||
|
||||
o: Our sub-pixel interpolated coordinates will always be >= 0 and <=
|
||||
src_width.
|
||||
|
||||
o: Thus the src pixel coordinate of our final destination pixel will always
|
||||
be src_width - 1.
|
||||
*/
|
||||
|
||||
/* ratios define our kernel size and resample factor. */
|
||||
float32 h_ratio =
|
||||
(1 == dst_width ? 1.0f : ((float32)src_width - 1) / (dst_width - 1));
|
||||
float32 v_ratio =
|
||||
(1 == dst_height ? 1.0f : ((float32)src_height - 1) / (dst_height - 1));
|
||||
|
||||
/* horizontal sampling first. */
|
||||
for (uint32 j = 0; j < src_height; j++)
|
||||
for (uint32 i = 0; i < dst_width; i++) {
|
||||
uint8 *output = block_offset(buffer.get(), dst_width, i, j, ch);
|
||||
|
||||
/* Determine the sub-pixel location of our *target* (i,j) coordinate, in
|
||||
the space of our source image. */
|
||||
float32 f_x = (float32)i * h_ratio;
|
||||
float32 f_y = (float32)j;
|
||||
|
||||
if (!SampleKernel<ch>(src, src_width, src_height,
|
||||
KernelDirectionHorizontal, f_x, f_y, type, h_ratio,
|
||||
v_ratio, output)) {
|
||||
if (errors) {
|
||||
*errors = "Failure during horizontal resample operation.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/* vertical sampling next. */
|
||||
for (uint32 j = 0; j < dst_height; j++)
|
||||
for (uint32 i = 0; i < dst_width; i++) {
|
||||
uint8 *output = block_offset(dst, dst_width, i, j, ch);
|
||||
|
||||
/* Determine the sub-pixel location of our *target* (i,j) coordinate, in
|
||||
the space of our temp image. */
|
||||
float32 f_x = (float32)i;
|
||||
float32 f_y = (float32)j * v_ratio;
|
||||
|
||||
if (!SampleKernel<ch>(buffer.get(), dst_width, src_height,
|
||||
KernelDirectionVertical, f_x, f_y, type, h_ratio,
|
||||
v_ratio, output)) {
|
||||
if (errors) {
|
||||
*errors = "Failure during vertical resample operation.";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/* Resamples a 24 bit RGB image using a bilinear, bicubic, or lanczos filter. */
|
||||
inline bool ResampleImage24(const uint8 *src, uint32 src_width,
|
||||
uint32 src_height, uint8 *dst, uint32 dst_width,
|
||||
uint32 dst_height, KernelType type,
|
||||
::std::string *errors = nullptr) {
|
||||
return ResampleImage<3>(src, src_width, src_height, dst, dst_width,
|
||||
dst_height, type, errors);
|
||||
}
|
||||
|
||||
} // namespace base
|
||||
|
||||
#endif // __BASE_RESAMPLE_H__
|
227
plugin/android/src/main/cpp/deep_lap_3.cpp
Normal file
227
plugin/android/src/main/cpp/deep_lap_3.cpp
Normal file
|
@ -0,0 +1,227 @@
|
|||
#include "base_resample.h"
|
||||
#include "exception.h"
|
||||
#include "log.h"
|
||||
#include "stopwatch.h"
|
||||
#include "tflite_wrapper.h"
|
||||
#include "util.h"
|
||||
#include <RenderScriptToolkit.h>
|
||||
#include <algorithm>
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <cassert>
|
||||
#include <exception>
|
||||
#include <jni.h>
|
||||
#include <omp.h>
|
||||
#include <tensorflow/lite/c/c_api.h>
|
||||
|
||||
using namespace plugin;
|
||||
using namespace renderscript;
|
||||
using namespace std;
|
||||
using namespace tflite;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char *MODEL = "lite-model_mobilenetv2-dm05-coco_dr_1.tflite";
|
||||
constexpr size_t WIDTH = 513;
|
||||
constexpr size_t HEIGHT = 513;
|
||||
constexpr unsigned LABEL_COUNT = 21;
|
||||
|
||||
enum struct Label {
|
||||
BACKGROUND = 0,
|
||||
AEROPLANE,
|
||||
BICYCLE,
|
||||
BIRD,
|
||||
BOAT,
|
||||
BOTTLE,
|
||||
BUS,
|
||||
CAR,
|
||||
CAT,
|
||||
CHAIR,
|
||||
COW,
|
||||
DINING_TABLE,
|
||||
DOG,
|
||||
HORSE,
|
||||
MOTORBIKE,
|
||||
PERSON,
|
||||
POTTED_PLANT,
|
||||
SHEEP,
|
||||
SOFA,
|
||||
TRAIN,
|
||||
TV,
|
||||
};
|
||||
|
||||
class DeepLab3 {
|
||||
public:
|
||||
explicit DeepLab3(AAssetManager *const aam);
|
||||
DeepLab3(const DeepLab3 &) = delete;
|
||||
DeepLab3(DeepLab3 &&) = default;
|
||||
|
||||
std::vector<uint8_t> infer(const uint8_t *image, const size_t width,
|
||||
const size_t height);
|
||||
|
||||
private:
|
||||
Model model;
|
||||
|
||||
static constexpr const char *TAG = "DeepLap3";
|
||||
};
|
||||
|
||||
class DeepLab3Portrait {
|
||||
public:
|
||||
explicit DeepLab3Portrait(DeepLab3 &&deepLab);
|
||||
|
||||
std::vector<uint8_t> infer(const uint8_t *image, const size_t width,
|
||||
const size_t height, const unsigned radius);
|
||||
|
||||
private:
|
||||
/**
|
||||
* Post-process the segment map.
|
||||
*
|
||||
* The resulting segment map will:
|
||||
* 1. Contain only the most significant label (the one with the most pixel)
|
||||
* 2. The label value set to 255
|
||||
* 3. The background set to 0
|
||||
*
|
||||
* @param segmentMap
|
||||
*/
|
||||
void postProcessSegmentMap(std::vector<uint8_t> *segmentMap);
|
||||
|
||||
std::vector<uint8_t> enhance(const uint8_t *image, const size_t width,
|
||||
const size_t height,
|
||||
const std::vector<uint8_t> &segmentMap,
|
||||
const unsigned radius);
|
||||
|
||||
DeepLab3 deepLab;
|
||||
|
||||
static constexpr const char *TAG = "DeepLab3Portrait";
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" JNIEXPORT jbyteArray JNICALL
|
||||
Java_com_nkming_nc_1photos_plugin_image_1processor_DeepLab3Portrait_inferNative(
|
||||
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
|
||||
jint width, jint height, jint radius) {
|
||||
try {
|
||||
initOpenMp();
|
||||
auto aam = AAssetManager_fromJava(env, assetManager);
|
||||
DeepLab3Portrait model(DeepLab3{aam});
|
||||
RaiiContainer<jbyte> cImage(
|
||||
[&]() { return env->GetByteArrayElements(image, nullptr); },
|
||||
[&](jbyte *obj) {
|
||||
env->ReleaseByteArrayElements(image, obj, JNI_ABORT);
|
||||
});
|
||||
const auto result = model.infer(reinterpret_cast<uint8_t *>(cImage.get()),
|
||||
width, height, radius);
|
||||
auto resultAry = env->NewByteArray(result.size());
|
||||
env->SetByteArrayRegion(resultAry, 0, result.size(),
|
||||
reinterpret_cast<const int8_t *>(result.data()));
|
||||
return resultAry;
|
||||
} catch (const exception &e) {
|
||||
throwJavaException(env, e.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
DeepLab3::DeepLab3(AAssetManager *const aam) : model(Asset(aam, MODEL)) {}
|
||||
|
||||
vector<uint8_t> DeepLab3::infer(const uint8_t *image, const size_t width,
|
||||
const size_t height) {
|
||||
InterpreterOptions options;
|
||||
options.setNumThreads(getNumberOfProcessors());
|
||||
Interpreter interpreter(model, options);
|
||||
interpreter.allocateTensors();
|
||||
|
||||
LOGI(TAG, "[infer] Convert bitmap to input");
|
||||
vector<uint8_t> inputBitmap(WIDTH * HEIGHT * 3);
|
||||
base::ResampleImage24(const_cast<uint8_t *>(image), width, height,
|
||||
inputBitmap.data(), WIDTH, HEIGHT,
|
||||
base::KernelTypeLanczos3);
|
||||
const auto input =
|
||||
rgb8ToRgbFloat(inputBitmap.data(), inputBitmap.size(), true);
|
||||
auto inputTensor = interpreter.getInputTensor(0);
|
||||
assert(TfLiteTensorByteSize(inputTensor) == input.size() * sizeof(float));
|
||||
TfLiteTensorCopyFromBuffer(inputTensor, input.data(),
|
||||
input.size() * sizeof(float));
|
||||
|
||||
LOGI(TAG, "[infer] Inferring");
|
||||
Stopwatch stopwatch;
|
||||
interpreter.invoke();
|
||||
LOGI(TAG, "[infer] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f);
|
||||
|
||||
auto outputTensor = interpreter.getOutputTensor(0);
|
||||
vector<float> output(WIDTH * HEIGHT * LABEL_COUNT);
|
||||
assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float));
|
||||
TfLiteTensorCopyToBuffer(outputTensor, output.data(),
|
||||
output.size() * sizeof(float));
|
||||
const auto i1 = (200 * 513 + 260) * LABEL_COUNT;
|
||||
return argmax(output.data(), WIDTH, HEIGHT, LABEL_COUNT);
|
||||
}
|
||||
|
||||
DeepLab3Portrait::DeepLab3Portrait(DeepLab3 &&deepLab)
|
||||
: deepLab(move(deepLab)) {}
|
||||
|
||||
vector<uint8_t> DeepLab3Portrait::infer(const uint8_t *image,
|
||||
const size_t width, const size_t height,
|
||||
const unsigned radius) {
|
||||
auto segmentMap = deepLab.infer(image, width, height);
|
||||
postProcessSegmentMap(&segmentMap);
|
||||
return enhance(image, width, height, segmentMap, radius);
|
||||
}
|
||||
|
||||
void DeepLab3Portrait::postProcessSegmentMap(vector<uint8_t> *segmentMap) {
|
||||
// keep only the largest segment
|
||||
vector<uint8_t> &segmentMapRef = *segmentMap;
|
||||
vector<int> count(LABEL_COUNT);
|
||||
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
||||
assert(segmentMapRef[i] < LABEL_COUNT);
|
||||
const auto label = std::min<unsigned>(segmentMapRef[i], LABEL_COUNT);
|
||||
if (label != static_cast<int>(Label::BACKGROUND)) {
|
||||
++count[label];
|
||||
}
|
||||
}
|
||||
const auto keep = distance(
|
||||
count.data(), max_element(count.data(), count.data() + count.size()));
|
||||
LOGI(TAG, "[postProcessSegmentMap] Label to keep: %d",
|
||||
static_cast<int>(keep));
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < segmentMapRef.size(); ++i) {
|
||||
if (segmentMapRef[i] == keep) {
|
||||
segmentMapRef[i] = 0xFF;
|
||||
} else {
|
||||
segmentMapRef[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
vector<uint8_t> DeepLab3Portrait::enhance(const uint8_t *image,
|
||||
const size_t width,
|
||||
const size_t height,
|
||||
const vector<uint8_t> &segmentMap,
|
||||
const unsigned radius) {
|
||||
LOGI(TAG, "[enhance] Enhancing image");
|
||||
// resize alpha to input size
|
||||
vector<uint8_t> alpha(width * height);
|
||||
base::ResampleImage<1>(segmentMap.data(), WIDTH, HEIGHT, alpha.data(), width,
|
||||
height, base::KernelTypeLanczos3);
|
||||
// smoothen the edge
|
||||
vector<uint8_t> alphaFiltered(width * height);
|
||||
getToolkitInst().blur(alpha.data(), alphaFiltered.data(), width, height, 1,
|
||||
16);
|
||||
alpha.clear();
|
||||
|
||||
// blur input
|
||||
auto rgba8 = rgb8ToRgba8(image, width, height);
|
||||
vector<uint8_t> blur(width * height * 4);
|
||||
getToolkitInst().blur(rgba8.data(), blur.data(), width, height, 4, radius);
|
||||
|
||||
// draw input on top of blurred image, with alpha map
|
||||
replaceChannel<4>(rgba8.data(), alphaFiltered.data(), width, height, 3);
|
||||
alphaFiltered.clear();
|
||||
alphaBlend(rgba8.data(), blur.data(), width, height);
|
||||
rgba8.clear();
|
||||
return rgba8ToRgb8(blur.data(), width, height);
|
||||
}
|
||||
|
||||
} // namespace
|
16
plugin/android/src/main/cpp/deep_lap_3.h
Normal file
16
plugin/android/src/main/cpp/deep_lap_3.h
Normal file
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_com_nkming_nc_1photos_plugin_image_1processor_DeepLab3Portrait_inferNative(
|
||||
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
|
||||
jint width, jint height, jint radius);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
9
plugin/android/src/main/cpp/exception.cpp
Normal file
9
plugin/android/src/main/cpp/exception.cpp
Normal file
|
@ -0,0 +1,9 @@
|
|||
#include "exception.h"
|
||||
#include <jni.h>
|
||||
|
||||
void throwJavaException(JNIEnv *env, const char *msg) {
|
||||
jclass clz = env->FindClass("com/nkming/nc_photos/plugin/NativeException");
|
||||
if (clz) {
|
||||
env->ThrowNew(clz, msg);
|
||||
}
|
||||
}
|
5
plugin/android/src/main/cpp/exception.h
Normal file
5
plugin/android/src/main/cpp/exception.h
Normal file
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
void throwJavaException(JNIEnv *env, const char *msg);
|
15
plugin/android/src/main/cpp/log.h
Normal file
15
plugin/android/src/main/cpp/log.h
Normal file
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#include <android/log.h>
|
||||
|
||||
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, __VA_ARGS__)
|
||||
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, __VA_ARGS__)
|
||||
#ifdef NDEBUG
|
||||
#define LOGI(...)
|
||||
#define LOGD(...)
|
||||
#define LOGV(...)
|
||||
#else
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, __VA_ARGS__)
|
||||
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, __VA_ARGS__)
|
||||
#define LOGV(...) __android_log_print(ANDROID_LOG_VERBOSE, __VA_ARGS__)
|
||||
#endif
|
36
plugin/android/src/main/cpp/stopwatch.cpp
Normal file
36
plugin/android/src/main/cpp/stopwatch.cpp
Normal file
|
@ -0,0 +1,36 @@
|
|||
#include "stopwatch.h"
|
||||
#include <chrono>
|
||||
|
||||
using namespace std;
|
||||
|
||||
Stopwatch::Stopwatch()
|
||||
: is_start_(true), beg_(chrono::steady_clock::now()),
|
||||
time_elapsed_(chrono::steady_clock::duration::zero()),
|
||||
offset_(chrono::steady_clock::duration::zero()) {}
|
||||
|
||||
void Stopwatch::resume() {
|
||||
if (!is_start_) {
|
||||
beg_ = chrono::steady_clock::now();
|
||||
is_start_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void Stopwatch::pause() {
|
||||
if (is_start_) {
|
||||
time_elapsed_ += chrono::steady_clock::now() - beg_;
|
||||
is_start_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
chrono::steady_clock::duration Stopwatch::getTime() const {
|
||||
if (is_start_) {
|
||||
return time_elapsed_ + offset_ + (chrono::steady_clock::now() - beg_);
|
||||
} else {
|
||||
return time_elapsed_ + offset_;
|
||||
}
|
||||
}
|
||||
|
||||
void Stopwatch::resetClock() {
|
||||
time_elapsed_ = chrono::steady_clock::duration::zero();
|
||||
offset_ = chrono::steady_clock::duration::zero();
|
||||
}
|
68
plugin/android/src/main/cpp/stopwatch.h
Normal file
68
plugin/android/src/main/cpp/stopwatch.h
Normal file
|
@ -0,0 +1,68 @@
|
|||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
|
||||
/**
|
||||
* High precision time measurement. Suitable to benchmark codes
|
||||
*/
|
||||
class Stopwatch {
|
||||
public:
|
||||
/**
|
||||
* Constructor. The instance will by default be started so calling
|
||||
* start() right after the constructor is unnecessary
|
||||
*/
|
||||
Stopwatch();
|
||||
|
||||
void start() {
|
||||
stop();
|
||||
resetClock();
|
||||
resume();
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the stopwatch. All calls to getters will return the same value
|
||||
* until starting again
|
||||
*/
|
||||
void stop() { pause(); }
|
||||
|
||||
/**
|
||||
* Resume a previously paused stopwatch
|
||||
*/
|
||||
void resume();
|
||||
|
||||
/**
|
||||
* Pause the stopwatch. Will continue counting on resume()
|
||||
*/
|
||||
void pause();
|
||||
|
||||
template <typename T> typename T::rep get() const {
|
||||
return std::chrono::duration_cast<T>(getTime()).count();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the current measurement, in ns
|
||||
*
|
||||
* @return
|
||||
* @see get()
|
||||
*/
|
||||
int64_t getNs(void) const { return get<std::chrono::nanoseconds>(); }
|
||||
|
||||
/**
|
||||
* Return the current measurement, in ms
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
int64_t getMs() const { return get<std::chrono::milliseconds>(); }
|
||||
|
||||
std::chrono::steady_clock::duration getTime() const;
|
||||
void setOffset(std::chrono::steady_clock::duration a_d) { offset_ = a_d; }
|
||||
std::chrono::steady_clock::duration getOffset() { return offset_; }
|
||||
std::chrono::steady_clock::duration getTimeElapsed() { return time_elapsed_; }
|
||||
void resetClock();
|
||||
|
||||
private:
|
||||
bool is_start_;
|
||||
std::chrono::time_point<std::chrono::steady_clock> beg_;
|
||||
std::chrono::steady_clock::duration time_elapsed_;
|
||||
std::chrono::steady_clock::duration offset_;
|
||||
};
|
102
plugin/android/src/main/cpp/tflite_wrapper.cpp
Normal file
102
plugin/android/src/main/cpp/tflite_wrapper.cpp
Normal file
|
@ -0,0 +1,102 @@
|
|||
#include "tflite_wrapper.h"
|
||||
#include "util.h"
|
||||
#include <exception>
|
||||
#include <tensorflow/lite/c/c_api.h>
|
||||
|
||||
using namespace plugin;
|
||||
using namespace std;
|
||||
|
||||
namespace tflite {
|
||||
|
||||
Model::Model(Asset &&a) : asset(std::move(a)) {
|
||||
model = TfLiteModelCreate(asset.getBuffer(), asset.getSize());
|
||||
if (!model) {
|
||||
throw runtime_error("Error loading model file");
|
||||
}
|
||||
}
|
||||
|
||||
Model::Model(Model &&rhs) : asset(std::move(rhs.asset)), model(rhs.model) {
|
||||
rhs.model = nullptr;
|
||||
}
|
||||
|
||||
Model::~Model() {
|
||||
if (model) {
|
||||
TfLiteModelDelete(model);
|
||||
model = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
InterpreterOptions::InterpreterOptions() {
|
||||
options = TfLiteInterpreterOptionsCreate();
|
||||
if (!options) {
|
||||
throw runtime_error("Error calling TfLiteInterpreterOptionsCreate");
|
||||
}
|
||||
}
|
||||
|
||||
InterpreterOptions::~InterpreterOptions() {
|
||||
if (options) {
|
||||
TfLiteInterpreterOptionsDelete(options);
|
||||
}
|
||||
}
|
||||
|
||||
void InterpreterOptions::setNumThreads(const int num_threads) {
|
||||
TfLiteInterpreterOptionsSetNumThreads(options, num_threads);
|
||||
}
|
||||
|
||||
void InterpreterOptions::addDelegate(TfLiteDelegate *delegate) {
|
||||
TfLiteInterpreterOptionsAddDelegate(options, delegate);
|
||||
}
|
||||
|
||||
Interpreter::Interpreter(const Model &model) : Interpreter(model, nullptr) {}
|
||||
|
||||
Interpreter::Interpreter(const Model &model, const InterpreterOptions &options)
|
||||
: Interpreter(model, &options) {}
|
||||
|
||||
Interpreter::Interpreter(const Model &model,
|
||||
const InterpreterOptions *options) {
|
||||
interpreter =
|
||||
TfLiteInterpreterCreate(model.get(), options ? options->get() : nullptr);
|
||||
if (!interpreter) {
|
||||
throw runtime_error("Error creating interpreter");
|
||||
}
|
||||
}
|
||||
|
||||
Interpreter::~Interpreter() {
|
||||
if (interpreter) {
|
||||
TfLiteInterpreterDelete(interpreter);
|
||||
interpreter = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t Interpreter::getInputTensorCount() const {
|
||||
return TfLiteInterpreterGetInputTensorCount(interpreter);
|
||||
}
|
||||
|
||||
TfLiteTensor *Interpreter::getInputTensor(const int32_t inputIndex) {
|
||||
return TfLiteInterpreterGetInputTensor(interpreter, inputIndex);
|
||||
}
|
||||
|
||||
TfLiteStatus Interpreter::resizeInputTensor(const int32_t inputIndex,
|
||||
const int *inputDims,
|
||||
const int32_t inputDimsSize) {
|
||||
return TfLiteInterpreterResizeInputTensor(interpreter, inputIndex, inputDims,
|
||||
inputDimsSize);
|
||||
}
|
||||
|
||||
TfLiteStatus Interpreter::allocateTensors() {
|
||||
return TfLiteInterpreterAllocateTensors(interpreter);
|
||||
}
|
||||
|
||||
TfLiteStatus Interpreter::invoke() {
|
||||
return TfLiteInterpreterInvoke(interpreter);
|
||||
}
|
||||
|
||||
int32_t Interpreter::getOutputTensorCount() const {
|
||||
return TfLiteInterpreterGetOutputTensorCount(interpreter);
|
||||
}
|
||||
|
||||
const TfLiteTensor *Interpreter::getOutputTensor(const int32_t outputIndex) {
|
||||
return TfLiteInterpreterGetOutputTensor(interpreter, outputIndex);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
58
plugin/android/src/main/cpp/tflite_wrapper.h
Normal file
58
plugin/android/src/main/cpp/tflite_wrapper.h
Normal file
|
@ -0,0 +1,58 @@
|
|||
#pragma once
|
||||
|
||||
#include "util.h"
|
||||
#include <cstdint>
|
||||
#include <tensorflow/lite/c/c_api.h>
|
||||
|
||||
namespace tflite {
|
||||
|
||||
class Model {
|
||||
public:
|
||||
explicit Model(plugin::Asset &&asset);
|
||||
Model(Model &&rhs);
|
||||
~Model();
|
||||
|
||||
const TfLiteModel *get() const { return model; }
|
||||
const void *getBuffer() const;
|
||||
const size_t getSize() const;
|
||||
|
||||
private:
|
||||
plugin::Asset asset;
|
||||
TfLiteModel *model;
|
||||
};
|
||||
|
||||
class InterpreterOptions {
|
||||
public:
|
||||
InterpreterOptions();
|
||||
~InterpreterOptions();
|
||||
|
||||
const TfLiteInterpreterOptions *get() const { return options; }
|
||||
void setNumThreads(const int num_threads);
|
||||
void addDelegate(TfLiteDelegate *delegate);
|
||||
|
||||
private:
|
||||
TfLiteInterpreterOptions *options = nullptr;
|
||||
};
|
||||
|
||||
class Interpreter {
|
||||
public:
|
||||
explicit Interpreter(const Model &model);
|
||||
Interpreter(const Model &model, const InterpreterOptions &options);
|
||||
~Interpreter();
|
||||
|
||||
int32_t getInputTensorCount() const;
|
||||
TfLiteTensor *getInputTensor(const int32_t inputIndex);
|
||||
TfLiteStatus resizeInputTensor(const int32_t inputIndex, const int *inputDims,
|
||||
const int32_t inputDimsSize);
|
||||
TfLiteStatus allocateTensors();
|
||||
TfLiteStatus invoke();
|
||||
int32_t getOutputTensorCount() const;
|
||||
const TfLiteTensor *getOutputTensor(const int32_t outputIndex);
|
||||
|
||||
private:
|
||||
Interpreter(const Model &model, const InterpreterOptions *options);
|
||||
|
||||
TfLiteInterpreter *interpreter = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tflite
|
194
plugin/android/src/main/cpp/util.cpp
Normal file
194
plugin/android/src/main/cpp/util.cpp
Normal file
|
@ -0,0 +1,194 @@
|
|||
#include "util.h"
|
||||
#include "log.h"
|
||||
#include <RenderScriptToolkit.h>
|
||||
#include <algorithm>
|
||||
#include <android/asset_manager.h>
|
||||
#include <cstdarg>
|
||||
#include <exception>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <omp.h>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tensorflow/lite/c/c_api.h>
|
||||
#include <vector>
|
||||
|
||||
using namespace plugin;
|
||||
using namespace std;
|
||||
|
||||
namespace plugin {
|
||||
|
||||
Asset::Asset(AAssetManager *const aam, const string &name, const int mode) {
|
||||
asset = AAssetManager_open(aam, name.c_str(), AASSET_MODE_BUFFER);
|
||||
if (!asset) {
|
||||
throw runtime_error("Error loading asset file");
|
||||
}
|
||||
}
|
||||
|
||||
Asset::Asset(Asset &&rhs) {
|
||||
if (this != &rhs) {
|
||||
asset = rhs.asset;
|
||||
rhs.asset = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Asset::~Asset() {
|
||||
if (asset) {
|
||||
AAsset_close(asset);
|
||||
asset = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
const void *Asset::getBuffer() const { return AAsset_getBuffer(asset); }
|
||||
|
||||
const size_t Asset::getSize() const {
|
||||
return static_cast<size_t>(AAsset_getLength(asset));
|
||||
}
|
||||
|
||||
void initOpenMp() {
|
||||
const auto count = omp_get_num_procs();
|
||||
LOGI("OpenMp", "Number of threads: %d", count);
|
||||
omp_set_num_threads(count);
|
||||
}
|
||||
|
||||
int getNumberOfProcessors() { return omp_get_num_procs(); }
|
||||
|
||||
renderscript::RenderScriptToolkit &getToolkitInst() {
|
||||
static renderscript::RenderScriptToolkit inst(getNumberOfProcessors());
|
||||
return inst;
|
||||
}
|
||||
|
||||
string strprintf(const char *format, ...) {
|
||||
va_list arg;
|
||||
va_start(arg, format);
|
||||
|
||||
va_list arg_copy;
|
||||
va_copy(arg_copy, arg);
|
||||
const int size = vsnprintf(nullptr, 0, format, arg_copy);
|
||||
va_end(arg_copy);
|
||||
|
||||
if (size < 0) {
|
||||
va_end(arg);
|
||||
return "";
|
||||
}
|
||||
|
||||
string str(size + 1, '\0');
|
||||
vsnprintf(&str[0], size + 1, format, arg);
|
||||
// We don't want the null char
|
||||
str.pop_back();
|
||||
|
||||
va_end(arg);
|
||||
return str;
|
||||
}
|
||||
|
||||
string toString(const TfLiteTensor &tensor) {
|
||||
const auto numDims = TfLiteTensorNumDims(&tensor);
|
||||
stringstream ss;
|
||||
ss << "[";
|
||||
for (int i = 0; i < numDims; ++i) {
|
||||
ss << TfLiteTensorDim(&tensor, i) << ", ";
|
||||
}
|
||||
ss << "]";
|
||||
|
||||
return strprintf("TfLiteTensor {"
|
||||
"\"type: %d\", "
|
||||
"\"dimension\": %s, "
|
||||
"\"byteSize: %d\", "
|
||||
"}",
|
||||
TfLiteTensorType(&tensor), ss.str().c_str(),
|
||||
TfLiteTensorByteSize(&tensor));
|
||||
}
|
||||
|
||||
vector<float> rgb8ToRgbFloat(const uint8_t *rgb8, const size_t size,
|
||||
const bool shouldNormalize) {
|
||||
vector<float> rgbF(size);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (shouldNormalize) {
|
||||
rgbF[i] = rgb8[i] / 255.0f;
|
||||
} else {
|
||||
rgbF[i] = rgb8[i];
|
||||
}
|
||||
}
|
||||
return rgbF;
|
||||
}
|
||||
|
||||
vector<uint8_t> rgbFloatToRgb8(const float *rgbF, const size_t size,
|
||||
const bool isNormalized) {
|
||||
vector<uint8_t> rgb8(size);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
if (isNormalized) {
|
||||
rgb8[i] = clamp<int>(0, rgbF[i] * 255, 255);
|
||||
} else {
|
||||
rgb8[i] = clamp<int>(0, rgbF[i], 255);
|
||||
}
|
||||
}
|
||||
return rgb8;
|
||||
}
|
||||
|
||||
vector<uint8_t> rgb8ToRgba8(const uint8_t *rgb8, const size_t width,
|
||||
const size_t height) {
|
||||
vector<uint8_t> rgba8(width * height * 4);
|
||||
#pragma omp parallel for
|
||||
for (size_t y = 0; y < height; ++y) {
|
||||
for (size_t x = 0; x < width; ++x) {
|
||||
const auto i1 = y * width + x;
|
||||
const auto i3 = i1 * 3;
|
||||
const auto i4 = i1 * 4;
|
||||
memcpy(rgba8.data() + i4, rgb8 + i3, 3);
|
||||
rgba8[i4 + 3] = 0xFF;
|
||||
}
|
||||
}
|
||||
return rgba8;
|
||||
}
|
||||
|
||||
vector<uint8_t> rgba8ToRgb8(const uint8_t *rgba8, const size_t width,
|
||||
const size_t height) {
|
||||
vector<uint8_t> rgb8(width * height * 3);
|
||||
#pragma omp parallel for
|
||||
for (size_t y = 0; y < height; ++y) {
|
||||
for (size_t x = 0; x < width; ++x) {
|
||||
const auto i1 = y * width + x;
|
||||
const auto i3 = i1 * 3;
|
||||
const auto i4 = i1 * 4;
|
||||
memcpy(rgb8.data() + i3, rgba8 + i4, 3);
|
||||
}
|
||||
}
|
||||
return rgb8;
|
||||
}
|
||||
|
||||
void alphaBlend(const uint8_t *src, uint8_t *dst, const size_t width,
|
||||
const size_t height) {
|
||||
#pragma omp parallel for
|
||||
for (size_t y = 0; y < height; ++y) {
|
||||
for (size_t x = 0; x < width; ++x) {
|
||||
const auto i4 = (y * width + x) * 4;
|
||||
const auto srcA = src[i4 + 3] / 255.0f;
|
||||
const auto dstA = dst[i4 + 3] / 255.0f;
|
||||
const auto endA = srcA + dstA * (1 - srcA);
|
||||
// rgb
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
dst[i4 + i] =
|
||||
(src[i4 + i] * srcA + dst[i4 + i] * dstA * (1 - srcA)) / endA;
|
||||
}
|
||||
// a
|
||||
dst[i4 + 3] = clamp<int>(0, endA * 255, 255);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
vector<uint8_t> argmax(const float *output, const size_t width,
|
||||
const size_t height, const unsigned channel) {
|
||||
vector<uint8_t> product(width * height);
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < width * height; ++i) {
|
||||
const float *point = output + j;
|
||||
const auto maxIt = max_element(point, point + channel);
|
||||
product[i] = distance(point, maxIt);
|
||||
j += channel;
|
||||
}
|
||||
return product;
|
||||
}
|
||||
|
||||
} // namespace plugin
|
101
plugin/android/src/main/cpp/util.h
Normal file
101
plugin/android/src/main/cpp/util.h
Normal file
|
@ -0,0 +1,101 @@
|
|||
#pragma once
|
||||
|
||||
#include <android/asset_manager.h>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <omp.h>
|
||||
#include <string>
|
||||
#include <tensorflow/lite/c/c_api.h>
|
||||
#include <vector>
|
||||
|
||||
namespace renderscript {
|
||||
class RenderScriptToolkit;
|
||||
}
|
||||
|
||||
namespace plugin {
|
||||
|
||||
template <typename T> class RaiiContainer {
|
||||
public:
|
||||
RaiiContainer(const std::function<T *()> &constructor,
|
||||
const std::function<void(T *)> &destructor)
|
||||
: destructor(destructor) {
|
||||
obj = constructor();
|
||||
}
|
||||
|
||||
~RaiiContainer() {
|
||||
if (obj) {
|
||||
destructor(obj);
|
||||
}
|
||||
}
|
||||
|
||||
T &operator*() { return *obj; }
|
||||
|
||||
const T &operator*() const { return *obj; }
|
||||
|
||||
T *operator->() { return obj; }
|
||||
|
||||
const T *operator->() const { return obj; }
|
||||
|
||||
T *get() { return obj; }
|
||||
|
||||
private:
|
||||
T *obj = nullptr;
|
||||
std::function<void(T *)> destructor;
|
||||
};
|
||||
|
||||
class Asset {
|
||||
public:
|
||||
Asset(AAssetManager *const aam, const std::string &name,
|
||||
const int mode = AASSET_MODE_BUFFER);
|
||||
Asset(const Asset &) = delete;
|
||||
Asset(Asset &&rhs);
|
||||
~Asset();
|
||||
|
||||
const void *getBuffer() const;
|
||||
const size_t getSize() const;
|
||||
|
||||
private:
|
||||
AAsset *asset = nullptr;
|
||||
};
|
||||
|
||||
template <typename T> inline T clamp(const T &min, const T &x, const T &max) {
|
||||
return std::max(min, std::min(x, max));
|
||||
}
|
||||
|
||||
void initOpenMp();
|
||||
int getNumberOfProcessors();
|
||||
|
||||
renderscript::RenderScriptToolkit &getToolkitInst();
|
||||
|
||||
std::string strprintf(const char *format, ...);
|
||||
std::string toString(const TfLiteTensor &tensor);
|
||||
|
||||
std::vector<float> rgb8ToRgbFloat(const uint8_t *rgb, const size_t size,
|
||||
const bool shouldNormalize);
|
||||
std::vector<uint8_t> rgbFloatToRgb8(const float *rgbF, const size_t size,
|
||||
const bool isNormalized);
|
||||
std::vector<uint8_t> rgb8ToRgba8(const uint8_t *rgb8, const size_t width,
|
||||
const size_t height);
|
||||
std::vector<uint8_t> rgba8ToRgb8(const uint8_t *rgba8, const size_t width,
|
||||
const size_t height);
|
||||
|
||||
template <size_t ch>
|
||||
void replaceChannel(uint8_t *dst, const uint8_t *src, const size_t width,
|
||||
const size_t height, const unsigned targetChannel) {
|
||||
#pragma omp parallel for
|
||||
for (size_t y = 0; y < height; ++y) {
|
||||
for (size_t x = 0; x < width; ++x) {
|
||||
const auto i1 = y * width + x;
|
||||
const auto iN = i1 * ch;
|
||||
dst[iN + targetChannel] = src[i1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void alphaBlend(const uint8_t *src, uint8_t *dst, const size_t width,
|
||||
const size_t height);
|
||||
|
||||
std::vector<uint8_t> argmax(const float *output, const size_t width,
|
||||
const size_t height, const unsigned channel);
|
||||
|
||||
} // namespace plugin
|
142
plugin/android/src/main/cpp/zero_dce.cpp
Normal file
142
plugin/android/src/main/cpp/zero_dce.cpp
Normal file
|
@ -0,0 +1,142 @@
|
|||
#include "base_resample.h"
|
||||
#include "exception.h"
|
||||
#include "log.h"
|
||||
#include "stopwatch.h"
|
||||
#include "tflite_wrapper.h"
|
||||
#include "util.h"
|
||||
#include <algorithm>
|
||||
#include <android/asset_manager.h>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <cassert>
|
||||
#include <exception>
|
||||
#include <jni.h>
|
||||
#include <omp.h>
|
||||
#include <tensorflow/lite/c/c_api.h>
|
||||
|
||||
using namespace plugin;
|
||||
using namespace std;
|
||||
using namespace tflite;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr const char *MODEL = "zero_dce_lite_200x300_iter8_60.tflite";
|
||||
constexpr size_t WIDTH = 300;
|
||||
constexpr size_t HEIGHT = 200;
|
||||
|
||||
class ZeroDce {
|
||||
public:
|
||||
explicit ZeroDce(AAssetManager *const aam);
|
||||
|
||||
std::vector<uint8_t> infer(const uint8_t *image, const size_t width,
|
||||
const size_t height, const unsigned iteration);
|
||||
|
||||
private:
|
||||
std::vector<uint8_t> inferAlphaMaps(const uint8_t *image, const size_t width,
|
||||
const size_t height);
|
||||
std::vector<uint8_t> enhance(const uint8_t *image, const size_t width,
|
||||
const size_t height,
|
||||
const std::vector<uint8_t> &alphaMaps,
|
||||
const unsigned iteration);
|
||||
|
||||
Model model;
|
||||
|
||||
static constexpr const char *TAG = "ZeroDce";
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" JNIEXPORT jbyteArray JNICALL
|
||||
Java_com_nkming_nc_1photos_plugin_image_1processor_ZeroDce_inferNative(
|
||||
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
|
||||
jint width, jint height, jint iteration) {
|
||||
try {
|
||||
initOpenMp();
|
||||
auto aam = AAssetManager_fromJava(env, assetManager);
|
||||
ZeroDce model(aam);
|
||||
RaiiContainer<jbyte> cImage(
|
||||
[&]() { return env->GetByteArrayElements(image, nullptr); },
|
||||
[&](jbyte *obj) {
|
||||
env->ReleaseByteArrayElements(image, obj, JNI_ABORT);
|
||||
});
|
||||
const auto result = model.infer(reinterpret_cast<uint8_t *>(cImage.get()),
|
||||
width, height, iteration);
|
||||
auto resultAry = env->NewByteArray(result.size());
|
||||
env->SetByteArrayRegion(resultAry, 0, result.size(),
|
||||
reinterpret_cast<const int8_t *>(result.data()));
|
||||
return resultAry;
|
||||
} catch (const exception &e) {
|
||||
throwJavaException(env, e.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
ZeroDce::ZeroDce(AAssetManager *const aam) : model(Asset(aam, MODEL)) {}
|
||||
|
||||
vector<uint8_t> ZeroDce::infer(const uint8_t *image, const size_t width,
|
||||
const size_t height, const unsigned iteration) {
|
||||
const auto alphaMaps = inferAlphaMaps(image, width, height);
|
||||
return enhance(image, width, height, alphaMaps, iteration);
|
||||
}
|
||||
|
||||
vector<uint8_t> ZeroDce::inferAlphaMaps(const uint8_t *image,
|
||||
const size_t width,
|
||||
const size_t height) {
|
||||
InterpreterOptions options;
|
||||
options.setNumThreads(getNumberOfProcessors());
|
||||
Interpreter interpreter(model, options);
|
||||
interpreter.allocateTensors();
|
||||
|
||||
LOGI(TAG, "[inferAlphaMaps] Convert bitmap to input");
|
||||
vector<uint8_t> inputBitmap(WIDTH * HEIGHT * 3);
|
||||
base::ResampleImage24(image, width, height, inputBitmap.data(), WIDTH, HEIGHT,
|
||||
base::KernelTypeLanczos3);
|
||||
const auto input =
|
||||
rgb8ToRgbFloat(inputBitmap.data(), inputBitmap.size(), true);
|
||||
auto inputTensor = interpreter.getInputTensor(0);
|
||||
assert(TfLiteTensorByteSize(inputTensor) == input.size() * sizeof(float));
|
||||
TfLiteTensorCopyFromBuffer(inputTensor, input.data(),
|
||||
input.size() * sizeof(float));
|
||||
|
||||
LOGI(TAG, "[inferAlphaMaps] Inferring");
|
||||
Stopwatch stopwatch;
|
||||
interpreter.invoke();
|
||||
LOGI(TAG, "[inferAlphaMaps] Elapsed: %.3fs", stopwatch.getMs() / 1000.0f);
|
||||
|
||||
auto outputTensor = interpreter.getOutputTensor(1);
|
||||
vector<float> output(input.size());
|
||||
assert(TfLiteTensorByteSize(outputTensor) == output.size() * sizeof(float));
|
||||
TfLiteTensorCopyToBuffer(outputTensor, output.data(),
|
||||
output.size() * sizeof(float));
|
||||
// the output is in negative, we need to abs them
|
||||
for (size_t i = 0; i < output.size(); ++i) {
|
||||
output[i] = fabsf(output[i]);
|
||||
}
|
||||
return rgbFloatToRgb8(output.data(), output.size(), true);
|
||||
}
|
||||
|
||||
vector<uint8_t> ZeroDce::enhance(const uint8_t *image, const size_t width,
|
||||
const size_t height,
|
||||
const vector<uint8_t> &alphaMaps,
|
||||
const unsigned iteration) {
|
||||
LOGI(TAG, "[enhance] Enhancing image, iteration: %d", iteration);
|
||||
// resize aMaps
|
||||
vector<uint8_t> filter(width * height * 3);
|
||||
base::ResampleImage24(alphaMaps.data(), WIDTH, HEIGHT, filter.data(), width,
|
||||
height, base::KernelTypeBicubic);
|
||||
|
||||
vector<uint8_t> output(width * height * 3);
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < filter.size(); ++i) {
|
||||
auto s = image[i] / 255.0f;
|
||||
const auto f = filter[i] / 255.0f;
|
||||
for (unsigned j = 0; j < iteration; ++j) {
|
||||
s += -f * (std::pow(s, 2.0f) - s);
|
||||
}
|
||||
output[i] = std::max(0, std::min(static_cast<int>(s * 255), 255));
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace
|
16
plugin/android/src/main/cpp/zero_dce.h
Normal file
16
plugin/android/src/main/cpp/zero_dce.h
Normal file
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_com_nkming_nc_1photos_plugin_image_1processor_ZeroDce_inferNative(
|
||||
JNIEnv *env, jobject *thiz, jobject assetManager, jbyteArray image,
|
||||
jint width, jint height, jint iteration);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
|
@ -3,3 +3,5 @@ package com.nkming.nc_photos.plugin
|
|||
class PermissionException(message: String) : Exception(message)
|
||||
|
||||
class HttpException(statusCode: Int, message: String): Exception(message)
|
||||
|
||||
class NativeException(message: String) : Exception(message)
|
||||
|
|
|
@ -13,6 +13,10 @@ class NcPhotosPlugin : FlutterPlugin, ActivityAware,
|
|||
PluginRegistry.ActivityResultListener,
|
||||
PluginRegistry.RequestPermissionsResultListener {
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("plugin")
|
||||
}
|
||||
|
||||
const val ACTION_SHOW_IMAGE_PROCESSOR_RESULT =
|
||||
K.ACTION_SHOW_IMAGE_PROCESSOR_RESULT
|
||||
const val EXTRA_IMAGE_RESULT_URI = K.EXTRA_IMAGE_RESULT_URI
|
||||
|
|
|
@ -1,17 +1,11 @@
|
|||
package com.nkming.nc_photos.plugin.image_processor
|
||||
|
||||
import android.content.Context
|
||||
import android.content.res.AssetManager
|
||||
import android.graphics.*
|
||||
import android.net.Uri
|
||||
import com.google.android.renderscript.Toolkit
|
||||
import com.nkming.nc_photos.plugin.BitmapResizeMethod
|
||||
import com.nkming.nc_photos.plugin.BitmapUtil
|
||||
import com.nkming.nc_photos.plugin.logI
|
||||
import com.nkming.nc_photos.plugin.transform
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import java.io.File
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.FloatBuffer
|
||||
|
||||
/**
|
||||
* DeepLab is a state-of-art deep learning model for semantic image
|
||||
|
@ -20,132 +14,35 @@ import java.nio.FloatBuffer
|
|||
*
|
||||
* See: https://github.com/tensorflow/models/tree/master/research/deeplab
|
||||
*/
|
||||
private class DeepLab3(context: Context) {
|
||||
companion object {
|
||||
private const val MODEL = "lite-model_mobilenetv2-dm05-coco_dr_1.tflite"
|
||||
const val WIDTH = 513
|
||||
const val HEIGHT = 513
|
||||
|
||||
private const val TAG = "DeepLab3"
|
||||
}
|
||||
|
||||
enum class Label(val value: Int) {
|
||||
BACKGROUND(0),
|
||||
AEROPLANE(1),
|
||||
BICYCLE(2),
|
||||
BIRD(3),
|
||||
BOAT(4),
|
||||
BOTTLE(5),
|
||||
BUS(6),
|
||||
CAR(7),
|
||||
CAT(8),
|
||||
CHAIR(9),
|
||||
COW(10),
|
||||
DINING_TABLE(11),
|
||||
DOG(12),
|
||||
HORSE(13),
|
||||
MOTORBIKE(14),
|
||||
PERSON(15),
|
||||
POTTED_PLANT(16),
|
||||
SHEEP(17),
|
||||
SOFA(18),
|
||||
TRAIN(19),
|
||||
TV(20),
|
||||
}
|
||||
|
||||
fun infer(imageUri: Uri): ByteBuffer {
|
||||
val interpreter =
|
||||
Interpreter(TfLiteHelper.loadModelFromAsset(context, MODEL))
|
||||
interpreter.allocateTensors()
|
||||
|
||||
logI(TAG, "Converting bitmap to input")
|
||||
val inputBitmap =
|
||||
BitmapUtil.loadImageFixed(context, imageUri, WIDTH, HEIGHT)
|
||||
val input = TfLiteHelper.bitmapToRgbFloatArray(inputBitmap)
|
||||
val output = FloatBuffer.allocate(WIDTH * HEIGHT * Label.values().size)
|
||||
logI(TAG, "Inferring")
|
||||
interpreter.run(input, output)
|
||||
return TfLiteHelper.argmax(output, WIDTH, HEIGHT, Label.values().size)
|
||||
}
|
||||
|
||||
private val context = context
|
||||
}
|
||||
|
||||
class DeepLab3Portrait(
|
||||
context: Context, maxWidth: Int, maxHeight: Int, radius: Int
|
||||
) {
|
||||
companion object {
|
||||
private const val TAG = "DeepLab3Portrait"
|
||||
}
|
||||
|
||||
fun infer(imageUri: Uri): Bitmap {
|
||||
val segmentMap = deepLab.infer(imageUri).also {
|
||||
postProcessSegmentMap(it)
|
||||
}
|
||||
return enhance(imageUri, segmentMap, radius)
|
||||
}
|
||||
|
||||
/**
|
||||
* Post-process the segment map.
|
||||
*
|
||||
* The resulting segment map will:
|
||||
* 1. Contain only the most significant label (the one with the most pixel)
|
||||
* 2. The label value set to 255
|
||||
* 3. The background set to 0
|
||||
*
|
||||
* @param segmentMap
|
||||
*/
|
||||
private fun postProcessSegmentMap(segmentMap: ByteBuffer) {
|
||||
// keep only the largest segment
|
||||
val count = mutableMapOf<Byte, Int>()
|
||||
segmentMap.array().forEach {
|
||||
if (it != DeepLab3.Label.BACKGROUND.value.toByte()) {
|
||||
count[it] = (count[it] ?: 0) + 1
|
||||
}
|
||||
}
|
||||
val keep = count.maxByOrNull { it.value }?.key
|
||||
segmentMap.array().transform { if (it == keep) 0xFF.toByte() else 0 }
|
||||
}
|
||||
|
||||
private fun enhance(
|
||||
imageUri: Uri, segmentMap: ByteBuffer, radius: Int
|
||||
): Bitmap {
|
||||
logI(TAG, "[enhance] Enhancing image")
|
||||
// downscale original to prevent OOM
|
||||
val orig = BitmapUtil.loadImage(
|
||||
val width: Int
|
||||
val height: Int
|
||||
val rgb8Image = BitmapUtil.loadImage(
|
||||
context, imageUri, maxWidth, maxHeight, BitmapResizeMethod.FIT,
|
||||
isAllowSwapSide = true, shouldUpscale = false
|
||||
)
|
||||
val bg = Toolkit.blur(orig, radius)
|
||||
).let {
|
||||
width = it.width
|
||||
height = it.height
|
||||
val rgb8 = TfLiteHelper.bitmapToRgb8Array(it)
|
||||
it.recycle()
|
||||
rgb8
|
||||
}
|
||||
val am = context.assets
|
||||
|
||||
var alpha = Bitmap.createBitmap(
|
||||
DeepLab3.WIDTH, DeepLab3.HEIGHT, Bitmap.Config.ALPHA_8
|
||||
)
|
||||
alpha.copyPixelsFromBuffer(segmentMap)
|
||||
alpha = Bitmap.createScaledBitmap(alpha, orig.width, orig.height, true)
|
||||
// blur the mask to smoothen the edge
|
||||
alpha = Toolkit.blur(alpha, 16)
|
||||
File(context.filesDir, "alpha.png").outputStream().use {
|
||||
alpha.compress(Bitmap.CompressFormat.PNG, 50, it)
|
||||
return inferNative(am, rgb8Image, width, height, radius).let {
|
||||
TfLiteHelper.rgb8ArrayToBitmap(it, width, height)
|
||||
}
|
||||
|
||||
val shader = ComposeShader(
|
||||
BitmapShader(orig, Shader.TileMode.CLAMP, Shader.TileMode.CLAMP),
|
||||
BitmapShader(alpha, Shader.TileMode.CLAMP, Shader.TileMode.CLAMP),
|
||||
PorterDuff.Mode.DST_ATOP
|
||||
)
|
||||
val paint = Paint().apply {
|
||||
setShader(shader)
|
||||
}
|
||||
Canvas(bg).apply {
|
||||
drawRect(0f, 0f, orig.width.toFloat(), orig.height.toFloat(), paint)
|
||||
}
|
||||
return bg
|
||||
}
|
||||
|
||||
private external fun inferNative(
|
||||
am: AssetManager, image: ByteArray, width: Int, height: Int, radius: Int
|
||||
): ByteArray
|
||||
|
||||
private val context = context
|
||||
private val maxWidth = maxWidth
|
||||
private val maxHeight = maxHeight
|
||||
private val radius = radius
|
||||
private val deepLab = DeepLab3(context)
|
||||
}
|
||||
|
|
|
@ -1,67 +1,46 @@
|
|||
package com.nkming.nc_photos.plugin.image_processor
|
||||
|
||||
import android.content.Context
|
||||
import android.graphics.Bitmap
|
||||
import java.io.FileInputStream
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.FloatBuffer
|
||||
import java.nio.IntBuffer
|
||||
import java.nio.channels.FileChannel
|
||||
import kotlin.math.abs
|
||||
|
||||
interface TfLiteHelper {
|
||||
companion object {
|
||||
/**
|
||||
* Load a TFLite model from the assets dir
|
||||
*
|
||||
* @param context
|
||||
* @param name Name of the model file
|
||||
* @return
|
||||
*/
|
||||
fun loadModelFromAsset(context: Context, name: String): ByteBuffer {
|
||||
val fd = context.assets.openFd(name)
|
||||
val istream = FileInputStream(fd.fileDescriptor)
|
||||
val channel = istream.channel
|
||||
return channel.map(
|
||||
FileChannel.MapMode.READ_ONLY, fd.startOffset, fd.declaredLength
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert an ARGB_8888 Android bitmap to a float RGB buffer
|
||||
* Convert an ARGB_8888 Android bitmap to a RGB8 byte array
|
||||
*
|
||||
* @param bitmap
|
||||
* @return
|
||||
*/
|
||||
fun bitmapToRgbFloatArray(bitmap: Bitmap): FloatBuffer {
|
||||
fun bitmapToRgb8Array(bitmap: Bitmap): ByteArray {
|
||||
val buffer = IntBuffer.allocate(bitmap.width * bitmap.height)
|
||||
bitmap.copyPixelsToBuffer(buffer)
|
||||
val input = FloatBuffer.allocate(bitmap.width * bitmap.height * 3)
|
||||
buffer.array().forEach {
|
||||
input.put((it and 0xFF) / 255.0f)
|
||||
input.put((it shr 8 and 0xFF) / 255.0f)
|
||||
input.put((it shr 16 and 0xFF) / 255.0f)
|
||||
val rgb8 = ByteArray(bitmap.width * bitmap.height * 3)
|
||||
buffer.array().forEachIndexed { i, it ->
|
||||
run {
|
||||
rgb8[i * 3] = (it and 0xFF).toByte()
|
||||
rgb8[i * 3 + 1] = (it shr 8 and 0xFF).toByte()
|
||||
rgb8[i * 3 + 2] = (it shr 16 and 0xFF).toByte()
|
||||
}
|
||||
}
|
||||
input.rewind()
|
||||
return input
|
||||
return rgb8
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a float RGB buffer to an ARGB_8888 Android bitmap
|
||||
* Convert a RGB8 byte array to an ARGB_8888 Android bitmap
|
||||
*
|
||||
* @param output
|
||||
* @param rgb8
|
||||
* @param width
|
||||
* @param height
|
||||
* @return
|
||||
*/
|
||||
fun rgbFloatArrayToBitmap(
|
||||
output: FloatBuffer, width: Int, height: Int
|
||||
fun rgb8ArrayToBitmap(
|
||||
rgb8: ByteArray, width: Int, height: Int
|
||||
): Bitmap {
|
||||
val buffer = IntBuffer.allocate(width * height)
|
||||
var i = 0
|
||||
var pixel = 0
|
||||
output.array().forEach {
|
||||
val value = (abs(it * 255f)).toInt().coerceIn(0, 255)
|
||||
rgb8.forEach {
|
||||
val value = it.toInt() and 0xFF
|
||||
when (i++) {
|
||||
0 -> {
|
||||
// A
|
||||
|
@ -88,20 +67,5 @@ interface TfLiteHelper {
|
|||
outputBitmap.copyPixelsFromBuffer(buffer)
|
||||
return outputBitmap
|
||||
}
|
||||
|
||||
fun argmax(
|
||||
output: FloatBuffer, width: Int, height: Int, channel: Int
|
||||
): ByteBuffer {
|
||||
val product = ByteBuffer.allocate(width * height)
|
||||
val array = output.array()
|
||||
var j = 0
|
||||
for (i in 0 until width * height) {
|
||||
val pixel = array.slice(j until j + channel)
|
||||
val max = pixel.indices.maxByOrNull { pixel[it] }!!
|
||||
product.put(i, max.toByte())
|
||||
j += channel
|
||||
}
|
||||
return product
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,106 +1,38 @@
|
|||
package com.nkming.nc_photos.plugin.image_processor
|
||||
|
||||
import android.content.Context
|
||||
import android.content.res.AssetManager
|
||||
import android.graphics.Bitmap
|
||||
import android.net.Uri
|
||||
import com.nkming.nc_photos.plugin.BitmapResizeMethod
|
||||
import com.nkming.nc_photos.plugin.BitmapUtil
|
||||
import com.nkming.nc_photos.plugin.logI
|
||||
import org.tensorflow.lite.Interpreter
|
||||
import java.nio.FloatBuffer
|
||||
import java.nio.IntBuffer
|
||||
import kotlin.math.pow
|
||||
|
||||
class ZeroDce(context: Context, maxWidth: Int, maxHeight: Int, iteration: Int) {
|
||||
companion object {
|
||||
private const val TAG = "ZeroDce"
|
||||
private const val MODEL = "zero_dce_lite_200x300_iter8_60.tflite"
|
||||
private const val WIDTH = 300
|
||||
private const val HEIGHT = 200
|
||||
}
|
||||
|
||||
fun infer(imageUri: Uri): Bitmap {
|
||||
val alphaMaps = inferAlphaMaps(imageUri)
|
||||
return enhance(imageUri, alphaMaps, iteration)
|
||||
}
|
||||
|
||||
private fun inferAlphaMaps(imageUri: Uri): Bitmap {
|
||||
val interpreter =
|
||||
Interpreter(TfLiteHelper.loadModelFromAsset(context, MODEL))
|
||||
interpreter.allocateTensors()
|
||||
|
||||
logI(TAG, "Converting bitmap to input")
|
||||
val inputBitmap =
|
||||
BitmapUtil.loadImageFixed(context, imageUri, WIDTH, HEIGHT)
|
||||
val inputs = arrayOf(TfLiteHelper.bitmapToRgbFloatArray(inputBitmap))
|
||||
val outputs = mapOf(
|
||||
0 to FloatBuffer.allocate(inputs[0].capacity()),
|
||||
1 to FloatBuffer.allocate(inputs[0].capacity())
|
||||
)
|
||||
logI(TAG, "Inferring")
|
||||
interpreter.runForMultipleInputsOutputs(inputs, outputs)
|
||||
|
||||
return TfLiteHelper.rgbFloatArrayToBitmap(
|
||||
outputs[1]!!, inputBitmap.width, inputBitmap.height
|
||||
)
|
||||
}
|
||||
|
||||
private fun enhance(
|
||||
imageUri: Uri, alphaMaps: Bitmap, iteration: Int
|
||||
): Bitmap {
|
||||
logI(TAG, "Enhancing image, iteration: $iteration")
|
||||
// we can't work with FloatBuffer directly here as a FloatBuffer is way
|
||||
// too large to fit in Android's heap limit
|
||||
// downscale original to prevent OOM
|
||||
val width: Int
|
||||
val height: Int
|
||||
val imgBuf: IntBuffer
|
||||
BitmapUtil.loadImage(
|
||||
val rgb8Image = BitmapUtil.loadImage(
|
||||
context, imageUri, maxWidth, maxHeight, BitmapResizeMethod.FIT,
|
||||
isAllowSwapSide = true, shouldUpscale = false
|
||||
).apply {
|
||||
width = this.width
|
||||
height = this.height
|
||||
imgBuf = IntBuffer.allocate(width * height)
|
||||
copyPixelsToBuffer(imgBuf)
|
||||
recycle()
|
||||
).let {
|
||||
width = it.width
|
||||
height = it.height
|
||||
val rgb8 = TfLiteHelper.bitmapToRgb8Array(it)
|
||||
it.recycle()
|
||||
rgb8
|
||||
}
|
||||
imgBuf.rewind()
|
||||
val am = context.assets
|
||||
|
||||
// resize aMaps
|
||||
val filterBuf: IntBuffer
|
||||
Bitmap.createScaledBitmap(alphaMaps, width, height, true).apply {
|
||||
filterBuf = IntBuffer.allocate(width * height)
|
||||
copyPixelsToBuffer(filterBuf)
|
||||
recycle()
|
||||
return inferNative(am, rgb8Image, width, height, iteration).let {
|
||||
TfLiteHelper.rgb8ArrayToBitmap(it, width, height)
|
||||
}
|
||||
filterBuf.rewind()
|
||||
|
||||
val src = imgBuf.array()
|
||||
val filter = filterBuf.array()
|
||||
for (i in src.indices) {
|
||||
var sr = (src[i] and 0xFF) / 255f
|
||||
var sg = (src[i] shr 8 and 0xFF) / 255f
|
||||
var sb = (src[i] shr 16 and 0xFF) / 255f
|
||||
val fr = (filter[i] and 0xFF) / 255f
|
||||
val fg = (filter[i] shr 8 and 0xFF) / 255f
|
||||
val fb = (filter[i] shr 16 and 0xFF) / 255f
|
||||
for (j in 0 until iteration) {
|
||||
sr += -fr * (sr.pow(2f) - sr)
|
||||
sg += -fg * (sg.pow(2f) - sg)
|
||||
sb += -fb * (sb.pow(2f) - sb)
|
||||
}
|
||||
src[i] = (0xFF shl 24) or
|
||||
((sr * 255).toInt().coerceIn(0, 255)) or
|
||||
((sg * 255).toInt().coerceIn(0, 255) shl 8) or
|
||||
((sb * 255).toInt().coerceIn(0, 255) shl 16)
|
||||
}
|
||||
return Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888)
|
||||
.apply {
|
||||
copyPixelsFromBuffer(imgBuf)
|
||||
}
|
||||
}
|
||||
|
||||
private external fun inferNative(
|
||||
am: AssetManager, image: ByteArray, width: Int, height: Int,
|
||||
iteration: Int
|
||||
): ByteArray
|
||||
|
||||
private val context = context
|
||||
private val maxWidth = maxWidth
|
||||
private val maxHeight = maxHeight
|
||||
|
|
Loading…
Reference in a new issue