xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/c/c_api_experimental.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/c/c_api_experimental.h"
17 
18 #include <stdint.h>
19 
20 #include <memory>
21 
22 #include "tensorflow/lite/builtin_ops.h"
23 #include "tensorflow/lite/c/c_api.h"
24 #include "tensorflow/lite/c/c_api_internal.h"
25 #include "tensorflow/lite/interpreter.h"
26 #include "tensorflow/lite/signature_runner.h"
27 
28 extern "C" {
29 
TfLiteInterpreterResetVariableTensors(TfLiteInterpreter * interpreter)30 TfLiteStatus TfLiteInterpreterResetVariableTensors(
31     TfLiteInterpreter* interpreter) {
32   return interpreter->impl->ResetVariableTensors();
33 }
34 
TfLiteInterpreterOptionsAddBuiltinOp(TfLiteInterpreterOptions * options,TfLiteBuiltinOperator op,const TfLiteRegistration * registration,int32_t min_version,int32_t max_version)35 void TfLiteInterpreterOptionsAddBuiltinOp(
36     TfLiteInterpreterOptions* options, TfLiteBuiltinOperator op,
37     const TfLiteRegistration* registration, int32_t min_version,
38     int32_t max_version) {
39   options->mutable_op_resolver.AddBuiltin(
40       static_cast<tflite::BuiltinOperator>(op), registration, min_version,
41       max_version);
42 }
43 
TfLiteInterpreterCreateWithSelectedOps(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options)44 TfLiteInterpreter* TfLiteInterpreterCreateWithSelectedOps(
45     const TfLiteModel* model,
46     const TfLiteInterpreterOptions* optional_options) {
47   tflite::MutableOpResolver resolver;
48   return tflite::internal::InterpreterCreateWithOpResolver(
49       model, optional_options, &resolver);
50 }
51 
TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions * options,const char * name,const TfLiteRegistration * registration,int32_t min_version,int32_t max_version)52 void TfLiteInterpreterOptionsAddCustomOp(TfLiteInterpreterOptions* options,
53                                          const char* name,
54                                          const TfLiteRegistration* registration,
55                                          int32_t min_version,
56                                          int32_t max_version) {
57   options->mutable_op_resolver.AddCustom(name, registration, min_version,
58                                          max_version);
59 }
60 
TfLiteInterpreterOptionsSetOpResolver(TfLiteInterpreterOptions * options,const TfLiteRegistration * (* find_builtin_op)(void * user_data,TfLiteBuiltinOperator op,int version),const TfLiteRegistration * (* find_custom_op)(void * user_data,const char * op,int version),void * op_resolver_user_data)61 void TfLiteInterpreterOptionsSetOpResolver(
62     TfLiteInterpreterOptions* options,
63     const TfLiteRegistration* (*find_builtin_op)(void* user_data,
64                                                  TfLiteBuiltinOperator op,
65                                                  int version),
66     const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op,
67                                                 int version),
68     void* op_resolver_user_data) {
69   options->op_resolver_callbacks.find_builtin_op = find_builtin_op;
70   options->op_resolver_callbacks.find_custom_op = find_custom_op;
71   options->op_resolver_callbacks.user_data = op_resolver_user_data;
72 }
73 
TfLiteInterpreterOptionsSetOpResolverV1(TfLiteInterpreterOptions * options,const TfLiteRegistration_V1 * (* find_builtin_op_v1)(void * user_data,TfLiteBuiltinOperator op,int version),const TfLiteRegistration_V1 * (* find_custom_op_v1)(void * user_data,const char * op,int version),void * op_resolver_user_data)74 void TfLiteInterpreterOptionsSetOpResolverV1(
75     TfLiteInterpreterOptions* options,
76     const TfLiteRegistration_V1* (*find_builtin_op_v1)(void* user_data,
77                                                        TfLiteBuiltinOperator op,
78                                                        int version),
79     const TfLiteRegistration_V1* (*find_custom_op_v1)(void* user_data,
80                                                       const char* op,
81                                                       int version),
82     void* op_resolver_user_data) {
83   options->op_resolver_callbacks.find_builtin_op_v1 = find_builtin_op_v1;
84   options->op_resolver_callbacks.find_custom_op_v1 = find_custom_op_v1;
85   options->op_resolver_callbacks.user_data = op_resolver_user_data;
86 }
87 
TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions * options,bool enable)88 void TfLiteInterpreterOptionsSetUseNNAPI(TfLiteInterpreterOptions* options,
89                                          bool enable) {
90   options->use_nnapi = enable;
91 }
92 
TfLiteInterpreterOptionsSetEnableDelegateFallback(TfLiteInterpreterOptions * options,bool enable)93 void TfLiteInterpreterOptionsSetEnableDelegateFallback(
94     TfLiteInterpreterOptions* options, bool enable) {
95   options->enable_delegate_fallback = enable;
96 }
97 
TfLiteSetAllowBufferHandleOutput(const TfLiteInterpreter * interpreter,bool allow_buffer_handle_output)98 void TfLiteSetAllowBufferHandleOutput(const TfLiteInterpreter* interpreter,
99                                       bool allow_buffer_handle_output) {
100   interpreter->impl->SetAllowBufferHandleOutput(allow_buffer_handle_output);
101 }
102 
TfLiteInterpreterModifyGraphWithDelegate(const TfLiteInterpreter * interpreter,TfLiteDelegate * delegate)103 TfLiteStatus TfLiteInterpreterModifyGraphWithDelegate(
104     const TfLiteInterpreter* interpreter, TfLiteDelegate* delegate) {
105   return interpreter->impl->ModifyGraphWithDelegate(delegate);
106 }
107 
TfLiteInterpreterGetInputTensorIndex(const TfLiteInterpreter * interpreter,int32_t input_index)108 int32_t TfLiteInterpreterGetInputTensorIndex(
109     const TfLiteInterpreter* interpreter, int32_t input_index) {
110   return interpreter->impl->inputs()[input_index];
111 }
112 
TfLiteInterpreterGetOutputTensorIndex(const TfLiteInterpreter * interpreter,int32_t output_index)113 int32_t TfLiteInterpreterGetOutputTensorIndex(
114     const TfLiteInterpreter* interpreter, int32_t output_index) {
115   return interpreter->impl->outputs()[output_index];
116 }
117 
TfLiteInterpreterGetSignatureCount(const TfLiteInterpreter * interpreter)118 int32_t TfLiteInterpreterGetSignatureCount(
119     const TfLiteInterpreter* interpreter) {
120   return static_cast<int32_t>(interpreter->impl->signature_keys().size());
121 }
122 
TfLiteInterpreterGetSignatureKey(const TfLiteInterpreter * interpreter,int32_t signature_index)123 const char* TfLiteInterpreterGetSignatureKey(
124     const TfLiteInterpreter* interpreter, int32_t signature_index) {
125   int32_t signature_count = TfLiteInterpreterGetSignatureCount(interpreter);
126   if (signature_index < 0 || signature_index >= signature_count) {
127     return nullptr;
128   }
129   return interpreter->impl->signature_keys()[signature_index]->c_str();
130 }
131 
TfLiteInterpreterGetSignatureRunner(const TfLiteInterpreter * interpreter,const char * signature_key)132 TfLiteSignatureRunner* TfLiteInterpreterGetSignatureRunner(
133     const TfLiteInterpreter* interpreter, const char* signature_key) {
134   tflite::SignatureRunner* signature_runner =
135       interpreter->impl->GetSignatureRunner(signature_key);
136   if (!signature_runner) return nullptr;
137   return new TfLiteSignatureRunner{signature_runner};
138 }
139 
TfLiteSignatureRunnerGetInputCount(const TfLiteSignatureRunner * signature_runner)140 size_t TfLiteSignatureRunnerGetInputCount(
141     const TfLiteSignatureRunner* signature_runner) {
142   return signature_runner->impl->input_size();
143 }
144 
TfLiteSignatureRunnerGetInputName(const TfLiteSignatureRunner * signature_runner,const int32_t input_index)145 const char* TfLiteSignatureRunnerGetInputName(
146     const TfLiteSignatureRunner* signature_runner, const int32_t input_index) {
147   int32_t input_count = TfLiteSignatureRunnerGetInputCount(signature_runner);
148   if (input_index < 0 || input_index >= input_count) {
149     return nullptr;
150   }
151   return signature_runner->impl->input_names()[input_index];
152 }
153 
TfLiteSignatureRunnerResizeInputTensor(TfLiteSignatureRunner * signature_runner,const char * input_name,const int * input_dims,int32_t input_dims_size)154 TfLiteStatus TfLiteSignatureRunnerResizeInputTensor(
155     TfLiteSignatureRunner* signature_runner, const char* input_name,
156     const int* input_dims, int32_t input_dims_size) {
157   std::vector<int> dims{input_dims, input_dims + input_dims_size};
158   return signature_runner->impl->ResizeInputTensorStrict(input_name, dims);
159 }
160 
TfLiteSignatureRunnerAllocateTensors(TfLiteSignatureRunner * signature_runner)161 TfLiteStatus TfLiteSignatureRunnerAllocateTensors(
162     TfLiteSignatureRunner* signature_runner) {
163   return signature_runner->impl->AllocateTensors();
164 }
165 
TfLiteSignatureRunnerGetInputTensor(TfLiteSignatureRunner * signature_runner,const char * input_name)166 TfLiteTensor* TfLiteSignatureRunnerGetInputTensor(
167     TfLiteSignatureRunner* signature_runner, const char* input_name) {
168   return signature_runner->impl->input_tensor(input_name);
169 }
170 
TfLiteSignatureRunnerInvoke(TfLiteSignatureRunner * signature_runner)171 TfLiteStatus TfLiteSignatureRunnerInvoke(
172     TfLiteSignatureRunner* signature_runner) {
173   return signature_runner->impl->Invoke();
174 }
175 
TfLiteSignatureRunnerGetOutputCount(const TfLiteSignatureRunner * signature_runner)176 size_t TfLiteSignatureRunnerGetOutputCount(
177     const TfLiteSignatureRunner* signature_runner) {
178   return signature_runner->impl->output_size();
179 }
180 
TfLiteSignatureRunnerGetOutputName(const TfLiteSignatureRunner * signature_runner,int32_t output_index)181 const char* TfLiteSignatureRunnerGetOutputName(
182     const TfLiteSignatureRunner* signature_runner, int32_t output_index) {
183   int32_t output_count = TfLiteSignatureRunnerGetOutputCount(signature_runner);
184   if (output_index < 0 || output_index >= output_count) {
185     return nullptr;
186   }
187   return signature_runner->impl->output_names()[output_index];
188 }
189 
TfLiteSignatureRunnerGetOutputTensor(const TfLiteSignatureRunner * signature_runner,const char * output_name)190 const TfLiteTensor* TfLiteSignatureRunnerGetOutputTensor(
191     const TfLiteSignatureRunner* signature_runner, const char* output_name) {
192   return signature_runner->impl->output_tensor(output_name);
193 }
194 
TfLiteSignatureRunnerDelete(TfLiteSignatureRunner * signature_runner)195 void TfLiteSignatureRunnerDelete(TfLiteSignatureRunner* signature_runner) {
196   delete signature_runner;
197 }
198 
199 }  // extern "C"
200