diff --git a/plugin/android/src/main/cpp/arbitrary_style_transfer.cpp b/plugin/android/src/main/cpp/arbitrary_style_transfer.cpp index d34d9a53..29a88324 100644 --- a/plugin/android/src/main/cpp/arbitrary_style_transfer.cpp +++ b/plugin/android/src/main/cpp/arbitrary_style_transfer.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include using namespace plugin; @@ -184,6 +185,11 @@ ArbitraryStyleTransfer::transfer(const uint8_t *image, const size_t width, vector ArbitraryStyleTransfer::predictStyle(const uint8_t *style) { InterpreterOptions options; options.setNumThreads(getNumberOfProcessors()); + + auto gpuOptions = TfLiteGpuDelegateOptionsV2Default(); + auto gpuDelegate = AutoTfLiteDelegate(TfLiteGpuDelegateV2Create(&gpuOptions)); + options.addDelegate(gpuDelegate.get()); + Interpreter interpreter(predictModel, options); interpreter.allocateTensors(); diff --git a/plugin/android/src/main/cpp/deep_lap_3.cpp b/plugin/android/src/main/cpp/deep_lap_3.cpp index dc4fe944..bbfce12c 100644 --- a/plugin/android/src/main/cpp/deep_lap_3.cpp +++ b/plugin/android/src/main/cpp/deep_lap_3.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include "./filter/saturation.h" @@ -207,6 +208,11 @@ vector DeepLab3::infer(const uint8_t *image, const size_t width, const size_t height) { InterpreterOptions options; options.setNumThreads(getNumberOfProcessors()); + + auto gpuOptions = TfLiteGpuDelegateOptionsV2Default(); + auto gpuDelegate = AutoTfLiteDelegate(TfLiteGpuDelegateV2Create(&gpuOptions)); + options.addDelegate(gpuDelegate.get()); + Interpreter interpreter(model, options); interpreter.allocateTensors(); diff --git a/plugin/android/src/main/cpp/tflite_wrapper.cpp b/plugin/android/src/main/cpp/tflite_wrapper.cpp index 9978fb9b..8a2fd177 100644 --- a/plugin/android/src/main/cpp/tflite_wrapper.cpp +++ b/plugin/android/src/main/cpp/tflite_wrapper.cpp @@ -2,6 +2,7 @@ #include "util.h" #include #include +#include using namespace plugin; using namespace std; @@ -99,4 +100,10 @@ const TfLiteTensor *Interpreter::getOutputTensor(const int32_t outputIndex) { return TfLiteInterpreterGetOutputTensor(interpreter, outputIndex); } +AutoTfLiteDelegate::~AutoTfLiteDelegate() { + if (inst) { + TfLiteGpuDelegateV2Delete(inst); + } +} + } // namespace tflite \ No newline at end of file diff --git a/plugin/android/src/main/cpp/tflite_wrapper.h b/plugin/android/src/main/cpp/tflite_wrapper.h index 80d4fbdc..0e57bd13 100644 --- a/plugin/android/src/main/cpp/tflite_wrapper.h +++ b/plugin/android/src/main/cpp/tflite_wrapper.h @@ -55,4 +55,17 @@ private: TfLiteInterpreter *interpreter = nullptr; }; +class AutoTfLiteDelegate { +public: + explicit AutoTfLiteDelegate(TfLiteDelegate *inst) : inst(inst) {} + ~AutoTfLiteDelegate(); + + TfLiteDelegate &operator*() { return *inst; } + TfLiteDelegate *operator->() { return inst; } + TfLiteDelegate *get() { return inst; } + +private: + TfLiteDelegate *const inst; +}; + } // namespace tflite diff --git a/plugin/android/src/main/cpp/zero_dce.cpp b/plugin/android/src/main/cpp/zero_dce.cpp index 55e9763c..d43c9ff6 100644 --- a/plugin/android/src/main/cpp/zero_dce.cpp +++ b/plugin/android/src/main/cpp/zero_dce.cpp @@ -11,6 +11,7 @@ #include #include #include +#include using namespace plugin; using namespace std; @@ -84,6 +85,11 @@ vector ZeroDce::inferAlphaMaps(const uint8_t *image, const size_t height) { InterpreterOptions options; options.setNumThreads(getNumberOfProcessors()); + + auto gpuOptions = TfLiteGpuDelegateOptionsV2Default(); + auto gpuDelegate = AutoTfLiteDelegate(TfLiteGpuDelegateV2Create(&gpuOptions)); + options.addDelegate(gpuDelegate.get()); + Interpreter interpreter(model, options); interpreter.allocateTensors();