1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/c_api_experimental.h"
17
18 #include <stdint.h>
19
20 #include <memory>
21
22 #include "tensorflow/lite/builtin_ops.h"
23 #include "tensorflow/lite/c/c_api.h"
24 #include "tensorflow/lite/c/c_api_internal.h"
25 #include "tensorflow/lite/interpreter.h"
26 #include "tensorflow/lite/signature_runner.h"
27
28 extern "C" {
29
TfLiteInterpreterResetVariableTensors(TfLiteInterpreter * interpreter)30 TfLiteStatus TfLiteInterpreterResetVariableTensors(
31 TfLiteInterpreter* interpreter) {
32 return interpreter->impl->ResetVariableTensors();
33 }
34
TfLiteInterpreterOptionsAddBuiltinOp(TfLiteInterpreterOptions * options,TfLiteBuiltinOperator op,const TfLiteRegistration * registration,int32_t min_version,int32_t max_version)35 void TfLiteInterpreterOptionsAddBuiltinOp(
36 TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
37 const TfLiteRegistration* registration, int32_t min_version,
38 int32_t max_version) {
39 options->mutable_op_resolver.AddBuiltin(
40 static_cast<tflite::BuiltinOperator>(op), registration, min_version,
41 max_version);
42 }
43
TfLiteInterpreterCreateWithSelectedOps(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options)44 TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps(
45 const TfLiteModel* model,
46 const TfLiteInterpreterOptions* optional_options) {
47 tflite::MutableOpResolver resolver;
48 return tflite::internal::InterpreterCreateWithOpResolver(
49 model, optional_options, &resolver);
50 }
51
TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions * options,const char * name,const TfLiteRegistration * registration,int32_t min_version,int32_t max_version)52 void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options,
53 const char* name,
54 const TfLiteRegistration* registration,
55 int32_t min_version,
56 int32_t max_version) {
57 options->mutable_op_resolver.AddCustom(name, registration, min_version,
58 max_version);
59 }
60
TfLiteInterpreterOptionsSetOpResolver(TfLiteInterpreterOptions * options,const TfLiteRegistration * (* find_builtin_op)(void * user_data,TfLiteBuiltinOperator op,int version),const TfLiteRegistration * (* find_custom_op)(void * user_data,const char * op,int version),void * op_resolver_user_data)61 void TfLiteInterpreterOptionsSetOpResolver(
62 TfLiteInterpreterOptions* options,
63 const TfLiteRegistration* (*find_builtin_op)(void* user_data,
64 TfLiteBuiltinOperator op,
65 int version),
66 const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op,
67 int version),
68 void* op_resolver_user_data) {
69 options->op_resolver_callbacks.find_builtin_op = find_builtin_op;
70 options->op_resolver_callbacks.find_custom_op = find_custom_op;
71 options->op_resolver_callbacks.user_data = op_resolver_user_data;
72 }
73
TfLiteInterpreterOptionsSetOpResolverV1(TfLiteInterpreterOptions * options,const TfLiteRegistration_V1 * (* find_builtin_op_v1)(void * user_data,TfLiteBuiltinOperator op,int version),const TfLiteRegistration_V1 * (* find_custom_op_v1)(void * user_data,const char * op,int version),void * op_resolver_user_data)74 void TfLiteInterpreterOptionsSetOpResolverV1(
75 TfLiteInterpreterOptions* options,
76 const TfLiteRegistration_V1* (*find_builtin_op_v1)(void* user_data,
77 TfLiteBuiltinOperator op,
78 int version),
79 const TfLiteRegistration_V1* (*find_custom_op_v1)(void* user_data,
80 const char* op,
81 int version),
82 void* op_resolver_user_data) {
83 options->op_resolver_callbacks.find_builtin_op_v1 = find_builtin_op_v1;
84 options->op_resolver_callbacks.find_custom_op_v1 = find_custom_op_v1;
85 options->op_resolver_callbacks.user_data = op_resolver_user_data;
86 }
87
TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions * options,bool enable)88 void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options,
89 bool enable) {
90 options->use_nnapi = enable;
91 }
92
TfLiteInterpreterOptionsSetEnableDelegateFallback(TfLiteInterpreterOptions * options,bool enable)93 void TfLiteInterpreterOptionsSetEnableDelegateFallback(
94 TfLiteInterpreterOptions* options, bool enable) {
95 options->enable_delegate_fallback = enable;
96 }
97
TfLiteSetAllowBufferHandleOutput(const TfLiteInterpreter * interpreter,bool allow_buffer_handle_output)98 void TfLiteSetAllowBufferHandleOutput(const TfLiteInterpreter* interpreter,
99 bool allow_buffer_handle_output) {
100 interpreter->impl->SetAllowBufferHandleOutput(allow_buffer_handle_output);
101 }
102
TfLiteInterpreterModifyGraphWithDelegate(const TfLiteInterpreter * interpreter,TfLiteDelegate * delegate)103 TfLiteStatus TfLiteInterpreterModifyGraphWithDelegate(
104 const TfLiteInterpreter* interpreter, TfLiteDelegate* delegate) {
105 return interpreter->impl->ModifyGraphWithDelegate(delegate);
106 }
107
TfLiteInterpreterGetInputTensorIndex(const TfLiteInterpreter * interpreter,int32_t input_index)108 int32_t TfLiteInterpreterGetInputTensorIndex(
109 const TfLiteInterpreter* interpreter, int32_t input_index) {
110 return interpreter->impl->inputs()[input_index];
111 }
112
TfLiteInterpreterGetOutputTensorIndex(const TfLiteInterpreter * interpreter,int32_t output_index)113 int32_t TfLiteInterpreterGetOutputTensorIndex(
114 const TfLiteInterpreter* interpreter, int32_t output_index) {
115 return interpreter->impl->outputs()[output_index];
116 }
117
TfLiteInterpreterGetSignatureCount(const TfLiteInterpreter * interpreter)118 int32_t TfLiteInterpreterGetSignatureCount(
119 const TfLiteInterpreter* interpreter) {
120 return static_cast<int32_t>(interpreter->impl->signature_keys().size());
121 }
122
TfLiteInterpreterGetSignatureKey(const TfLiteInterpreter * interpreter,int32_t signature_index)123 const char* TfLiteInterpreterGetSignatureKey(
124 const TfLiteInterpreter* interpreter, int32_t signature_index) {
125 int32_t signature_count = TfLiteInterpreterGetSignatureCount(interpreter);
126 if (signature_index < 0 || signature_index >= signature_count) {
127 return nullptr;
128 }
129 return interpreter->impl->signature_keys()[signature_index]->c_str();
130 }
131
TfLiteInterpreterGetSignatureRunner(const TfLiteInterpreter * interpreter,const char * signature_key)132 TfLiteSignatureRunner* TfLiteInterpreterGetSignatureRunner(
133 const TfLiteInterpreter* interpreter, const char* signature_key) {
134 tflite::SignatureRunner* signature_runner =
135 interpreter->impl->GetSignatureRunner(signature_key);
136 if (!signature_runner) return nullptr;
137 return new TfLiteSignatureRunner{signature_runner};
138 }
139
TfLiteSignatureRunnerGetInputCount(const TfLiteSignatureRunner * signature_runner)140 size_t TfLiteSignatureRunnerGetInputCount(
141 const TfLiteSignatureRunner* signature_runner) {
142 return signature_runner->impl->input_size();
143 }
144
TfLiteSignatureRunnerGetInputName(const TfLiteSignatureRunner * signature_runner,const int32_t input_index)145 const char* TfLiteSignatureRunnerGetInputName(
146 const TfLiteSignatureRunner* signature_runner, const int32_t input_index) {
147 int32_t input_count = TfLiteSignatureRunnerGetInputCount(signature_runner);
148 if (input_index < 0 || input_index >= input_count) {
149 return nullptr;
150 }
151 return signature_runner->impl->input_names()[input_index];
152 }
153
TfLiteSignatureRunnerResizeInputTensor(TfLiteSignatureRunner * signature_runner,const char * input_name,const int * input_dims,int32_t input_dims_size)154 TfLiteStatus TfLiteSignatureRunnerResizeInputTensor(
155 TfLiteSignatureRunner* signature_runner, const char* input_name,
156 const int* input_dims, int32_t input_dims_size) {
157 std::vector<int> dims{input_dims, input_dims + input_dims_size};
158 return signature_runner->impl->ResizeInputTensorStrict(input_name, dims);
159 }
160
TfLiteSignatureRunnerAllocateTensors(TfLiteSignatureRunner * signature_runner)161 TfLiteStatus TfLiteSignatureRunnerAllocateTensors(
162 TfLiteSignatureRunner* signature_runner) {
163 return signature_runner->impl->AllocateTensors();
164 }
165
TfLiteSignatureRunnerGetInputTensor(TfLiteSignatureRunner * signature_runner,const char * input_name)166 TfLiteTensor* TfLiteSignatureRunnerGetInputTensor(
167 TfLiteSignatureRunner* signature_runner, const char* input_name) {
168 return signature_runner->impl->input_tensor(input_name);
169 }
170
TfLiteSignatureRunnerInvoke(TfLiteSignatureRunner * signature_runner)171 TfLiteStatus TfLiteSignatureRunnerInvoke(
172 TfLiteSignatureRunner* signature_runner) {
173 return signature_runner->impl->Invoke();
174 }
175
TfLiteSignatureRunnerGetOutputCount(const TfLiteSignatureRunner * signature_runner)176 size_t TfLiteSignatureRunnerGetOutputCount(
177 const TfLiteSignatureRunner* signature_runner) {
178 return signature_runner->impl->output_size();
179 }
180
TfLiteSignatureRunnerGetOutputName(const TfLiteSignatureRunner * signature_runner,int32_t output_index)181 const char* TfLiteSignatureRunnerGetOutputName(
182 const TfLiteSignatureRunner* signature_runner, int32_t output_index) {
183 int32_t output_count = TfLiteSignatureRunnerGetOutputCount(signature_runner);
184 if (output_index < 0 || output_index >= output_count) {
185 return nullptr;
186 }
187 return signature_runner->impl->output_names()[output_index];
188 }
189
TfLiteSignatureRunnerGetOutputTensor(const TfLiteSignatureRunner * signature_runner,const char * output_name)190 const TfLiteTensor* TfLiteSignatureRunnerGetOutputTensor(
191 const TfLiteSignatureRunner* signature_runner, const char* output_name) {
192 return signature_runner->impl->output_tensor(output_name);
193 }
194
TfLiteSignatureRunnerDelete(TfLiteSignatureRunner * signature_runner)195 void TfLiteSignatureRunnerDelete(TfLiteSignatureRunner* signature_runner) {
196 delete signature_runner;
197 }
198
199 } // extern "C"
200