xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/batch_to_space_nd.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/compatibility.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace batch_to_space_nd {
29 
30 // This file has two implementations of BatchToSpaceND.
31 enum KernelType {
32   kReference,
33   kGenericOptimized,
34 };
35 
36 struct BatchToSpaceNDContext {
BatchToSpaceNDContexttflite::ops::builtin::batch_to_space_nd::BatchToSpaceNDContext37   BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) {
38     input = GetInput(context, node, 0);
39     block_shape = GetInput(context, node, 1);
40     crops = GetInput(context, node, 2);
41     output = GetOutput(context, node, 0);
42   }
43   const TfLiteTensor* input;
44   const TfLiteTensor* block_shape;
45   const TfLiteTensor* crops;
46   TfLiteTensor* output;
47 };
48 
49 // Currently, only 3D NHC or 4D NHWC input/output op_context are supported.
50 // In case of 3D input,it will be converted to 4D by adding W=1 to be NH1C.
51 // The 4D array need to have exactly 2 spatial dimensions.
52 // TODO(ycling): Support arbitrary dimension in BatchToSpaceND.
53 const int kInputMinDimensionNum = 3;
54 const int kInputMaxDimensionNum = 4;
55 
ResizeOutputTensor(TfLiteContext * context,BatchToSpaceNDContext * op_context)56 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
57                                 BatchToSpaceNDContext* op_context) {
58   TfLiteIntArray* input_size = op_context->input->dims;
59   const int* block_shape = GetTensorData<int32>(op_context->block_shape);
60   const int* crops = GetTensorData<int32>(op_context->crops);
61 
62   int spatial_dims_num = input_size->size - 2;
63   // Block_shape should be a 1D tensor with dimension [spatial_dims_num].
64   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), 1);
65   TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0],
66                     spatial_dims_num);
67   // Crops should be a 2D tensor with dimension [spatial_dims_num, 2].
68   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), 2);
69   TF_LITE_ENSURE_EQ(context, op_context->crops->dims->data[0],
70                     spatial_dims_num);
71   TF_LITE_ENSURE_EQ(context, op_context->crops->dims->data[1], 2);
72 
73   for (int i = 0; i < spatial_dims_num * 2; ++i) {
74     TF_LITE_ENSURE(context, crops[i] >= 0);
75   }
76 
77   TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
78   int output_batch_size = input_size->data[0];
79   for (int dim = 0; dim < spatial_dims_num; ++dim) {
80     // Number of batch must be multiple of (block_shape[dim]).
81     TF_LITE_ENSURE(context, block_shape[dim] != 0);
82     TF_LITE_ENSURE_EQ(context, output_batch_size % block_shape[dim], 0);
83     output_batch_size = output_batch_size / block_shape[dim];
84     output_size->data[dim + 1] = input_size->data[dim + 1] * block_shape[dim] -
85                                  crops[dim * 2] - crops[dim * 2 + 1];
86   }
87 
88   output_size->data[0] = output_batch_size;
89   output_size->data[input_size->size - 1] =
90       input_size->data[input_size->size - 1];
91 
92   return context->ResizeTensor(context, op_context->output, output_size);
93 }
94 
Prepare(TfLiteContext * context,TfLiteNode * node)95 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
96   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
97   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
98 
99   BatchToSpaceNDContext op_context(context, node);
100   TF_LITE_ENSURE(context,
101                  NumDimensions(op_context.input) >= kInputMinDimensionNum);
102   TF_LITE_ENSURE(context,
103                  NumDimensions(op_context.input) <= kInputMaxDimensionNum);
104   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
105 
106   if (!IsConstantTensor(op_context.block_shape) ||
107       !IsConstantTensor(op_context.crops)) {
108     SetTensorToDynamic(op_context.output);
109     return kTfLiteOk;
110   }
111   return ResizeOutputTensor(context, &op_context);
112 }
113 
114 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)115 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
116   BatchToSpaceNDContext op_context(context, node);
117 
118   // Resize the output tensor if the output tensor is dynamic.
119   if (IsDynamicTensor(op_context.output)) {
120     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
121   }
122 
123 #define TF_LITE_BATCH_TO_SPACE_ND(type, scalar)                        \
124   type::BatchToSpaceND(GetTensorShape(op_context.input),               \
125                        GetTensorData<scalar>(op_context.input),        \
126                        GetTensorShape(op_context.block_shape),         \
127                        GetTensorData<int32_t>(op_context.block_shape), \
128                        GetTensorShape(op_context.crops),               \
129                        GetTensorData<int32_t>(op_context.crops),       \
130                        GetTensorShape(op_context.output),              \
131                        GetTensorData<scalar>(op_context.output))
132   switch (op_context.input->type) {  // Already know in/out types are same.
133     case kTfLiteFloat32:
134       if (kernel_type == kReference) {
135         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float);
136       } else {
137         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float);
138       }
139       break;
140     case kTfLiteUInt8:
141       if (kernel_type == kReference) {
142         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t);
143       } else {
144         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t);
145       }
146       break;
147     case kTfLiteInt8:
148       if (kernel_type == kReference) {
149         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int8_t);
150       } else {
151         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int8_t);
152       }
153       break;
154     case kTfLiteInt32:
155       if (kernel_type == kReference) {
156         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t);
157       } else {
158         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t);
159       }
160       break;
161     case kTfLiteInt64:
162       if (kernel_type == kReference) {
163         TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t);
164       } else {
165         TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t);
166       }
167       break;
168     default:
169       TF_LITE_KERNEL_LOG(context,
170                          "Type %d is currently not supported by BatchToSpace.",
171                          op_context.input->type);
172       return kTfLiteError;
173   }
174 #undef TF_LITE_BATCH_TO_SPACE_ND
175   return kTfLiteOk;
176 }
177 
178 }  // namespace batch_to_space_nd
179 
Register_BATCH_TO_SPACE_ND_REF()180 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() {
181   static TfLiteRegistration r = {
182       nullptr, nullptr, batch_to_space_nd::Prepare,
183       batch_to_space_nd::Eval<batch_to_space_nd::kReference>};
184   return &r;
185 }
186 
Register_BATCH_TO_SPACE_ND_GENERIC_OPT()187 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() {
188   static TfLiteRegistration r = {
189       nullptr, nullptr, batch_to_space_nd::Prepare,
190       batch_to_space_nd::Eval<batch_to_space_nd::kGenericOptimized>};
191   return &r;
192 }
193 
Register_BATCH_TO_SPACE_ND()194 TfLiteRegistration* Register_BATCH_TO_SPACE_ND() {
195   return Register_BATCH_TO_SPACE_ND_GENERIC_OPT();
196 }
197 
198 }  // namespace builtin
199 }  // namespace ops
200 }  // namespace tflite
201