1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/flex/test_util.h"
17
18 #include "absl/memory/memory.h"
19 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
20 #include "tensorflow/lite/string_type.h"
21
22 namespace tflite {
23 namespace flex {
24 namespace testing {
25
Invoke()26 bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
27
SetStringValues(int tensor_index,const std::vector<string> & values)28 void FlexModelTest::SetStringValues(int tensor_index,
29 const std::vector<string>& values) {
30 DynamicBuffer dynamic_buffer;
31 for (const string& s : values) {
32 dynamic_buffer.AddString(s.data(), s.size());
33 }
34 dynamic_buffer.WriteToTensor(interpreter_->tensor(tensor_index),
35 /*new_shape=*/nullptr);
36 }
37
GetStringValues(int tensor_index) const38 std::vector<string> FlexModelTest::GetStringValues(int tensor_index) const {
39 std::vector<string> result;
40
41 TfLiteTensor* tensor = interpreter_->tensor(tensor_index);
42 auto num_strings = GetStringCount(tensor);
43 for (size_t i = 0; i < num_strings; ++i) {
44 auto ref = GetString(tensor, i);
45 result.push_back(string(ref.str, ref.len));
46 }
47
48 return result;
49 }
50
SetShape(int tensor_index,const std::vector<int> & values)51 void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) {
52 ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
53 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
54 }
55
GetShape(int tensor_index)56 std::vector<int> FlexModelTest::GetShape(int tensor_index) {
57 std::vector<int> result;
58 auto* dims = interpreter_->tensor(tensor_index)->dims;
59 result.reserve(dims->size);
60 for (int i = 0; i < dims->size; ++i) {
61 result.push_back(dims->data[i]);
62 }
63 return result;
64 }
65
GetType(int tensor_index)66 TfLiteType FlexModelTest::GetType(int tensor_index) {
67 return interpreter_->tensor(tensor_index)->type;
68 }
69
IsDynamicTensor(int tensor_index)70 bool FlexModelTest::IsDynamicTensor(int tensor_index) {
71 return interpreter_->tensor(tensor_index)->allocation_type == kTfLiteDynamic;
72 }
73
AddTensors(int num_tensors,const std::vector<int> & inputs,const std::vector<int> & outputs,TfLiteType type,const std::vector<int> & dims)74 void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
75 const std::vector<int>& outputs, TfLiteType type,
76 const std::vector<int>& dims) {
77 interpreter_->AddTensors(num_tensors);
78 for (int i = 0; i < num_tensors; ++i) {
79 TfLiteQuantizationParams quant;
80 CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
81 /*name=*/"",
82 /*dims=*/dims, quant),
83 kTfLiteOk);
84 }
85
86 CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk);
87 CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk);
88 }
89
SetConstTensor(int tensor_index,const std::vector<int> & values,TfLiteType type,const char * buffer,size_t bytes)90 void FlexModelTest::SetConstTensor(int tensor_index,
91 const std::vector<int>& values,
92 TfLiteType type, const char* buffer,
93 size_t bytes) {
94 TfLiteQuantizationParams quant;
95 CHECK_EQ(interpreter_->SetTensorParametersReadOnly(tensor_index, type,
96 /*name=*/"",
97 /*dims=*/values, quant,
98 buffer, bytes),
99 kTfLiteOk);
100 }
101
AddTfLiteMulOp(const std::vector<int> & inputs,const std::vector<int> & outputs)102 void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs,
103 const std::vector<int>& outputs) {
104 ++next_op_index_;
105
106 static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
107 reg.builtin_code = BuiltinOperator_MUL;
108 reg.prepare = [](TfLiteContext* context, TfLiteNode* node) {
109 auto* i0 = &context->tensors[node->inputs->data[0]];
110 auto* o = &context->tensors[node->outputs->data[0]];
111 return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims));
112 };
113 reg.invoke = [](TfLiteContext* context, TfLiteNode* node) {
114 auto* i0 = &context->tensors[node->inputs->data[0]];
115 auto* i1 = &context->tensors[node->inputs->data[1]];
116 auto* o = &context->tensors[node->outputs->data[0]];
117 for (int i = 0; i < o->bytes / sizeof(float); ++i) {
118 o->data.f[i] = i0->data.f[i] * i1->data.f[i];
119 }
120 return kTfLiteOk;
121 };
122
123 CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0,
124 nullptr, ®),
125 kTfLiteOk);
126 }
127
AddTfOp(TfOpType op,const std::vector<int> & inputs,const std::vector<int> & outputs)128 void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
129 const std::vector<int>& outputs) {
130 tf_ops_.push_back(next_op_index_);
131 ++next_op_index_;
132
133 auto attr = [](const string& key, const string& value) {
134 return " attr{ key: '" + key + "' value {" + value + "}}";
135 };
136
137 string type_attribute;
138 switch (interpreter_->tensor(inputs[0])->type) {
139 case kTfLiteInt32:
140 type_attribute = attr("T", "type: DT_INT32");
141 break;
142 case kTfLiteFloat32:
143 type_attribute = attr("T", "type: DT_FLOAT");
144 break;
145 case kTfLiteString:
146 type_attribute = attr("T", "type: DT_STRING");
147 break;
148 case kTfLiteBool:
149 type_attribute = attr("T", "type: DT_BOOL");
150 break;
151 default:
152 // TODO(b/113613439): Use nodedef string utilities to properly handle all
153 // types.
154 LOG(FATAL) << "Type not supported";
155 break;
156 }
157
158 if (op == kUnpack) {
159 string attributes =
160 type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
161 AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs);
162 } else if (op == kIdentity) {
163 string attributes = type_attribute;
164 AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs);
165 } else if (op == kAdd) {
166 string attributes = type_attribute;
167 AddTfOp("FlexAdd", "Add", attributes, inputs, outputs);
168 } else if (op == kMul) {
169 string attributes = type_attribute;
170 AddTfOp("FlexMul", "Mul", attributes, inputs, outputs);
171 } else if (op == kRfft) {
172 AddTfOp("FlexRFFT", "RFFT", "", inputs, outputs);
173 } else if (op == kImag) {
174 AddTfOp("FlexImag", "Imag", "", inputs, outputs);
175 } else if (op == kLoopCond) {
176 string attributes = type_attribute;
177 AddTfOp("FlexLoopCond", "LoopCond", attributes, inputs, outputs);
178 } else if (op == kNonExistent) {
179 AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
180 } else if (op == kIncompatibleNodeDef) {
181 // "Cast" op is created without attributes - making it incompatible.
182 AddTfOp("FlexCast", "Cast", "", inputs, outputs);
183 }
184 }
185
AddTfOp(const char * tflite_name,const string & tf_name,const string & nodedef_str,const std::vector<int> & inputs,const std::vector<int> & outputs)186 void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name,
187 const string& nodedef_str,
188 const std::vector<int>& inputs,
189 const std::vector<int>& outputs) {
190 static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr};
191 reg.builtin_code = BuiltinOperator_CUSTOM;
192 reg.custom_name = tflite_name;
193
194 tensorflow::NodeDef nodedef;
195 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
196 nodedef_str + " op: '" + tf_name + "'", &nodedef));
197 string serialized_nodedef;
198 CHECK(nodedef.SerializeToString(&serialized_nodedef));
199 flexbuffers::Builder fbb;
200 fbb.Vector([&]() {
201 fbb.String(nodedef.op());
202 fbb.String(serialized_nodedef);
203 });
204 fbb.Finish();
205
206 flexbuffers_.push_back(fbb.GetBuffer());
207 auto& buffer = flexbuffers_.back();
208 CHECK_EQ(interpreter_->AddNodeWithParameters(
209 inputs, outputs, reinterpret_cast<const char*>(buffer.data()),
210 buffer.size(), nullptr, ®),
211 kTfLiteOk);
212 }
213
214 } // namespace testing
215 } // namespace flex
216 } // namespace tflite
217