xref: /aosp_15_r20/system/chre/apps/tflm_demo/src/model.cc (revision 84e339476a462649f82315436d70fd732297a399)
1*84e33947SAndroid Build Coastguard Worker /*
2*84e33947SAndroid Build Coastguard Worker  * Copyright (C) 2020 The Android Open Source Project
3*84e33947SAndroid Build Coastguard Worker  *
4*84e33947SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*84e33947SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*84e33947SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*84e33947SAndroid Build Coastguard Worker  *
8*84e33947SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*84e33947SAndroid Build Coastguard Worker  *
10*84e33947SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*84e33947SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*84e33947SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*84e33947SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*84e33947SAndroid Build Coastguard Worker  * limitations under the License.
15*84e33947SAndroid Build Coastguard Worker  */
16*84e33947SAndroid Build Coastguard Worker 
17*84e33947SAndroid Build Coastguard Worker #include "model.h"
18*84e33947SAndroid Build Coastguard Worker 
19*84e33947SAndroid Build Coastguard Worker #include "sine_model_data.h"
20*84e33947SAndroid Build Coastguard Worker #include "tensorflow/lite/micro/kernels/micro_ops.h"
21*84e33947SAndroid Build Coastguard Worker #include "tensorflow/lite/micro/micro_interpreter.h"
22*84e33947SAndroid Build Coastguard Worker #include "tensorflow/lite/micro/micro_log.h"
23*84e33947SAndroid Build Coastguard Worker #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
24*84e33947SAndroid Build Coastguard Worker #include "tensorflow/lite/schema/schema_generated.h"
25*84e33947SAndroid Build Coastguard Worker 
26*84e33947SAndroid Build Coastguard Worker //  The following registration code is generated. Check the following commit for
27*84e33947SAndroid Build Coastguard Worker //  details.
28*84e33947SAndroid Build Coastguard Worker //  https://github.com/tensorflow/tensorflow/commit/098556c3a96e1d51f79606c0834547cb2aa20908
29*84e33947SAndroid Build Coastguard Worker 
30*84e33947SAndroid Build Coastguard Worker namespace {
RegisterSelectedOps(::tflite::MicroMutableOpResolver * resolver)31*84e33947SAndroid Build Coastguard Worker void RegisterSelectedOps(::tflite::MicroMutableOpResolver *resolver) {
32*84e33947SAndroid Build Coastguard Worker   resolver->AddBuiltin(
33*84e33947SAndroid Build Coastguard Worker       ::tflite::BuiltinOperator_FULLY_CONNECTED,
34*84e33947SAndroid Build Coastguard Worker       // For now the op version is not supported in the generated code, so this
35*84e33947SAndroid Build Coastguard Worker       // version still needs to added manually.
36*84e33947SAndroid Build Coastguard Worker       ::tflite::ops::micro::Register_FULLY_CONNECTED(), 1, 4);
37*84e33947SAndroid Build Coastguard Worker }
38*84e33947SAndroid Build Coastguard Worker }  // namespace
39*84e33947SAndroid Build Coastguard Worker 
40*84e33947SAndroid Build Coastguard Worker namespace demo {
run(float x_val)41*84e33947SAndroid Build Coastguard Worker float run(float x_val) {
42*84e33947SAndroid Build Coastguard Worker   const tflite::Model *model = tflite::GetModel(g_sine_model_data);
43*84e33947SAndroid Build Coastguard Worker   // TODO(wangtz): Check for schema version.
44*84e33947SAndroid Build Coastguard Worker 
45*84e33947SAndroid Build Coastguard Worker   tflite::MicroMutableOpResolver resolver;
46*84e33947SAndroid Build Coastguard Worker   RegisterSelectedOps(&resolver);
47*84e33947SAndroid Build Coastguard Worker   constexpr int kTensorAreanaSize = 2 * 1024;
48*84e33947SAndroid Build Coastguard Worker   uint8_t tensor_arena[kTensorAreanaSize];
49*84e33947SAndroid Build Coastguard Worker 
50*84e33947SAndroid Build Coastguard Worker   tflite::MicroInterpreter interpreter(model, resolver, tensor_arena,
51*84e33947SAndroid Build Coastguard Worker                                        kTensorAreanaSize);
52*84e33947SAndroid Build Coastguard Worker   interpreter.AllocateTensors();
53*84e33947SAndroid Build Coastguard Worker 
54*84e33947SAndroid Build Coastguard Worker   TfLiteTensor *input = interpreter.input(0);
55*84e33947SAndroid Build Coastguard Worker   TfLiteTensor *output = interpreter.output(0);
56*84e33947SAndroid Build Coastguard Worker   input->data.f[0] = x_val;
57*84e33947SAndroid Build Coastguard Worker   TfLiteStatus invoke_status = interpreter.Invoke();
58*84e33947SAndroid Build Coastguard Worker   if (invoke_status != kTfLiteOk) {
59*84e33947SAndroid Build Coastguard Worker     MicroPrintf("Internal error: invoke failed.");
60*84e33947SAndroid Build Coastguard Worker     return 0.0;
61*84e33947SAndroid Build Coastguard Worker   }
62*84e33947SAndroid Build Coastguard Worker   float y_val = output->data.f[0];
63*84e33947SAndroid Build Coastguard Worker   return y_val;
64*84e33947SAndroid Build Coastguard Worker }
65*84e33947SAndroid Build Coastguard Worker 
66*84e33947SAndroid Build Coastguard Worker }  // namespace demo
67