1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/tools/serialization/writer_lib.h"
17
18 #include <cstdlib>
19 #include <fstream>
20 #include <memory>
21 #include <numeric>
22 #include <sstream>
23 #include <string>
24 #include <tuple>
25 #include <vector>
26
27 #include <gmock/gmock.h>
28 #include <gtest/gtest.h>
29 #include "tensorflow/lite/c/builtin_op_data.h"
30 #include "tensorflow/lite/c/c_api_types.h"
31 #include "tensorflow/lite/c/common.h"
32 #include "tensorflow/lite/interpreter.h"
33 #include "tensorflow/lite/kernels/register.h"
34 #include "tensorflow/lite/kernels/subgraph_test_util.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow/lite/schema/schema_generated.h"
37 #include "tensorflow/lite/testing/util.h"
38
39 namespace tflite {
40
41 using subgraph_test_util::CheckIntTensor;
42 using subgraph_test_util::FillIntTensor;
43
CreateFilePath(const std::string & file_name)44 std::string CreateFilePath(const std::string& file_name) {
45 const char* tmp_dir = getenv("TEST_TMPDIR");
46 return std::string(tmp_dir ? tmp_dir : "./") + file_name;
47 }
48
49 // The bool param indicates whether we use SubgraphWriter(true) or
50 // ModelWriter(false) for the test
51 class SingleSubgraphTest : public ::testing::TestWithParam<bool> {
52 protected:
WriteToFile(Interpreter * interpreter,const std::string & filename,bool use_subgraph_writer)53 void WriteToFile(Interpreter* interpreter, const std::string& filename,
54 bool use_subgraph_writer) {
55 if (use_subgraph_writer) {
56 SubgraphWriter writer(&interpreter->primary_subgraph());
57 CHECK_EQ(writer.Write(filename), kTfLiteOk);
58 } else {
59 ModelWriter writer(interpreter);
60 CHECK_EQ(writer.Write(filename), kTfLiteOk);
61 }
62 }
63 };
64
TEST_P(SingleSubgraphTest,InvalidDestinations)65 TEST_P(SingleSubgraphTest, InvalidDestinations) {
66 Interpreter interpreter;
67 interpreter.AddTensors(3);
68 float foo[] = {1, 2, 3};
69 interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
70 TfLiteQuantization());
71 interpreter.SetTensorParametersReadOnly(
72 1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
73 reinterpret_cast<char*>(foo), sizeof(foo));
74 interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
75 TfLiteQuantization());
76 interpreter.SetInputs({0, 1});
77 interpreter.SetOutputs({2});
78 const char* initial_data = "";
79 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
80 TfLiteAddParams* builtin_data =
81 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
82 builtin_data->activation = kTfLiteActNone;
83 builtin_data->pot_scale_int16 = false;
84 const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
85 interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
86 reinterpret_cast<void*>(builtin_data), reg);
87
88 // Check if invalid filename is handled gracefully.
89 if (GetParam()) {
90 SubgraphWriter writer(&interpreter.primary_subgraph());
91 CHECK_EQ(writer.Write(""), kTfLiteError);
92 } else {
93 ModelWriter writer(&interpreter);
94 CHECK_EQ(writer.Write(""), kTfLiteError);
95 }
96
97 // Check if invalid buffer is handled gracefully.
98 size_t size;
99 if (GetParam()) {
100 SubgraphWriter writer(&interpreter.primary_subgraph());
101 CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
102 } else {
103 ModelWriter writer(&interpreter);
104 CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
105 }
106 }
107
TEST_P(SingleSubgraphTest,FloatModelTest)108 TEST_P(SingleSubgraphTest, FloatModelTest) {
109 Interpreter interpreter;
110 interpreter.AddTensors(3);
111 float foo[] = {1, 2, 3};
112 interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
113 TfLiteQuantization());
114 interpreter.SetTensorParametersReadOnly(
115 1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
116 reinterpret_cast<char*>(foo), sizeof(foo));
117 interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
118 TfLiteQuantization());
119 interpreter.SetInputs({0, 1});
120 interpreter.SetOutputs({2});
121 const char* initial_data = "";
122 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
123 TfLiteAddParams* builtin_data =
124 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
125 builtin_data->activation = kTfLiteActNone;
126 builtin_data->pot_scale_int16 = false;
127 const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
128 interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
129 reinterpret_cast<void*>(builtin_data), reg);
130
131 const std::string test_file = CreateFilePath("test_float.tflite");
132 WriteToFile(&interpreter, test_file, GetParam());
133 std::unique_ptr<FlatBufferModel> model =
134 FlatBufferModel::BuildFromFile(test_file.c_str());
135 InterpreterBuilder builder(*model, resolver);
136 std::unique_ptr<Interpreter> new_interpreter;
137 builder(&new_interpreter);
138 CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
139 }
140
141 // Tests writing only a portion of the subgraph.
TEST_P(SingleSubgraphTest,CustomInputOutputTest)142 TEST_P(SingleSubgraphTest, CustomInputOutputTest) {
143 Interpreter interpreter;
144 interpreter.AddTensors(4);
145 constexpr float kFoo[] = {1, 2, 3};
146 interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
147 TfLiteQuantization());
148 interpreter.SetTensorParametersReadOnly(
149 1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
150 reinterpret_cast<const char*>(kFoo), sizeof(kFoo));
151 interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
152 TfLiteQuantization());
153 interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3},
154 TfLiteQuantization());
155 interpreter.SetInputs({0, 1});
156 interpreter.SetOutputs({3});
157
158 // Add two ops: Add and Relu
159 const char* initial_data = "";
160 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
161 TfLiteAddParams* builtin_data =
162 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
163 builtin_data->activation = kTfLiteActNone;
164 builtin_data->pot_scale_int16 = false;
165 const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
166 interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
167 reinterpret_cast<void*>(builtin_data), reg);
168
169 const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1);
170 interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
171
172 // Only write the second op.
173 const std::string test_file = CreateFilePath("test_custom.tflite");
174 SubgraphWriter writer(&interpreter.primary_subgraph());
175 EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
176 /*execution_plan=*/{1}),
177 kTfLiteOk);
178 writer.SetUnusedTensors({0, 1});
179 writer.Write(test_file);
180
181 std::unique_ptr<FlatBufferModel> model =
182 FlatBufferModel::BuildFromFile(test_file.c_str());
183 InterpreterBuilder builder(*model, resolver);
184 std::unique_ptr<Interpreter> new_interpreter;
185 builder(&new_interpreter);
186 ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
187 }
188
TEST_P(SingleSubgraphTest,CustomInputOutputErrorCasesTest)189 TEST_P(SingleSubgraphTest, CustomInputOutputErrorCasesTest) {
190 Interpreter interpreter;
191 interpreter.AddTensors(5);
192 constexpr float kFoo[] = {1, 2, 3};
193 interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
194 TfLiteQuantization());
195 interpreter.SetTensorParametersReadOnly(
196 1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
197 reinterpret_cast<const char*>(kFoo), sizeof(kFoo));
198 interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
199 TfLiteQuantization());
200 interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3},
201 TfLiteQuantization());
202 interpreter.SetTensorParametersReadWrite(4, kTfLiteFloat32, "e", {3},
203 TfLiteQuantization());
204 interpreter.SetInputs({0, 1});
205 interpreter.SetOutputs({4});
206
207 // Add three ops.
208 const char* initial_data = "";
209 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
210 TfLiteAddParams* builtin_data =
211 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
212 builtin_data->activation = kTfLiteActNone;
213 builtin_data->pot_scale_int16 = false;
214 const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
215 interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
216 reinterpret_cast<void*>(builtin_data), reg);
217
218 const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1);
219 interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
220
221 const TfLiteRegistration* reg3 = resolver.FindOp(BuiltinOperator_RELU6, 1);
222 interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, reg3);
223
224 SubgraphWriter writer(&interpreter.primary_subgraph());
225
226 // Test wrong input.
227 EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
228 /*execution_plan=*/{0, 1}),
229 kTfLiteError);
230 // Test wrong output.
231 EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{4},
232 /*execution_plan=*/{0, 1}),
233 kTfLiteError);
234 // Test a valid case.
235 EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{3},
236 /*execution_plan=*/{0, 1}),
237 kTfLiteOk);
238 }
239
240 // Tests if SetCustomInputOutput handles variable tensors correctly.
TEST_P(SingleSubgraphTest,CustomInputOutputVariableTensorTest)241 TEST_P(SingleSubgraphTest, CustomInputOutputVariableTensorTest) {
242 Interpreter interpreter;
243 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
244
245 // Create tensors.
246 interpreter.AddTensors(3);
247 interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
248 TfLiteQuantization());
249 interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "b", {3},
250 TfLiteQuantization(),
251 /*is_variable=*/true);
252 interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
253 TfLiteQuantization());
254 interpreter.SetInputs({0});
255 interpreter.SetOutputs({2});
256 interpreter.SetVariables({1});
257
258 // Create an Add node.
259 TfLiteAddParams* builtin_data =
260 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
261 builtin_data->activation = kTfLiteActNone;
262 builtin_data->pot_scale_int16 = false;
263 interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0,
264 reinterpret_cast<void*>(builtin_data),
265 resolver.FindOp(BuiltinOperator_ADD, 1));
266
267 // Write model to file.
268 const std::string test_file = CreateFilePath("test_variables.tflite");
269 SubgraphWriter writer(&interpreter.primary_subgraph());
270 EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0}, /*outputs=*/{2},
271 /*execution_plan=*/{0}),
272 kTfLiteOk);
273 writer.Write(test_file);
274
275 // Read model and test.
276 std::unique_ptr<FlatBufferModel> model =
277 FlatBufferModel::BuildFromFile(test_file.c_str());
278 InterpreterBuilder builder(*model, resolver);
279 std::unique_ptr<Interpreter> new_interpreter;
280 builder(&new_interpreter);
281 CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
282 }
283
TEST_P(SingleSubgraphTest,PerTensorQuantizedModelTest)284 TEST_P(SingleSubgraphTest, PerTensorQuantizedModelTest) {
285 Interpreter interpreter;
286 interpreter.AddTensors(3);
287 interpreter.SetTensorParametersReadWrite(
288 0, kTfLiteUInt8, "a", {3}, TfLiteQuantizationParams({1 / 256., 128}));
289 interpreter.SetTensorParametersReadWrite(
290 1, kTfLiteUInt8, "b", {3}, TfLiteQuantizationParams({1 / 256., 128}));
291 interpreter.SetTensorParametersReadWrite(
292 2, kTfLiteUInt8, "c", {3}, TfLiteQuantizationParams({1 / 256., 128}));
293 interpreter.SetInputs({0, 1});
294 interpreter.SetOutputs({2});
295 const char* initial_data = "";
296 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
297 TfLiteAddParams* builtin_data =
298 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
299 builtin_data->activation = kTfLiteActNone;
300 builtin_data->pot_scale_int16 = false;
301 const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
302 interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
303 reinterpret_cast<void*>(builtin_data), reg);
304
305 const std::string test_file = CreateFilePath("test_uint8.tflite");
306 WriteToFile(&interpreter, test_file, GetParam());
307 std::unique_ptr<FlatBufferModel> model =
308 FlatBufferModel::BuildFromFile(test_file.c_str());
309 InterpreterBuilder builder(*model, resolver);
310 std::unique_ptr<Interpreter> new_interpreter;
311 builder(&new_interpreter);
312 CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
313 }
314
TEST_P(SingleSubgraphTest,OpVersioningTest)315 TEST_P(SingleSubgraphTest, OpVersioningTest) {
316 Interpreter interpreter;
317 interpreter.AddTensors(3);
318 interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {1, 4},
319 TfLiteQuantization());
320 interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "b", {2},
321 TfLiteQuantization());
322 interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {4, 4},
323 TfLiteQuantization());
324 interpreter.SetInputs({0, 1});
325 interpreter.SetOutputs({2});
326
327 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
328 const TfLiteRegistration* reg =
329 resolver.FindOp(BuiltinOperator_BROADCAST_TO, 2);
330 interpreter.AddNodeWithParameters(/*inputs=*/{0, 1}, /*outputs=*/{2},
331 /*init_data=*/nullptr, /*init_data_size=*/0,
332 /*builtin_data=*/nullptr, reg);
333
334 const std::string test_file = CreateFilePath("test_float.tflite");
335 WriteToFile(&interpreter, test_file, GetParam());
336 std::unique_ptr<FlatBufferModel> model =
337 FlatBufferModel::BuildFromFile(test_file.c_str());
338 InterpreterBuilder builder(*model, resolver);
339 std::unique_ptr<Interpreter> new_interpreter;
340 builder(&new_interpreter);
341 CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
342
343 ASSERT_EQ(new_interpreter->nodes_size(), 1);
344 TfLiteRegistration output_reg =
345 new_interpreter->node_and_registration(0)->second;
346 ASSERT_EQ(output_reg.builtin_code, BuiltinOperator_BROADCAST_TO);
347 CHECK_EQ(output_reg.version, 2);
348 }
349
TEST_P(SingleSubgraphTest,DynamicShapeTest)350 TEST_P(SingleSubgraphTest, DynamicShapeTest) {
351 // Build a model with a single Add op.
352 Interpreter interpreter;
353
354 interpreter.AddTensors(3);
355 std::vector<int> dims = {1, 3};
356 std::vector<int> dims_signature = {-1, 3};
357 interpreter.SetTensorParametersReadWrite(
358 0, kTfLiteFloat32, "a", dims, TfLiteQuantizationParams{1.0, 0},
359 /*is_variable=*/false, &dims_signature);
360 interpreter.SetTensorParametersReadWrite(
361 1, kTfLiteFloat32, "b", dims, TfLiteQuantizationParams{1.0, 0},
362 /*is_variable=*/false, &dims_signature);
363 interpreter.SetTensorParametersReadWrite(
364 2, kTfLiteFloat32, "c", dims, TfLiteQuantizationParams{1.0, 0},
365 /*is_variable=*/false, &dims_signature);
366
367 interpreter.SetInputs({0, 1});
368 interpreter.SetOutputs({2});
369 const char* initial_data = "";
370 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
371 TfLiteAddParams* builtin_data =
372 reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
373 builtin_data->activation = kTfLiteActNone;
374 builtin_data->pot_scale_int16 = false;
375 const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
376 interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
377 reinterpret_cast<void*>(builtin_data), reg);
378
379 // Export interpreter and import back.
380 const std::string test_file = CreateFilePath("test_dynamic_shape.tflite");
381 WriteToFile(&interpreter, test_file, GetParam());
382 std::unique_ptr<FlatBufferModel> model =
383 FlatBufferModel::BuildFromFile(test_file.c_str());
384 InterpreterBuilder builder(*model, resolver);
385 std::unique_ptr<Interpreter> new_interpreter;
386 builder(&new_interpreter);
387 CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
388
389 // Check shape signature in new interpreter.
390 TfLiteTensor* tensor0 = new_interpreter->tensor(0);
391 CHECK_NOTNULL(tensor0->dims_signature);
392 TfLiteIntArrayView shape_view(tensor0->dims_signature);
393 CHECK_EQ(shape_view.size(), 2);
394 CHECK_EQ(shape_view[0], -1);
395 }
396
397 INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool());
398
399 struct ReshapeTestPattern {
400 int num_inputs;
401 bool is_param_valid;
402 bool has_buggy_non_flatten_shape;
403 };
404
405 class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
406
TEST_P(ReshapeLayerTest,ReshapeLayerTest)407 TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
408 const auto param = GetParam();
409 Interpreter interpreter;
410 const int total_tensors = param.num_inputs + 1;
411 interpreter.AddTensors(total_tensors);
412 int output_shape[] = {1, 2, 3};
413 interpreter.SetTensorParametersReadWrite(/*tensor_index=*/0, kTfLiteFloat32,
414 /*name=*/"a", /*dims=*/{6},
415 TfLiteQuantization());
416 ASSERT_LE(param.num_inputs, 2);
417 if (param.num_inputs == 2) {
418 // Some TOCO generated models have buggy shape arguments, which are required
419 // to be flatten, for example, dims={3, 1} instead of dims={3}.
420 if (param.has_buggy_non_flatten_shape) {
421 interpreter.SetTensorParametersReadOnly(
422 /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3, 1},
423 TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
424 sizeof(output_shape));
425 } else {
426 interpreter.SetTensorParametersReadOnly(
427 /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
428 TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
429 sizeof(output_shape));
430 }
431 }
432 interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
433 kTfLiteFloat32, /*name=*/"c",
434 /*dims=*/{3}, TfLiteQuantization());
435
436 std::vector<int> input_tensors(param.num_inputs);
437 std::iota(input_tensors.begin(), input_tensors.end(), 0);
438
439 interpreter.SetInputs(input_tensors);
440 interpreter.SetOutputs({total_tensors - 1});
441 const char* initial_data = "";
442 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
443 TfLiteReshapeParams* builtin_data = reinterpret_cast<TfLiteReshapeParams*>(
444 malloc(sizeof(TfLiteReshapeParams)));
445 memset(builtin_data, 0, sizeof(TfLiteReshapeParams));
446 if (param.is_param_valid) {
447 builtin_data->num_dimensions = 3;
448 for (int dim = 0; dim < builtin_data->num_dimensions; ++dim) {
449 builtin_data->shape[dim] = output_shape[dim];
450 }
451 }
452 const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_RESHAPE, 1);
453 interpreter.AddNodeWithParameters(input_tensors,
454 /*outputs=*/{total_tensors - 1},
455 initial_data, /*init_data_size=*/0,
456 reinterpret_cast<void*>(builtin_data), reg);
457
458 SubgraphWriter writer(&interpreter.primary_subgraph());
459 std::stringstream ss;
460 ss << CreateFilePath("test_reshape_") << param.num_inputs
461 << param.is_param_valid << ".tflite";
462 std::string filename = ss.str();
463 writer.Write(filename);
464 std::unique_ptr<FlatBufferModel> model =
465 FlatBufferModel::BuildFromFile(filename.c_str());
466 InterpreterBuilder builder(*model, resolver);
467 std::unique_ptr<Interpreter> new_interpreter;
468 builder(&new_interpreter);
469 ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
470 }
471
472 INSTANTIATE_TEST_SUITE_P(
473 Writer, ReshapeLayerTest,
474 ::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
475 /*is_param_valid=*/true,
476 /*has_buggy_non_flatten_shape=*/false},
477 ReshapeTestPattern{/*num_inputs=*/2,
478 /*is_param_valid=*/false,
479 /*has_buggy_non_flatten_shape=*/false},
480 ReshapeTestPattern{/*num_inputs=*/1,
481 /*is_param_valid=*/true,
482 /*has_buggy_non_flatten_shape=*/false},
483 ReshapeTestPattern{/*num_inputs=*/2,
484 /*is_param_valid=*/true,
485 /*has_buggy_non_flatten_shape=*/true}),
__anond14eaafb0102(const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) 486 [](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
487 std::stringstream ss;
488 ss << "num_inputs_" << info.param.num_inputs << "_valid_param_"
489 << info.param.is_param_valid << "_buggy_shape_"
490 << info.param.has_buggy_non_flatten_shape;
491 std::string name = ss.str();
492 return name;
493 });
494
495 class WhileTest : public subgraph_test_util::ControlFlowOpTest {
496 protected:
NewCustomAlloc(size_t num_bytes,int required_alignment)497 TfLiteCustomAllocation NewCustomAlloc(size_t num_bytes,
498 int required_alignment) {
499 // Extra memory to ensure alignment.
500 char* new_alloc = new char[num_bytes + required_alignment];
501 char* new_underlying_buffer_aligned_ptr = reinterpret_cast<char*>(
502 AlignTo(required_alignment, reinterpret_cast<intptr_t>(new_alloc)));
503 custom_alloc_buffers_.emplace_back(new_alloc);
504
505 return TfLiteCustomAllocation(
506 {new_underlying_buffer_aligned_ptr, num_bytes});
507 }
508
AlignTo(size_t alignment,intptr_t offset)509 intptr_t AlignTo(size_t alignment, intptr_t offset) {
510 return offset % alignment == 0 ? offset
511 : offset + (alignment - offset % alignment);
512 }
513
514 std::vector<std::unique_ptr<char[]>> custom_alloc_buffers_;
515 };
516
517 // The test builds a model that produces the i-th number of
518 // triangular number sequence: 1, 3, 6, 10, 15, 21, 28.
TEST_F(WhileTest,TestTriangularNumberSequence)519 TEST_F(WhileTest, TestTriangularNumberSequence) {
520 const int kSeqNumber = 4;
521 const int kExpectedValue = 15;
522
523 interpreter_ = std::make_unique<Interpreter>();
524 AddSubgraphs(2);
525 builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), kSeqNumber);
526 builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
527 builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
528
529 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
530 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
531 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
532 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
533
534 // Use custom allocation for second input, to ensure things work well for
535 // non-traditional allocation types.
536 auto alloc =
537 NewCustomAlloc(interpreter_->tensor(interpreter_->inputs()[1])->bytes,
538 kDefaultTensorAlignment);
539 auto* input_data = reinterpret_cast<int*>(alloc.data);
540 input_data[0] = 1;
541 interpreter_->SetCustomAllocationForTensor(interpreter_->inputs()[1], alloc);
542
543 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
544 TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
545 CheckIntTensor(output1, {1}, {kSeqNumber + 1});
546 TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
547 CheckIntTensor(output2, {1}, {kExpectedValue});
548
549 // Now serialize & deserialize model into a new Interpreter.
550 ModelWriter writer(interpreter_.get());
551 const std::string test_file = CreateFilePath("test_while.tflite");
552 writer.Write(test_file);
553 std::unique_ptr<FlatBufferModel> model =
554 FlatBufferModel::BuildFromFile(test_file.c_str());
555 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
556 InterpreterBuilder builder(*model, resolver);
557 std::unique_ptr<Interpreter> new_interpreter;
558 builder(&new_interpreter);
559
560 // Check deserialized model.
561 new_interpreter->ResizeInputTensor(interpreter_->inputs()[0], {1});
562 new_interpreter->ResizeInputTensor(interpreter_->inputs()[1], {1});
563 ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
564 FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[0]), {1});
565 FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[1]), {1});
566 ASSERT_EQ(new_interpreter->Invoke(), kTfLiteOk);
567 output1 = new_interpreter->tensor(interpreter_->outputs()[0]);
568 CheckIntTensor(output1, {1}, {kSeqNumber + 1});
569 output2 = new_interpreter->tensor(interpreter_->outputs()[1]);
570 CheckIntTensor(output2, {1}, {kExpectedValue});
571 }
572
573 // Verifies the ModelWriters constructing from an interpreter or subgraphs
574 // produce the same results.
TEST_F(WhileTest,TestModelWriterFromSubgraphs)575 TEST_F(WhileTest, TestModelWriterFromSubgraphs) {
576 const int kSeqNumber = 4;
577 const int kExpectedValue = 15;
578
579 interpreter_ = std::make_unique<Interpreter>();
580 AddSubgraphs(2);
581 builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), kSeqNumber);
582 builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
583 builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
584
585 interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
586 interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
587 ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
588 FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
589
590 // Use custom allocation for second input, to ensure things work well for
591 // non-traditional allocation types.
592 auto alloc =
593 NewCustomAlloc(interpreter_->tensor(interpreter_->inputs()[1])->bytes,
594 kDefaultTensorAlignment);
595 auto* input_data = reinterpret_cast<int*>(alloc.data);
596 input_data[0] = 1;
597 interpreter_->SetCustomAllocationForTensor(interpreter_->inputs()[1], alloc);
598
599 ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
600 TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
601 CheckIntTensor(output1, {1}, {kSeqNumber + 1});
602 TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
603 CheckIntTensor(output2, {1}, {kExpectedValue});
604
605 // Serializes the model using the interpreter.
606 ModelWriter writer_1(interpreter_.get());
607 const std::string test_file_1 = CreateFilePath("test_while_1.tflite");
608 writer_1.Write(test_file_1);
609
610 // Serializes the model using subgraphs.
611 std::vector<Subgraph*> subgraphs;
612 for (int i = 0; i < interpreter_->subgraphs_size(); ++i) {
613 subgraphs.push_back(interpreter_->subgraph(i));
614 }
615 ModelWriter writer_2(subgraphs);
616 const std::string test_file_2 = CreateFilePath("test_while_2.tflite");
617 writer_2.Write(test_file_2);
618
619 // The results from both ModelWriters should be the same.
620 std::ifstream file_ifs_1(test_file_1, std::ios::in);
621 std::ostringstream model_content_1;
622 model_content_1 << file_ifs_1.rdbuf();
623
624 std::ifstream file_ifs_2(test_file_2, std::ios::in);
625 std::ostringstream model_content_2;
626 model_content_2 << file_ifs_2.rdbuf();
627
628 EXPECT_FALSE(model_content_1.str().empty());
629 EXPECT_EQ(model_content_1.str(), model_content_2.str());
630 }
631
632 } // namespace tflite
633