xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/test_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/flex/test_util.h"
17 
18 #include "absl/memory/memory.h"
19 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
20 #include "tensorflow/lite/string_type.h"
21 
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25 
Invoke()26 bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
27 
SetStringValues(int tensor_index,const std::vector<string> & values)28 void FlexModelTest::SetStringValues(int tensor_index,
29                                     const std::vector<string>& values) {
30   DynamicBuffer dynamic_buffer;
31   for (const string& s : values) {
32     dynamic_buffer.AddString(s.data(), s.size());
33   }
34   dynamic_buffer.WriteToTensor(interpreter_->tensor(tensor_index),
35                                /*new_shape=*/nullptr);
36 }
37 
GetStringValues(int tensor_index) const38 std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
39   std::vector<string> result;
40 
41   TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
42   auto num_strings = GetStringCount(tensor);
43   for (size_t i = 0; i < num_strings; ++i) {
44     auto ref = GetString(tensor, i);
45     result.push_back(string(ref.str, ref.len));
46   }
47 
48   return result;
49 }
50 
SetShape(int tensor_index,const std::vector<int> & values)51 void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
52   ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
53   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 }
55 
GetShape(int tensor_index)56 std::vector<int> FlexModelTest::GetShape(int tensor_index) {
57   std::vector<int> result;
58   auto* dims = interpreter_->tensor(tensor_index)->dims;
59   result.reserve(dims->size);
60   for (int i = 0; i < dims->size; ++i) {
61     result.push_back(dims->data[i]);
62   }
63   return result;
64 }
65 
GetType(int tensor_index)66 TfLiteType FlexModelTest::GetType(int tensor_index) {
67   return interpreter_->tensor(tensor_index)->type;
68 }
69 
IsDynamicTensor(int tensor_index)70 bool FlexModelTest::IsDynamicTensor(int tensor_index) {
71   return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic;
72 }
73 
AddTensors(int num_tensors,const std::vector<int> & inputs,const std::vector<int> & outputs,TfLiteType type,const std::vector<int> & dims)74 void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
75                                const std::vector<int>& outputs, TfLiteType type,
76                                const std::vector<int>& dims) {
77   interpreter_->AddTensors(num_tensors);
78   for (int i = 0; i < num_tensors; ++i) {
79     TfLiteQuantizationParams quant;
80     CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
81                                                         /*name=*/"",
82                                                         /*dims=*/dims, quant),
83              kTfLiteOk);
84   }
85 
86   CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
87   CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
88 }
89 
SetConstTensor(int tensor_index,const std::vector<int> & values,TfLiteType type,const char * buffer,size_t bytes)90 void FlexModelTest::SetConstTensor(int tensor_index,
91                                    const std::vector<int>& values,
92                                    TfLiteType type, const char* buffer,
93                                    size_t bytes) {
94   TfLiteQuantizationParams quant;
95   CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
96                                                      /*name=*/"",
97                                                      /*dims=*/values, quant,
98                                                      buffer, bytes),
99            kTfLiteOk);
100 }
101 
AddTfLiteMulOp(const std::vector<int> & inputs,const std::vector<int> & outputs)102 void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
103                                    const std::vector<int>& outputs) {
104   ++next_op_index_;
105 
106   static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
107   reg.builtin_code = BuiltinOperator_MUL;
108   reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
109     auto* i0 = &context->tensors[node->inputs->data[0]];
110     auto* o = &context->tensors[node->outputs->data[0]];
111     return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
112   };
113   reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
114     auto* i0 = &context->tensors[node->inputs->data[0]];
115     auto* i1 = &context->tensors[node->inputs->data[1]];
116     auto* o = &context->tensors[node->outputs->data[0]];
117     for (int i = 0; i < o->bytes / sizeof(float); ++i) {
118       o->data.f[i] = i0->data.f[i] * i1->data.f[i];
119     }
120     return kTfLiteOk;
121   };
122 
123   CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
124                                                nullptr, &reg),
125            kTfLiteOk);
126 }
127 
AddTfOp(TfOpType op,const std::vector<int> & inputs,const std::vector<int> & outputs)128 void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
129                             const std::vector<int>& outputs) {
130   tf_ops_.push_back(next_op_index_);
131   ++next_op_index_;
132 
133   auto attr = [](const string& key, const string& value) {
134     return " attr{ key: '" + key + "' value {" + value + "}}";
135   };
136 
137   string type_attribute;
138   switch (interpreter_->tensor(inputs[0])->type) {
139     case kTfLiteInt32:
140       type_attribute = attr("T", "type: DT_INT32");
141       break;
142     case kTfLiteFloat32:
143       type_attribute = attr("T", "type: DT_FLOAT");
144       break;
145     case kTfLiteString:
146       type_attribute = attr("T", "type: DT_STRING");
147       break;
148     case kTfLiteBool:
149       type_attribute = attr("T", "type: DT_BOOL");
150       break;
151     default:
152       // TODO(b/113613439): Use nodedef string utilities to properly handle all
153       // types.
154       LOG(FATAL) << "Type not supported";
155       break;
156   }
157 
158   if (op == kUnpack) {
159     string attributes =
160         type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
161     AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
162   } else if (op == kIdentity) {
163     string attributes = type_attribute;
164     AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
165   } else if (op == kAdd) {
166     string attributes = type_attribute;
167     AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
168   } else if (op == kMul) {
169     string attributes = type_attribute;
170     AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
171   } else if (op == kRfft) {
172     AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
173   } else if (op == kImag) {
174     AddTfOp("FlexImag", "Imag", "", inputs, outputs);
175   } else if (op == kLoopCond) {
176     string attributes = type_attribute;
177     AddTfOp("FlexLoopCond", "LoopCond", attributes, inputs, outputs);
178   } else if (op == kNonExistent) {
179     AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
180   } else if (op == kIncompatibleNodeDef) {
181     // "Cast" op is created without attributes - making it incompatible.
182     AddTfOp("FlexCast", "Cast", "", inputs, outputs);
183   }
184 }
185 
AddTfOp(const char * tflite_name,const string & tf_name,const string & nodedef_str,const std::vector<int> & inputs,const std::vector<int> & outputs)186 void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
187                             const string& nodedef_str,
188                             const std::vector<int>& inputs,
189                             const std::vector<int>& outputs) {
190   static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
191   reg.builtin_code = BuiltinOperator_CUSTOM;
192   reg.custom_name = tflite_name;
193 
194   tensorflow::NodeDef nodedef;
195   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
196       nodedef_str + " op: '" + tf_name + "'", &nodedef));
197   string serialized_nodedef;
198   CHECK(nodedef.SerializeToString(&serialized_nodedef));
199   flexbuffers::Builder fbb;
200   fbb.Vector([&]() {
201     fbb.String(nodedef.op());
202     fbb.String(serialized_nodedef);
203   });
204   fbb.Finish();
205 
206   flexbuffers_.push_back(fbb.GetBuffer());
207   auto& buffer = flexbuffers_.back();
208   CHECK_EQ(interpreter_->AddNodeWithParameters(
209                inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
210                buffer.size(), nullptr, &reg),
211            kTfLiteOk);
212 }
213 
214 }  // namespace testing
215 }  // namespace flex
216 }  // namespace tflite
217