xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/split_v.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <vector>
18 
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/common.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 
28 namespace tflite {
29 namespace ops {
30 namespace builtin {
31 namespace split_v {
32 
33 struct OpContext {
OpContexttflite::ops::builtin::split_v::OpContext34   OpContext(TfLiteContext* context, TfLiteNode* node) {
35     params = reinterpret_cast<TfLiteSplitVParams*>(node->builtin_data);
36     input = GetInput(context, node, 0);
37     size_splits = GetInput(context, node, 1);
38     axis = GetInput(context, node, 2);
39   }
40   TfLiteSplitVParams* params;
41   const TfLiteTensor* input;
42   const TfLiteTensor* size_splits;
43   const TfLiteTensor* axis;
44 };
45 
UseDynamicOutputTensors(TfLiteContext * context,TfLiteNode * node)46 TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
47   for (int i = 0; i < NumOutputs(node); ++i) {
48     TfLiteTensor* tensor;
49     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
50     SetTensorToDynamic(tensor);
51   }
52   return kTfLiteOk;
53 }
54 
55 template <typename T>
GetSizeSplitsVector(const TfLiteTensor * size_splits,std::vector<int64_t> * size_splits_vector)56 void GetSizeSplitsVector(const TfLiteTensor* size_splits,
57                          std::vector<int64_t>* size_splits_vector) {
58   const auto num_elements = NumElements(size_splits);
59   for (int i = 0; i < num_elements; ++i) {
60     size_splits_vector->push_back(GetTensorData<T>(size_splits)[i]);
61   }
62 }
63 
ResizeOutputTensors(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * input,const TfLiteTensor * size_splits,const TfLiteTensor * axis)64 TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
65                                  const TfLiteTensor* input,
66                                  const TfLiteTensor* size_splits,
67                                  const TfLiteTensor* axis) {
68   int axis_value = GetTensorData<int>(axis)[0];
69   if (axis_value < 0) {
70     axis_value += NumDimensions(input);
71   }
72 
73   std::vector<int64_t> size_splits_vector;
74   if (size_splits->type == kTfLiteInt32) {
75     GetSizeSplitsVector<int32_t>(size_splits, &size_splits_vector);
76   } else if (size_splits->type == kTfLiteInt64) {
77     GetSizeSplitsVector<int64_t>(size_splits, &size_splits_vector);
78   } else {
79     TF_LITE_KERNEL_LOG(context, "size_splits only support type int32|int64.");
80     return kTfLiteError;
81   }
82 
83   int minus_one_index = -1;
84   int64_t size_splits_sum = 0;
85 
86   for (int i = 0; i < size_splits_vector.size(); ++i) {
87     if (size_splits_vector.at(i) == -1) {
88       if (minus_one_index == -1) {
89         minus_one_index = i;
90       } else {
91         TF_LITE_KERNEL_LOG(context,
92                            "The size_splits contains more than one -1.");
93         return kTfLiteError;
94       }
95     } else {
96       size_splits_sum += size_splits_vector.at(i);
97     }
98   }
99 
100   TF_LITE_ENSURE(context, axis_value >= 0);
101   TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
102   const int input_size = SizeOfDimension(input, axis_value);
103 
104   if (minus_one_index != -1) {
105     if (size_splits_sum > input_size) {
106       TF_LITE_KERNEL_LOG(
107           context,
108           "The sum of size_splits must be less than the dimension of value.");
109     } else {
110       size_splits_vector[minus_one_index] = input_size - size_splits_sum;
111     }
112   } else if (size_splits_sum != input_size) {
113     TF_LITE_KERNEL_LOG(
114         context,
115         "The size_splits must sum to the dimension of value along axis.");
116   }
117 
118   for (int i = 0; i < NumOutputs(node); ++i) {
119     TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
120     output_dims->data[axis_value] = size_splits_vector.at(i);
121     TfLiteTensor* output;
122     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
123     TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
124   }
125 
126   return kTfLiteOk;
127 }
128 
Prepare(TfLiteContext * context,TfLiteNode * node)129 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
130   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
131 
132   OpContext op_context(context, node);
133 
134   TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
135 
136   auto input_type = op_context.input->type;
137   TF_LITE_ENSURE(context,
138                  input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
139                      input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
140                      input_type == kTfLiteInt64 || input_type == kTfLiteInt8);
141   for (int i = 0; i < NumOutputs(node); ++i) {
142     TfLiteTensor* tensor;
143     TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
144     tensor->type = input_type;
145   }
146 
147   auto size_splits = op_context.size_splits;
148   TF_LITE_ENSURE_EQ(context, NumDimensions(size_splits), 1);
149   TF_LITE_ENSURE_EQ(context, NumOutputs(node), NumElements(size_splits));
150 
151   // If we know the contents of the 'size_splits' tensor and the 'axis' tensor,
152   // resize all outputs. Otherwise, wait until Eval().
153   if (IsConstantTensor(op_context.size_splits) &&
154       IsConstantTensor(op_context.axis)) {
155     return ResizeOutputTensors(context, node, op_context.input,
156                                op_context.size_splits, op_context.axis);
157   } else {
158     return UseDynamicOutputTensors(context, node);
159   }
160 }
161 
Eval(TfLiteContext * context,TfLiteNode * node)162 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
163   OpContext op_context(context, node);
164 
165   // When the 'size_splits' and the 'axis' tensor is non-const we can't resize
166   // output tensors in Prepare(), and we have to do it now.
167   if (!IsConstantTensor(op_context.axis) ||
168       !IsConstantTensor(op_context.size_splits)) {
169     TF_LITE_ENSURE_OK(
170         context, ResizeOutputTensors(context, node, op_context.input,
171                                      op_context.size_splits, op_context.axis));
172   }
173 
174   int axis_value = GetTensorData<int>(op_context.axis)[0];
175 
176   // Use split function to build the outputs since they share the same logic.
177 #define TF_LITE_SPLIT_V(scalar)                                     \
178   VectorOfTensors<scalar> all_outputs(*context, *node->outputs);    \
179   tflite::SplitParams op_params;                                    \
180   op_params.num_split = NumOutputs(node);                           \
181   op_params.axis = axis_value;                                      \
182   reference_ops::Split(op_params, GetTensorShape(op_context.input), \
183                        GetTensorData<scalar>(op_context.input),     \
184                        all_outputs.shapes(), all_outputs.data());
185   switch (op_context.input->type) {
186     case kTfLiteFloat32: {
187       TF_LITE_SPLIT_V(float);
188       break;
189     }
190     case kTfLiteUInt8: {
191       TF_LITE_SPLIT_V(uint8_t);
192       break;
193     }
194     case kTfLiteInt16: {
195       TF_LITE_SPLIT_V(int16_t);
196       break;
197     }
198     case kTfLiteInt32: {
199       TF_LITE_SPLIT_V(int32_t);
200       break;
201     }
202     case kTfLiteInt64: {
203       TF_LITE_SPLIT_V(int64_t);
204       break;
205     }
206     case kTfLiteInt8: {
207       TF_LITE_SPLIT_V(int8_t);
208       break;
209     }
210     default:
211       TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
212                          TfLiteTypeGetName(op_context.input->type));
213       return kTfLiteError;
214   }
215 #undef TF_LITE_SPLIT_V
216 
217   return kTfLiteOk;
218 }
219 
220 }  // namespace split_v
221 
Register_SPLIT_V()222 TfLiteRegistration* Register_SPLIT_V() {
223   static TfLiteRegistration r = {nullptr, nullptr, split_v::Prepare,
224                                  split_v::Eval};
225   return &r;
226 }
227 
228 }  // namespace builtin
229 }  // namespace ops
230 }  // namespace tflite
231