xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/concatenation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
16 
17 #include <stdint.h>
18 #include <limits>
19 
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/internal/types.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 
30 namespace tflite {
31 namespace ops {
32 namespace builtin {
33 namespace concatenation {
34 
35 // This file has two implementation of Concatenation.
36 enum KernelType {
37   kReference,
38   kGenericOptimized,
39 };
40 
41 template <KernelType kernel_type>
EvalImpl(TfLiteContext * context,TfLiteNode * node,int axis,TfLiteTensor * output)42 TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, int axis,
43                       TfLiteTensor* output) {
44 // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
45 // allocate and populate these during Prepare().
46 // TODO(ycling): Activation function parameter is ignored. For now we don't have
47 // a model with a Concatenation with fused activation function.
48 #define TF_LITE_CONCATENATION(scalar)                                         \
49   {                                                                           \
50     VectorOfTensors<scalar> all_inputs(*context, *node->inputs);              \
51     tflite::ConcatenationParams op_params;                                    \
52     op_params.axis = axis;                                                    \
53     op_params.inputs_count = node->inputs->size;                              \
54     if (kernel_type == kReference) {                                          \
55       reference_ops::Concatenation(op_params, all_inputs.shapes(),            \
56                                    all_inputs.data(), GetTensorShape(output), \
57                                    GetTensorData<scalar>(output));            \
58     } else {                                                                  \
59       optimized_ops::Concatenation(op_params, all_inputs.shapes(),            \
60                                    all_inputs.data(), GetTensorShape(output), \
61                                    GetTensorData<scalar>(output));            \
62     }                                                                         \
63   }
64 
65 #define TF_LITE_CONCATENATION_QUANTIZED()                         \
66   {                                                               \
67     VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
68     tflite::ConcatenationParams op_params;                        \
69     op_params.axis = axis;                                        \
70     op_params.input_zeropoint = all_inputs.zero_point();          \
71     op_params.input_scale = all_inputs.scale();                   \
72     op_params.inputs_count = node->inputs->size;                  \
73     op_params.output_zeropoint = output->params.zero_point;       \
74     op_params.output_scale = output->params.scale;                \
75     if (kernel_type == kReference) {                              \
76       reference_ops::ConcatenationWithScaling(                    \
77           op_params, all_inputs.shapes(), all_inputs.data(),      \
78           GetTensorShape(output), GetTensorData<uint8>(output));  \
79     } else {                                                      \
80       optimized_ops::ConcatenationWithScaling(                    \
81           op_params, all_inputs.shapes(), all_inputs.data(),      \
82           GetTensorShape(output), GetTensorData<uint8>(output));  \
83     }                                                             \
84   }
85 
86   switch (output->type) {  // Already know in/outtypes are same.
87     case kTfLiteFloat32:
88       TF_LITE_CONCATENATION(float);
89       break;
90     case kTfLiteInt32:
91       TF_LITE_CONCATENATION(int32);
92       break;
93     case kTfLiteUInt8:
94       TF_LITE_CONCATENATION_QUANTIZED();
95       break;
96     case kTfLiteInt8:
97       TF_LITE_CONCATENATION(int8_t);
98       break;
99     case kTfLiteInt64:
100       TF_LITE_CONCATENATION(int64_t);
101       break;
102     case kTfLiteInt16:
103       TF_LITE_CONCATENATION(int16_t);
104       break;
105     case kTfLiteBool:
106       TF_LITE_CONCATENATION(bool);
107       break;
108     default:
109       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported currently.",
110                          TfLiteTypeGetName(output->type));
111       return kTfLiteError;
112   }
113 
114 #undef TF_LITE_CONCATENATION_QUANTIZED
115 #undef TF_LITE_CONCATENATION
116 
117   return kTfLiteOk;
118 }
119 
Prepare(TfLiteContext * context,TfLiteNode * node)120 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
121   auto* params =
122       reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
123   int axis = params->axis;
124   int num_inputs = node->inputs->size;
125 
126   // The number of dimensions of the input tensors must match, and all
127   // dimensions except 'axis' must be equal.
128   const TfLiteTensor* t0;
129   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &t0));
130   TfLiteType input_type = t0->type;
131   if (axis < 0) axis += t0->dims->size;
132   TF_LITE_ENSURE(context, axis >= 0);
133   TF_LITE_ENSURE(context, axis < t0->dims->size);
134 
135   TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
136   TF_LITE_ENSURE(context,
137                  input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
138                      input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
139                      input_type == kTfLiteInt32 || input_type == kTfLiteInt64 ||
140                      input_type == kTfLiteBool);
141 
142   // Output dimensions will match input dimensions, except 'axis', which
143   // will be the sum of inputs
144   int sum_axis = t0->dims->data[axis];
145   for (int i = 1; i < num_inputs; ++i) {
146     const TfLiteTensor* t;
147     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
148     TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
149     TF_LITE_ENSURE_EQ(context, t->type, input_type);
150     for (int d = 0; d < t0->dims->size; ++d) {
151       if (d == axis) {
152         // Avoid integer overflow in sum_axis below
153         TF_LITE_ENSURE(context, t->dims->data[axis] >= 0);
154         TF_LITE_ENSURE(context, t->dims->data[axis] <=
155                                     std::numeric_limits<int>::max() - sum_axis);
156         sum_axis += t->dims->data[axis];
157       } else {
158         TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
159       }
160     }
161   }
162 
163   TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
164   for (int d = 0; d < t0->dims->size; ++d) {
165     output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
166   }
167 
168   TfLiteTensor* output;
169   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
170   TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
171 
172   if (input_type == kTfLiteInt8) {
173     // Make sure there is no re-scaling needed for Int8 quantized kernel. This
174     // is a restriction we introduced to Int8 kernels.
175     VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
176     for (int i = 0; i < node->inputs->size; ++i) {
177       const TfLiteTensor* t;
178       TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
179       TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
180       TF_LITE_ENSURE_EQ(context, t->params.zero_point,
181                         output->params.zero_point);
182     }
183   }
184 
185   if (input_type == kTfLiteInt16) {
186     // Make sure that all Int16 inputs have a null zero-point.
187     for (int i = 0; i < node->inputs->size; ++i) {
188       const TfLiteTensor* t = GetInput(context, node, i);
189       TF_LITE_ENSURE_EQ(context, t->params.zero_point, 0);
190     }
191     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
192   }
193 
194   // Check to see if we can calculate the output now.
195   bool all_inputs_at_prepare = true;
196   for (int i = 0; i < num_inputs; ++i) {
197     const TfLiteTensor* t;
198     TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
199     if (!IsConstantOrPersistentTensor(t)) {
200       all_inputs_at_prepare = false;
201       break;
202     }
203   }
204   if (all_inputs_at_prepare) {
205     SetTensorToPersistentRo(output);
206     context->ResizeTensor(context, output, output_size);
207     return EvalImpl<kReference>(context, node, axis, output);
208   }
209   return context->ResizeTensor(context, output, output_size);
210 }
211 
212 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)213 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
214   auto* params =
215       reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
216   int axis = params->axis;
217   TfLiteTensor* output;
218   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
219   if (IsConstantOrPersistentTensor(output)) {
220     // Output is computed in Prepare.
221     return kTfLiteOk;
222   }
223   if (axis < 0) axis += output->dims->size;
224 
225   return EvalImpl<kernel_type>(context, node, axis, output);
226 }
227 
228 #undef TF_LITE_MACRO_DISPATCH
229 
230 }  // namespace concatenation
231 
Register_CONCATENATION_REF()232 TfLiteRegistration* Register_CONCATENATION_REF() {
233   static TfLiteRegistration r = {
234       nullptr, nullptr, concatenation::Prepare,
235       concatenation::Eval<concatenation::kReference>};
236   return &r;
237 }
238 
Register_CONCATENATION_GENERIC_OPT()239 TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
240   static TfLiteRegistration r = {
241       nullptr, nullptr, concatenation::Prepare,
242       concatenation::Eval<concatenation::kGenericOptimized>};
243   return &r;
244 }
245 
Register_CONCATENATION()246 TfLiteRegistration* Register_CONCATENATION() {
247   // TODO(ahentz): It turns out the two versions of Concatenation are almost
248   // identical, so we should consider removing one.
249   return Register_CONCATENATION_GENERIC_OPT();
250 }
251 
252 }  // namespace builtin
253 }  // namespace ops
254 }  // namespace tflite
255