xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_BATCH_ND_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_BATCH_ND_H_
17 
18 #include <cmath>
19 
20 #include "ruy/profiler/instrumentation.h"  // from @ruy
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 
24 namespace tflite {
25 namespace reference_ops {
26 
27 // TODO(b/135760455): Move this method anonymous namespace in a cc file.
ExtendShapeSpaceToBatch(const RuntimeShape & shape)28 inline RuntimeShape ExtendShapeSpaceToBatch(const RuntimeShape& shape) {
29   if (shape.DimensionsCount() == 4) {
30     return shape;
31   }
32   RuntimeShape new_shape(4, 1);
33   new_shape.SetDim(0, shape.Dims(0));
34   new_shape.SetDim(1, shape.Dims(1));
35   new_shape.SetDim(3, shape.Dims(2));
36   return new_shape;
37 }
38 
39 template <typename T>
SpaceToBatchND(const SpaceToBatchParams & params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const int32_t * block_shape_data,const RuntimeShape & unextended_input3_shape,const int32_t * paddings_data,const RuntimeShape & unextended_output_shape,T * output_data)40 inline void SpaceToBatchND(const SpaceToBatchParams& params,
41                            const RuntimeShape& unextended_input1_shape,
42                            const T* input1_data,
43                            const RuntimeShape& unextended_input2_shape,
44                            const int32_t* block_shape_data,
45                            const RuntimeShape& unextended_input3_shape,
46                            const int32_t* paddings_data,
47                            const RuntimeShape& unextended_output_shape,
48                            T* output_data) {
49   ruy::profiler::ScopeLabel label("SpaceToBatchND");
50   TFLITE_DCHECK_GE(unextended_input1_shape.DimensionsCount(), 3);
51   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
52   TFLITE_DCHECK_EQ(unextended_input1_shape.DimensionsCount(),
53                    unextended_output_shape.DimensionsCount());
54 
55   // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
56   const RuntimeShape input1_shape =
57       ExtendShapeSpaceToBatch(unextended_input1_shape);
58   const RuntimeShape output_shape =
59       ExtendShapeSpaceToBatch(unextended_output_shape);
60 
61   const int depth = input1_shape.Dims(3);
62   const int input_width = input1_shape.Dims(2);
63   const int input_height = input1_shape.Dims(1);
64   const int input_batch_size = input1_shape.Dims(0);
65 
66   const int output_width = output_shape.Dims(2);
67   const int output_height = output_shape.Dims(1);
68   const int output_batch_size = output_shape.Dims(0);
69 
70   const int block_shape_height = block_shape_data[0];
71   const int block_shape_width =
72       unextended_input1_shape.DimensionsCount() == 4 ? block_shape_data[1] : 1;
73   const int padding_top = paddings_data[0];
74   const int padding_left =
75       unextended_input1_shape.DimensionsCount() == 4 ? paddings_data[2] : 0;
76 
77   // For uint8 quantized, the correct padding "zero value" is the output offset.
78   const int32_t pad_value = params.output_offset;
79   for (int out_b = 0; out_b < output_batch_size; ++out_b) {
80     int input_batch = out_b % input_batch_size;
81     int shift_w = (out_b / input_batch_size) % block_shape_width;
82     int shift_h = (out_b / input_batch_size) / block_shape_width;
83     for (int out_h = 0; out_h < output_height; ++out_h) {
84       for (int out_w = 0; out_w < output_width; ++out_w) {
85         T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0);
86         if (out_h * block_shape_height + shift_h < padding_top ||
87             out_h * block_shape_height + shift_h >=
88                 padding_top + input_height ||
89             out_w * block_shape_width + shift_w < padding_left ||
90             out_w * block_shape_width + shift_w >= padding_left + input_width) {
91           // This may not execute correctly when pad_value != 0 and T != uint8.
92           memset(out, pad_value, depth * sizeof(T));
93         } else {
94           const T* in =
95               input1_data +
96               Offset(input1_shape, input_batch,
97                      (out_h * block_shape_height + shift_h) - padding_top,
98                      (out_w * block_shape_width + shift_w) - padding_left, 0);
99           memcpy(out, in, depth * sizeof(T));
100         }
101       }
102     }
103   }
104 }
105 
106 }  // namespace reference_ops
107 }  // namespace tflite
108 
109 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SPACE_TO_BATCH_ND_H_
110