Run models with fixed input size on GPU

This commit is contained in:
Ming Ming 2022-09-25 22:58:43 +08:00
parent fa0d275fad
commit c5584cc550
5 changed files with 38 additions and 0 deletions

View file

@ -12,6 +12,7 @@
#include <exception> #include <exception>
#include <jni.h> #include <jni.h>
#include <tensorflow/lite/c/c_api.h> #include <tensorflow/lite/c/c_api.h>
#include <tensorflow/lite/delegates/gpu/delegate.h>
#include <vector> #include <vector>
using namespace plugin; using namespace plugin;
@ -184,6 +185,11 @@ ArbitraryStyleTransfer::transfer(const uint8_t *image, const size_t width,
vector<float> ArbitraryStyleTransfer::predictStyle(const uint8_t *style) { vector<float> ArbitraryStyleTransfer::predictStyle(const uint8_t *style) {
InterpreterOptions options; InterpreterOptions options;
options.setNumThreads(getNumberOfProcessors()); options.setNumThreads(getNumberOfProcessors());
auto gpuOptions = TfLiteGpuDelegateOptionsV2Default();
auto gpuDelegate = AutoTfLiteDelegate(TfLiteGpuDelegateV2Create(&gpuOptions));
options.addDelegate(gpuDelegate.get());
Interpreter interpreter(predictModel, options); Interpreter interpreter(predictModel, options);
interpreter.allocateTensors(); interpreter.allocateTensors();

View file

@ -14,6 +14,7 @@
#include <exception> #include <exception>
#include <jni.h> #include <jni.h>
#include <tensorflow/lite/c/c_api.h> #include <tensorflow/lite/c/c_api.h>
#include <tensorflow/lite/delegates/gpu/delegate.h>
#include "./filter/saturation.h" #include "./filter/saturation.h"
@ -207,6 +208,11 @@ vector<uint8_t> DeepLab3::infer(const uint8_t *image, const size_t width,
const size_t height) { const size_t height) {
InterpreterOptions options; InterpreterOptions options;
options.setNumThreads(getNumberOfProcessors()); options.setNumThreads(getNumberOfProcessors());
auto gpuOptions = TfLiteGpuDelegateOptionsV2Default();
auto gpuDelegate = AutoTfLiteDelegate(TfLiteGpuDelegateV2Create(&gpuOptions));
options.addDelegate(gpuDelegate.get());
Interpreter interpreter(model, options); Interpreter interpreter(model, options);
interpreter.allocateTensors(); interpreter.allocateTensors();

View file

@ -2,6 +2,7 @@
#include "util.h" #include "util.h"
#include <exception> #include <exception>
#include <tensorflow/lite/c/c_api.h> #include <tensorflow/lite/c/c_api.h>
#include <tensorflow/lite/delegates/gpu/delegate.h>
using namespace plugin; using namespace plugin;
using namespace std; using namespace std;
@ -99,4 +100,10 @@ const TfLiteTensor *Interpreter::getOutputTensor(const int32_t outputIndex) {
return TfLiteInterpreterGetOutputTensor(interpreter, outputIndex); return TfLiteInterpreterGetOutputTensor(interpreter, outputIndex);
} }
AutoTfLiteDelegate::~AutoTfLiteDelegate() {
if (inst) {
TfLiteGpuDelegateV2Delete(inst);
}
}
} // namespace tflite } // namespace tflite

View file

@ -55,4 +55,17 @@ private:
TfLiteInterpreter *interpreter = nullptr; TfLiteInterpreter *interpreter = nullptr;
}; };
class AutoTfLiteDelegate {
public:
explicit AutoTfLiteDelegate(TfLiteDelegate *inst) : inst(inst) {}
~AutoTfLiteDelegate();
TfLiteDelegate &operator*() { return *inst; }
TfLiteDelegate *operator->() { return inst; }
TfLiteDelegate *get() { return inst; }
private:
TfLiteDelegate *const inst;
};
} // namespace tflite } // namespace tflite

View file

@ -11,6 +11,7 @@
#include <exception> #include <exception>
#include <jni.h> #include <jni.h>
#include <tensorflow/lite/c/c_api.h> #include <tensorflow/lite/c/c_api.h>
#include <tensorflow/lite/delegates/gpu/delegate.h>
using namespace plugin; using namespace plugin;
using namespace std; using namespace std;
@ -84,6 +85,11 @@ vector<uint8_t> ZeroDce::inferAlphaMaps(const uint8_t *image,
const size_t height) { const size_t height) {
InterpreterOptions options; InterpreterOptions options;
options.setNumThreads(getNumberOfProcessors()); options.setNumThreads(getNumberOfProcessors());
auto gpuOptions = TfLiteGpuDelegateOptionsV2Default();
auto gpuDelegate = AutoTfLiteDelegate(TfLiteGpuDelegateV2Create(&gpuOptions));
options.addDelegate(gpuDelegate.get());
Interpreter interpreter(model, options); Interpreter interpreter(model, options);
interpreter.allocateTensors(); interpreter.allocateTensors();