xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/subgraph_test_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This module provides helper functions for testing the interaction between
17 // control flow ops and subgraphs.
18 // For convenience, we mostly only use `kTfLiteInt32` in this module.
19 
20 #ifndef TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
21 #define TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
22 
23 #include <stdint.h>
24 
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 #include <gtest/gtest.h>
30 #include "tensorflow/lite/core/subgraph.h"
31 #include "tensorflow/lite/interpreter.h"
32 #include "tensorflow/lite/interpreter_test_util.h"
33 
34 namespace tflite {
35 namespace subgraph_test_util {
36 
37 class SubgraphBuilder {
38  public:
39   ~SubgraphBuilder();
40 
41   // Build a subgraph with a single Add op.
42   // 2 inputs. 1 output.
43   void BuildAddSubgraph(Subgraph* subgraph);
44 
45   // Build a subgraph with a single Mul op.
46   // 2 inputs. 1 output.
47   void BuildMulSubgraph(Subgraph* subgraph);
48 
49   // Build a subgraph with a single Pad op.
50   // 2 inputs. 1 output.
51   void BuildPadSubgraph(Subgraph* subgraph);
52 
53   // Build a subgraph with a single If op.
54   // 3 inputs:
55   //   The 1st input is condition with boolean type.
56   //   The 2nd and 3rd inputs are feed input the branch subgraphs.
57   // 1 output.
58   void BuildIfSubgraph(Subgraph* subgraph);
59 
60   // Build a subgraph with a single Less op.
61   // The subgraph is used as the condition subgraph for testing `While` op.
62   // 2 inputs:
63   //   The 1st input is a counter with `kTfLiteInt32` type.
64   //   The 2nd input is ignored in this subgraph.
65   // 1 output with `kTfLiteBool` type.
66   //   Equivalent to (input < rhs).
67   void BuildLessEqualCondSubgraph(Subgraph* subgraph, int rhs);
68 
69   // An accumulate loop body subgraph. Used to produce triangle number
70   // sequence. 2 inputs and 2 outputs
71   //   Equivalent to (counter, value) -> (counter + 1, counter + 1 + value)
72   void BuildAccumulateLoopBodySubgraph(Subgraph* subgraph);
73 
74   // A pad loop body subgraph. When used in a loop it will repeatively enlarge
75   // the
76   //   tensor.
77   // 2 inputs and 2 outputs.
78   //   Equivalent to (counter, value) -> (counter + 1, tf.pad(value, padding))
79   // Note the padding is created as a constant tensor.
80   void BuildPadLoopBodySubgraph(Subgraph* subgraph,
81                                 const std::vector<int> padding);
82 
83   // Build a subgraph with a single While op.
84   // 2 inputs, 2 outputs.
85   void BuildWhileSubgraph(Subgraph* subgraph);
86 
87   // Build a subgraph that assigns a random value to a variable.
88   // No input/output.
89   void BuildAssignRandomValueToVariableSubgraph(Subgraph* graph);
90 
91   // Build a subgraph with CallOnce op and ReadVariable op.
92   // No input and 1 output.
93   void BuildCallOnceAndReadVariableSubgraph(Subgraph* graph);
94 
95   // Build a subgraph with CallOnce op, ReadVariable op and Add op.
96   // No input and 1 output.
97   void BuildCallOnceAndReadVariablePlusOneSubgraph(Subgraph* graph);
98 
99   // Build a subgraph with a single Less op.
100   // The subgraph is used as the condition subgraph for testing `While` op.
101   // 3 inputs:
102   //   The 1st and 2nd inputs are string tensors, which will be ignored.
103   //   The 3rd input is an integner value as a counter in this subgraph.
104   // 1 output with `kTfLiteBool` type.
105   //   Equivalent to (int_val < rhs).
106   void BuildLessEqualCondSubgraphWithDynamicTensor(Subgraph* subgraph, int rhs);
107 
108   // Build a subgraph with a single While op, which has 3 inputs and 3 outputs.
109   // This subgraph is used for creating/invoking dynamic allocated tensors based
110   // on string tensors.
111   //   Equivalent to (str1, str2, int_val) ->
112   //                 (str1, Fill(str1, int_val + 1), int_val + 1).
113   void BuildBodySubgraphWithDynamicTensor(Subgraph* subgraph);
114 
115   // Build a subgraph with a single While op, that contains 3 inputs and 3
116   // outputs (str1, str2, int_val).
117   void BuildWhileSubgraphWithDynamicTensor(Subgraph* subgraph);
118 
119  private:
120   void CreateConstantInt32Tensor(Subgraph* subgraph, int tensor_index,
121                                  const std::vector<int>& shape,
122                                  const std::vector<int>& data);
123   std::vector<void*> buffers_;
124 };
125 
126 class ControlFlowOpTest : public InterpreterTest {
127  public:
ControlFlowOpTest()128   ControlFlowOpTest() : builder_(new SubgraphBuilder) {}
129 
~ControlFlowOpTest()130   ~ControlFlowOpTest() override {
131     builder_.reset();
132   }
133 
134  protected:
135   std::unique_ptr<SubgraphBuilder> builder_;
136 };
137 
138 // Fill a `TfLiteTensor` with a 32-bits integer vector.
139 // Preconditions:
140 // * The tensor must have `kTfLiteInt32` type.
141 // * The tensor must be allocated.
142 // * The element count of the tensor must be equal to the length or
143 //   the vector.
144 void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data);
145 
146 // Fill a `TfLiteTensor` with a string value.
147 // Preconditions:
148 // * The tensor must have `kTfLitString` type.
149 void FillScalarStringTensor(TfLiteTensor* tensor, const std::string& data);
150 
151 // Check if the scalar string data of a tensor is as expected.
152 void CheckScalarStringTensor(const TfLiteTensor* tensor,
153                              const std::string& data);
154 
155 // Check if the shape and string data of a tensor is as expected.
156 void CheckStringTensor(const TfLiteTensor* tensor,
157                        const std::vector<int>& shape,
158                        const std::vector<std::string>& data);
159 
160 // Check if the shape and int32 data of a tensor is as expected.
161 void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
162                     const std::vector<int32_t>& data);
163 // Check if the shape and bool data of a tensor is as expected.
164 void CheckBoolTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
165                      const std::vector<bool>& data);
166 
167 }  // namespace subgraph_test_util
168 }  // namespace tflite
169 
170 #endif  // TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
171