xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/serialization/writer_lib_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/serialization/writer_lib.h"
17 
18 #include <cstdlib>
19 #include <fstream>
20 #include <memory>
21 #include <numeric>
22 #include <sstream>
23 #include <string>
24 #include <tuple>
25 #include <vector>
26 
27 #include <gmock/gmock.h>
28 #include <gtest/gtest.h>
29 #include "tensorflow/lite/c/builtin_op_data.h"
30 #include "tensorflow/lite/c/c_api_types.h"
31 #include "tensorflow/lite/c/common.h"
32 #include "tensorflow/lite/interpreter.h"
33 #include "tensorflow/lite/kernels/register.h"
34 #include "tensorflow/lite/kernels/subgraph_test_util.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow/lite/schema/schema_generated.h"
37 #include "tensorflow/lite/testing/util.h"
38 
39 namespace tflite {
40 
41 using subgraph_test_util::CheckIntTensor;
42 using subgraph_test_util::FillIntTensor;
43 
CreateFilePath(const std::string & file_name)44 std::string CreateFilePath(const std::string& file_name) {
45   const char* tmp_dir = getenv("TEST_TMPDIR");
46   return std::string(tmp_dir ? tmp_dir : "./") + file_name;
47 }
48 
49 // The bool param indicates whether we use SubgraphWriter(true) or
50 // ModelWriter(false) for the test
51 class SingleSubgraphTest : public ::testing::TestWithParam<bool> {
52  protected:
WriteToFile(Interpreter * interpreter,const std::string & filename,bool use_subgraph_writer)53   void WriteToFile(Interpreter* interpreter, const std::string& filename,
54                    bool use_subgraph_writer) {
55     if (use_subgraph_writer) {
56       SubgraphWriter writer(&interpreter->primary_subgraph());
57       CHECK_EQ(writer.Write(filename), kTfLiteOk);
58     } else {
59       ModelWriter writer(interpreter);
60       CHECK_EQ(writer.Write(filename), kTfLiteOk);
61     }
62   }
63 };
64 
TEST_P(SingleSubgraphTest,InvalidDestinations)65 TEST_P(SingleSubgraphTest, InvalidDestinations) {
66   Interpreter interpreter;
67   interpreter.AddTensors(3);
68   float foo[] = {1, 2, 3};
69   interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
70                                            TfLiteQuantization());
71   interpreter.SetTensorParametersReadOnly(
72       1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
73       reinterpret_cast<char*>(foo), sizeof(foo));
74   interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
75                                            TfLiteQuantization());
76   interpreter.SetInputs({0, 1});
77   interpreter.SetOutputs({2});
78   const char* initial_data = "";
79   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
80   TfLiteAddParams* builtin_data =
81       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
82   builtin_data->activation = kTfLiteActNone;
83   builtin_data->pot_scale_int16 = false;
84   const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
85   interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
86                                     reinterpret_cast<void*>(builtin_data), reg);
87 
88   // Check if invalid filename is handled gracefully.
89   if (GetParam()) {
90     SubgraphWriter writer(&interpreter.primary_subgraph());
91     CHECK_EQ(writer.Write(""), kTfLiteError);
92   } else {
93     ModelWriter writer(&interpreter);
94     CHECK_EQ(writer.Write(""), kTfLiteError);
95   }
96 
97   // Check if invalid buffer is handled gracefully.
98   size_t size;
99   if (GetParam()) {
100     SubgraphWriter writer(&interpreter.primary_subgraph());
101     CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
102   } else {
103     ModelWriter writer(&interpreter);
104     CHECK_EQ(writer.GetBuffer(nullptr, &size), kTfLiteError);
105   }
106 }
107 
TEST_P(SingleSubgraphTest,FloatModelTest)108 TEST_P(SingleSubgraphTest, FloatModelTest) {
109   Interpreter interpreter;
110   interpreter.AddTensors(3);
111   float foo[] = {1, 2, 3};
112   interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
113                                            TfLiteQuantization());
114   interpreter.SetTensorParametersReadOnly(
115       1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
116       reinterpret_cast<char*>(foo), sizeof(foo));
117   interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
118                                            TfLiteQuantization());
119   interpreter.SetInputs({0, 1});
120   interpreter.SetOutputs({2});
121   const char* initial_data = "";
122   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
123   TfLiteAddParams* builtin_data =
124       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
125   builtin_data->activation = kTfLiteActNone;
126   builtin_data->pot_scale_int16 = false;
127   const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
128   interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
129                                     reinterpret_cast<void*>(builtin_data), reg);
130 
131   const std::string test_file = CreateFilePath("test_float.tflite");
132   WriteToFile(&interpreter, test_file, GetParam());
133   std::unique_ptr<FlatBufferModel> model =
134       FlatBufferModel::BuildFromFile(test_file.c_str());
135   InterpreterBuilder builder(*model, resolver);
136   std::unique_ptr<Interpreter> new_interpreter;
137   builder(&new_interpreter);
138   CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
139 }
140 
141 // Tests writing only a portion of the subgraph.
TEST_P(SingleSubgraphTest,CustomInputOutputTest)142 TEST_P(SingleSubgraphTest, CustomInputOutputTest) {
143   Interpreter interpreter;
144   interpreter.AddTensors(4);
145   constexpr float kFoo[] = {1, 2, 3};
146   interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
147                                            TfLiteQuantization());
148   interpreter.SetTensorParametersReadOnly(
149       1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
150       reinterpret_cast<const char*>(kFoo), sizeof(kFoo));
151   interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
152                                            TfLiteQuantization());
153   interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3},
154                                            TfLiteQuantization());
155   interpreter.SetInputs({0, 1});
156   interpreter.SetOutputs({3});
157 
158   // Add two ops: Add and Relu
159   const char* initial_data = "";
160   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
161   TfLiteAddParams* builtin_data =
162       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
163   builtin_data->activation = kTfLiteActNone;
164   builtin_data->pot_scale_int16 = false;
165   const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
166   interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
167                                     reinterpret_cast<void*>(builtin_data), reg);
168 
169   const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1);
170   interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
171 
172   // Only write the second op.
173   const std::string test_file = CreateFilePath("test_custom.tflite");
174   SubgraphWriter writer(&interpreter.primary_subgraph());
175   EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
176                                         /*execution_plan=*/{1}),
177             kTfLiteOk);
178   writer.SetUnusedTensors({0, 1});
179   writer.Write(test_file);
180 
181   std::unique_ptr<FlatBufferModel> model =
182       FlatBufferModel::BuildFromFile(test_file.c_str());
183   InterpreterBuilder builder(*model, resolver);
184   std::unique_ptr<Interpreter> new_interpreter;
185   builder(&new_interpreter);
186   ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
187 }
188 
TEST_P(SingleSubgraphTest,CustomInputOutputErrorCasesTest)189 TEST_P(SingleSubgraphTest, CustomInputOutputErrorCasesTest) {
190   Interpreter interpreter;
191   interpreter.AddTensors(5);
192   constexpr float kFoo[] = {1, 2, 3};
193   interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
194                                            TfLiteQuantization());
195   interpreter.SetTensorParametersReadOnly(
196       1, kTfLiteFloat32, "b", {3}, TfLiteQuantization(),
197       reinterpret_cast<const char*>(kFoo), sizeof(kFoo));
198   interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
199                                            TfLiteQuantization());
200   interpreter.SetTensorParametersReadWrite(3, kTfLiteFloat32, "d", {3},
201                                            TfLiteQuantization());
202   interpreter.SetTensorParametersReadWrite(4, kTfLiteFloat32, "e", {3},
203                                            TfLiteQuantization());
204   interpreter.SetInputs({0, 1});
205   interpreter.SetOutputs({4});
206 
207   // Add three ops.
208   const char* initial_data = "";
209   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
210   TfLiteAddParams* builtin_data =
211       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
212   builtin_data->activation = kTfLiteActNone;
213   builtin_data->pot_scale_int16 = false;
214   const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
215   interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
216                                     reinterpret_cast<void*>(builtin_data), reg);
217 
218   const TfLiteRegistration* reg2 = resolver.FindOp(BuiltinOperator_RELU, 1);
219   interpreter.AddNodeWithParameters({2}, {3}, nullptr, 0, nullptr, reg2);
220 
221   const TfLiteRegistration* reg3 = resolver.FindOp(BuiltinOperator_RELU6, 1);
222   interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, reg3);
223 
224   SubgraphWriter writer(&interpreter.primary_subgraph());
225 
226   // Test wrong input.
227   EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{2}, /*outputs=*/{3},
228                                         /*execution_plan=*/{0, 1}),
229             kTfLiteError);
230   // Test wrong output.
231   EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{4},
232                                         /*execution_plan=*/{0, 1}),
233             kTfLiteError);
234   // Test a valid case.
235   EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0, 1}, /*outputs=*/{3},
236                                         /*execution_plan=*/{0, 1}),
237             kTfLiteOk);
238 }
239 
240 // Tests if SetCustomInputOutput handles variable tensors correctly.
TEST_P(SingleSubgraphTest,CustomInputOutputVariableTensorTest)241 TEST_P(SingleSubgraphTest, CustomInputOutputVariableTensorTest) {
242   Interpreter interpreter;
243   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
244 
245   // Create tensors.
246   interpreter.AddTensors(3);
247   interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
248                                            TfLiteQuantization());
249   interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "b", {3},
250                                            TfLiteQuantization(),
251                                            /*is_variable=*/true);
252   interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
253                                            TfLiteQuantization());
254   interpreter.SetInputs({0});
255   interpreter.SetOutputs({2});
256   interpreter.SetVariables({1});
257 
258   // Create an Add node.
259   TfLiteAddParams* builtin_data =
260       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
261   builtin_data->activation = kTfLiteActNone;
262   builtin_data->pot_scale_int16 = false;
263   interpreter.AddNodeWithParameters({0, 1}, {2}, nullptr, 0,
264                                     reinterpret_cast<void*>(builtin_data),
265                                     resolver.FindOp(BuiltinOperator_ADD, 1));
266 
267   // Write model to file.
268   const std::string test_file = CreateFilePath("test_variables.tflite");
269   SubgraphWriter writer(&interpreter.primary_subgraph());
270   EXPECT_EQ(writer.SetCustomInputOutput(/*inputs=*/{0}, /*outputs=*/{2},
271                                         /*execution_plan=*/{0}),
272             kTfLiteOk);
273   writer.Write(test_file);
274 
275   // Read model and test.
276   std::unique_ptr<FlatBufferModel> model =
277       FlatBufferModel::BuildFromFile(test_file.c_str());
278   InterpreterBuilder builder(*model, resolver);
279   std::unique_ptr<Interpreter> new_interpreter;
280   builder(&new_interpreter);
281   CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
282 }
283 
TEST_P(SingleSubgraphTest,PerTensorQuantizedModelTest)284 TEST_P(SingleSubgraphTest, PerTensorQuantizedModelTest) {
285   Interpreter interpreter;
286   interpreter.AddTensors(3);
287   interpreter.SetTensorParametersReadWrite(
288       0, kTfLiteUInt8, "a", {3}, TfLiteQuantizationParams({1 / 256., 128}));
289   interpreter.SetTensorParametersReadWrite(
290       1, kTfLiteUInt8, "b", {3}, TfLiteQuantizationParams({1 / 256., 128}));
291   interpreter.SetTensorParametersReadWrite(
292       2, kTfLiteUInt8, "c", {3}, TfLiteQuantizationParams({1 / 256., 128}));
293   interpreter.SetInputs({0, 1});
294   interpreter.SetOutputs({2});
295   const char* initial_data = "";
296   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
297   TfLiteAddParams* builtin_data =
298       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
299   builtin_data->activation = kTfLiteActNone;
300   builtin_data->pot_scale_int16 = false;
301   const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
302   interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
303                                     reinterpret_cast<void*>(builtin_data), reg);
304 
305   const std::string test_file = CreateFilePath("test_uint8.tflite");
306   WriteToFile(&interpreter, test_file, GetParam());
307   std::unique_ptr<FlatBufferModel> model =
308       FlatBufferModel::BuildFromFile(test_file.c_str());
309   InterpreterBuilder builder(*model, resolver);
310   std::unique_ptr<Interpreter> new_interpreter;
311   builder(&new_interpreter);
312   CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
313 }
314 
TEST_P(SingleSubgraphTest,OpVersioningTest)315 TEST_P(SingleSubgraphTest, OpVersioningTest) {
316   Interpreter interpreter;
317   interpreter.AddTensors(3);
318   interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {1, 4},
319                                            TfLiteQuantization());
320   interpreter.SetTensorParametersReadWrite(1, kTfLiteInt32, "b", {2},
321                                            TfLiteQuantization());
322   interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {4, 4},
323                                            TfLiteQuantization());
324   interpreter.SetInputs({0, 1});
325   interpreter.SetOutputs({2});
326 
327   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
328   const TfLiteRegistration* reg =
329       resolver.FindOp(BuiltinOperator_BROADCAST_TO, 2);
330   interpreter.AddNodeWithParameters(/*inputs=*/{0, 1}, /*outputs=*/{2},
331                                     /*init_data=*/nullptr, /*init_data_size=*/0,
332                                     /*builtin_data=*/nullptr, reg);
333 
334   const std::string test_file = CreateFilePath("test_float.tflite");
335   WriteToFile(&interpreter, test_file, GetParam());
336   std::unique_ptr<FlatBufferModel> model =
337       FlatBufferModel::BuildFromFile(test_file.c_str());
338   InterpreterBuilder builder(*model, resolver);
339   std::unique_ptr<Interpreter> new_interpreter;
340   builder(&new_interpreter);
341   CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
342 
343   ASSERT_EQ(new_interpreter->nodes_size(), 1);
344   TfLiteRegistration output_reg =
345       new_interpreter->node_and_registration(0)->second;
346   ASSERT_EQ(output_reg.builtin_code, BuiltinOperator_BROADCAST_TO);
347   CHECK_EQ(output_reg.version, 2);
348 }
349 
TEST_P(SingleSubgraphTest,DynamicShapeTest)350 TEST_P(SingleSubgraphTest, DynamicShapeTest) {
351   // Build a model with a single Add op.
352   Interpreter interpreter;
353 
354   interpreter.AddTensors(3);
355   std::vector<int> dims = {1, 3};
356   std::vector<int> dims_signature = {-1, 3};
357   interpreter.SetTensorParametersReadWrite(
358       0, kTfLiteFloat32, "a", dims, TfLiteQuantizationParams{1.0, 0},
359       /*is_variable=*/false, &dims_signature);
360   interpreter.SetTensorParametersReadWrite(
361       1, kTfLiteFloat32, "b", dims, TfLiteQuantizationParams{1.0, 0},
362       /*is_variable=*/false, &dims_signature);
363   interpreter.SetTensorParametersReadWrite(
364       2, kTfLiteFloat32, "c", dims, TfLiteQuantizationParams{1.0, 0},
365       /*is_variable=*/false, &dims_signature);
366 
367   interpreter.SetInputs({0, 1});
368   interpreter.SetOutputs({2});
369   const char* initial_data = "";
370   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
371   TfLiteAddParams* builtin_data =
372       reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
373   builtin_data->activation = kTfLiteActNone;
374   builtin_data->pot_scale_int16 = false;
375   const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
376   interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
377                                     reinterpret_cast<void*>(builtin_data), reg);
378 
379   // Export interpreter and import back.
380   const std::string test_file = CreateFilePath("test_dynamic_shape.tflite");
381   WriteToFile(&interpreter, test_file, GetParam());
382   std::unique_ptr<FlatBufferModel> model =
383       FlatBufferModel::BuildFromFile(test_file.c_str());
384   InterpreterBuilder builder(*model, resolver);
385   std::unique_ptr<Interpreter> new_interpreter;
386   builder(&new_interpreter);
387   CHECK_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
388 
389   // Check shape signature in new interpreter.
390   TfLiteTensor* tensor0 = new_interpreter->tensor(0);
391   CHECK_NOTNULL(tensor0->dims_signature);
392   TfLiteIntArrayView shape_view(tensor0->dims_signature);
393   CHECK_EQ(shape_view.size(), 2);
394   CHECK_EQ(shape_view[0], -1);
395 }
396 
397 INSTANTIATE_TEST_SUITE_P(Writer, SingleSubgraphTest, ::testing::Bool());
398 
399 struct ReshapeTestPattern {
400   int num_inputs;
401   bool is_param_valid;
402   bool has_buggy_non_flatten_shape;
403 };
404 
405 class ReshapeLayerTest : public ::testing::TestWithParam<ReshapeTestPattern> {};
406 
TEST_P(ReshapeLayerTest,ReshapeLayerTest)407 TEST_P(ReshapeLayerTest, ReshapeLayerTest) {
408   const auto param = GetParam();
409   Interpreter interpreter;
410   const int total_tensors = param.num_inputs + 1;
411   interpreter.AddTensors(total_tensors);
412   int output_shape[] = {1, 2, 3};
413   interpreter.SetTensorParametersReadWrite(/*tensor_index=*/0, kTfLiteFloat32,
414                                            /*name=*/"a", /*dims=*/{6},
415                                            TfLiteQuantization());
416   ASSERT_LE(param.num_inputs, 2);
417   if (param.num_inputs == 2) {
418     // Some TOCO generated models have buggy shape arguments, which are required
419     // to be flatten, for example, dims={3, 1} instead of dims={3}.
420     if (param.has_buggy_non_flatten_shape) {
421       interpreter.SetTensorParametersReadOnly(
422           /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3, 1},
423           TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
424           sizeof(output_shape));
425     } else {
426       interpreter.SetTensorParametersReadOnly(
427           /*tensor_index=*/1, kTfLiteInt32, /*name=*/"b", /*dims=*/{3},
428           TfLiteQuantization(), reinterpret_cast<char*>(output_shape),
429           sizeof(output_shape));
430     }
431   }
432   interpreter.SetTensorParametersReadWrite(/*tensor_index=*/total_tensors - 1,
433                                            kTfLiteFloat32, /*name=*/"c",
434                                            /*dims=*/{3}, TfLiteQuantization());
435 
436   std::vector<int> input_tensors(param.num_inputs);
437   std::iota(input_tensors.begin(), input_tensors.end(), 0);
438 
439   interpreter.SetInputs(input_tensors);
440   interpreter.SetOutputs({total_tensors - 1});
441   const char* initial_data = "";
442   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
443   TfLiteReshapeParams* builtin_data = reinterpret_cast<TfLiteReshapeParams*>(
444       malloc(sizeof(TfLiteReshapeParams)));
445   memset(builtin_data, 0, sizeof(TfLiteReshapeParams));
446   if (param.is_param_valid) {
447     builtin_data->num_dimensions = 3;
448     for (int dim = 0; dim < builtin_data->num_dimensions; ++dim) {
449       builtin_data->shape[dim] = output_shape[dim];
450     }
451   }
452   const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_RESHAPE, 1);
453   interpreter.AddNodeWithParameters(input_tensors,
454                                     /*outputs=*/{total_tensors - 1},
455                                     initial_data, /*init_data_size=*/0,
456                                     reinterpret_cast<void*>(builtin_data), reg);
457 
458   SubgraphWriter writer(&interpreter.primary_subgraph());
459   std::stringstream ss;
460   ss << CreateFilePath("test_reshape_") << param.num_inputs
461      << param.is_param_valid << ".tflite";
462   std::string filename = ss.str();
463   writer.Write(filename);
464   std::unique_ptr<FlatBufferModel> model =
465       FlatBufferModel::BuildFromFile(filename.c_str());
466   InterpreterBuilder builder(*model, resolver);
467   std::unique_ptr<Interpreter> new_interpreter;
468   builder(&new_interpreter);
469   ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
470 }
471 
472 INSTANTIATE_TEST_SUITE_P(
473     Writer, ReshapeLayerTest,
474     ::testing::Values(ReshapeTestPattern{/*num_inputs=*/2,
475                                          /*is_param_valid=*/true,
476                                          /*has_buggy_non_flatten_shape=*/false},
477                       ReshapeTestPattern{/*num_inputs=*/2,
478                                          /*is_param_valid=*/false,
479                                          /*has_buggy_non_flatten_shape=*/false},
480                       ReshapeTestPattern{/*num_inputs=*/1,
481                                          /*is_param_valid=*/true,
482                                          /*has_buggy_non_flatten_shape=*/false},
483                       ReshapeTestPattern{/*num_inputs=*/2,
484                                          /*is_param_valid=*/true,
485                                          /*has_buggy_non_flatten_shape=*/true}),
__anond14eaafb0102(const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) 486     [](const ::testing::TestParamInfo<ReshapeLayerTest::ParamType>& info) {
487       std::stringstream ss;
488       ss << "num_inputs_" << info.param.num_inputs << "_valid_param_"
489          << info.param.is_param_valid << "_buggy_shape_"
490          << info.param.has_buggy_non_flatten_shape;
491       std::string name = ss.str();
492       return name;
493     });
494 
495 class WhileTest : public subgraph_test_util::ControlFlowOpTest {
496  protected:
NewCustomAlloc(size_t num_bytes,int required_alignment)497   TfLiteCustomAllocation NewCustomAlloc(size_t num_bytes,
498                                         int required_alignment) {
499     // Extra memory to ensure alignment.
500     char* new_alloc = new char[num_bytes + required_alignment];
501     char* new_underlying_buffer_aligned_ptr = reinterpret_cast<char*>(
502         AlignTo(required_alignment, reinterpret_cast<intptr_t>(new_alloc)));
503     custom_alloc_buffers_.emplace_back(new_alloc);
504 
505     return TfLiteCustomAllocation(
506         {new_underlying_buffer_aligned_ptr, num_bytes});
507   }
508 
AlignTo(size_t alignment,intptr_t offset)509   intptr_t AlignTo(size_t alignment, intptr_t offset) {
510     return offset % alignment == 0 ? offset
511                                    : offset + (alignment - offset % alignment);
512   }
513 
514   std::vector<std::unique_ptr<char[]>> custom_alloc_buffers_;
515 };
516 
517 // The test builds a model that produces the i-th number of
518 // triangular number sequence: 1, 3, 6, 10, 15, 21, 28.
TEST_F(WhileTest,TestTriangularNumberSequence)519 TEST_F(WhileTest, TestTriangularNumberSequence) {
520   const int kSeqNumber = 4;
521   const int kExpectedValue = 15;
522 
523   interpreter_ = std::make_unique<Interpreter>();
524   AddSubgraphs(2);
525   builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), kSeqNumber);
526   builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
527   builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
528 
529   interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
530   interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
531   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
532   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
533 
534   // Use custom allocation for second input, to ensure things work well for
535   // non-traditional allocation types.
536   auto alloc =
537       NewCustomAlloc(interpreter_->tensor(interpreter_->inputs()[1])->bytes,
538                      kDefaultTensorAlignment);
539   auto* input_data = reinterpret_cast<int*>(alloc.data);
540   input_data[0] = 1;
541   interpreter_->SetCustomAllocationForTensor(interpreter_->inputs()[1], alloc);
542 
543   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
544   TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
545   CheckIntTensor(output1, {1}, {kSeqNumber + 1});
546   TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
547   CheckIntTensor(output2, {1}, {kExpectedValue});
548 
549   // Now serialize & deserialize model into a new Interpreter.
550   ModelWriter writer(interpreter_.get());
551   const std::string test_file = CreateFilePath("test_while.tflite");
552   writer.Write(test_file);
553   std::unique_ptr<FlatBufferModel> model =
554       FlatBufferModel::BuildFromFile(test_file.c_str());
555   tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates resolver;
556   InterpreterBuilder builder(*model, resolver);
557   std::unique_ptr<Interpreter> new_interpreter;
558   builder(&new_interpreter);
559 
560   // Check deserialized model.
561   new_interpreter->ResizeInputTensor(interpreter_->inputs()[0], {1});
562   new_interpreter->ResizeInputTensor(interpreter_->inputs()[1], {1});
563   ASSERT_EQ(new_interpreter->AllocateTensors(), kTfLiteOk);
564   FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[0]), {1});
565   FillIntTensor(new_interpreter->tensor(interpreter_->inputs()[1]), {1});
566   ASSERT_EQ(new_interpreter->Invoke(), kTfLiteOk);
567   output1 = new_interpreter->tensor(interpreter_->outputs()[0]);
568   CheckIntTensor(output1, {1}, {kSeqNumber + 1});
569   output2 = new_interpreter->tensor(interpreter_->outputs()[1]);
570   CheckIntTensor(output2, {1}, {kExpectedValue});
571 }
572 
573 // Verifies the ModelWriters constructing from  an interpreter or subgraphs
574 // produce the same results.
TEST_F(WhileTest,TestModelWriterFromSubgraphs)575 TEST_F(WhileTest, TestModelWriterFromSubgraphs) {
576   const int kSeqNumber = 4;
577   const int kExpectedValue = 15;
578 
579   interpreter_ = std::make_unique<Interpreter>();
580   AddSubgraphs(2);
581   builder_->BuildLessEqualCondSubgraph(interpreter_->subgraph(1), kSeqNumber);
582   builder_->BuildAccumulateLoopBodySubgraph(interpreter_->subgraph(2));
583   builder_->BuildWhileSubgraph(&interpreter_->primary_subgraph());
584 
585   interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
586   interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {1});
587   ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
588   FillIntTensor(interpreter_->tensor(interpreter_->inputs()[0]), {1});
589 
590   // Use custom allocation for second input, to ensure things work well for
591   // non-traditional allocation types.
592   auto alloc =
593       NewCustomAlloc(interpreter_->tensor(interpreter_->inputs()[1])->bytes,
594                      kDefaultTensorAlignment);
595   auto* input_data = reinterpret_cast<int*>(alloc.data);
596   input_data[0] = 1;
597   interpreter_->SetCustomAllocationForTensor(interpreter_->inputs()[1], alloc);
598 
599   ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
600   TfLiteTensor* output1 = interpreter_->tensor(interpreter_->outputs()[0]);
601   CheckIntTensor(output1, {1}, {kSeqNumber + 1});
602   TfLiteTensor* output2 = interpreter_->tensor(interpreter_->outputs()[1]);
603   CheckIntTensor(output2, {1}, {kExpectedValue});
604 
605   // Serializes the model using the interpreter.
606   ModelWriter writer_1(interpreter_.get());
607   const std::string test_file_1 = CreateFilePath("test_while_1.tflite");
608   writer_1.Write(test_file_1);
609 
610   // Serializes the model using subgraphs.
611   std::vector<Subgraph*> subgraphs;
612   for (int i = 0; i < interpreter_->subgraphs_size(); ++i) {
613     subgraphs.push_back(interpreter_->subgraph(i));
614   }
615   ModelWriter writer_2(subgraphs);
616   const std::string test_file_2 = CreateFilePath("test_while_2.tflite");
617   writer_2.Write(test_file_2);
618 
619   // The results from both ModelWriters should be the same.
620   std::ifstream file_ifs_1(test_file_1, std::ios::in);
621   std::ostringstream model_content_1;
622   model_content_1 << file_ifs_1.rdbuf();
623 
624   std::ifstream file_ifs_2(test_file_2, std::ios::in);
625   std::ostringstream model_content_2;
626   model_content_2 << file_ifs_2.rdbuf();
627 
628   EXPECT_FALSE(model_content_1.str().empty());
629   EXPECT_EQ(model_content_1.str(), model_content_2.str());
630 }
631 
632 }  // namespace tflite
633