xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/select.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
17 
18 #include <cmath>
19 
20 #include "ruy/profiler/instrumentation.h"  // from @ruy
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 
24 namespace tflite {
25 namespace reference_ops {
26 
27 template <typename D, typename T>
Select(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)28 void Select(const RuntimeShape& input_condition_shape,
29             const D* input_condition_data, const RuntimeShape& input_x_shape,
30             const T* input_x_data, const RuntimeShape& input_y_shape,
31             const T* input_y_data, const RuntimeShape& output_shape,
32             T* output_data) {
33   ruy::profiler::ScopeLabel label("Select");
34   int64_t flatsize;
35   // Allow select operator executions on mixed scalar tensors and one element
36   // tensors.
37   if (input_condition_shape.FlatSize() == 1 && input_x_shape.FlatSize() == 1 &&
38       input_y_shape.FlatSize() == 1 && output_shape.FlatSize() == 1) {
39     flatsize = 1;
40   } else {
41     flatsize = MatchingFlatSize(input_condition_shape, input_x_shape,
42                                 input_y_shape, output_shape);
43   }
44   for (int64_t i = 0; i < flatsize; ++i) {
45     output_data[i] =
46         input_condition_data[i] ? input_x_data[i] : input_y_data[i];
47   }
48 }
49 
50 template <typename D, typename T>
RankOneSelect(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)51 void RankOneSelect(const RuntimeShape& input_condition_shape,
52                    const D* input_condition_data,
53                    const RuntimeShape& input_x_shape, const T* input_x_data,
54                    const RuntimeShape& input_y_shape, const T* input_y_data,
55                    const RuntimeShape& output_shape, T* output_data) {
56   ruy::profiler::ScopeLabel label("Select/RankOneSelect");
57   const int64_t outer_size = input_condition_shape.FlatSize();
58   int64_t inner_size;
59   if (input_condition_shape.DimensionsCount() == 0) {
60     inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
61   } else {
62     TFLITE_DCHECK_EQ(
63         MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
64         outer_size);
65     inner_size =
66         MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
67   }
68 
69   int64_t offset = 0;
70   for (int64_t i = 0; i < outer_size; i++) {
71     const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
72     memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
73     offset += inner_size;
74   }
75 }
76 
77 template <typename D, typename T>
BroadcastSelect5DSlow(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)78 void BroadcastSelect5DSlow(const RuntimeShape& input_condition_shape,
79                            const D* input_condition_data,
80                            const RuntimeShape& input_x_shape,
81                            const T* input_x_data,
82                            const RuntimeShape& input_y_shape,
83                            const T* input_y_data,
84                            const RuntimeShape& output_shape, T* output_data) {
85   ruy::profiler::ScopeLabel label("Select/BroadcastSelectSlow");
86   TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 5);
87   TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 5);
88   TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 5);
89   TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 5);
90 
91   NdArrayDesc<5> desc_condition;
92   NdArrayDesc<5> desc_x;
93   NdArrayDesc<5> desc_y;
94   NdArrayDesc<5> desc_output;
95   const RuntimeShape extended_output_shape =
96       RuntimeShape::ExtendedShape(5, output_shape);
97   CopyDimsToDesc(extended_output_shape, &desc_output);
98   NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
99                                       input_y_shape, &desc_condition, &desc_x,
100                                       &desc_y);
101 
102   // In Tensorflow, the dimensions are canonically named (batch_number, row,
103   // col, channel), with extents (batches, height, width, depth), with the
104   // trailing dimension changing most rapidly (channels has the smallest
105   // stride, typically 1 element).
106   //
107   // In generated C code, we store arrays with the dimensions reversed. The
108   // first dimension has smallest stride.
109   //
110   // We name our variables by their Tensorflow convention, but generate C code
111   // nesting loops such that the innermost loop has the smallest stride for
112   // the best cache behavior.
113   for (int n = 0; n < desc_output.extents[0]; ++n) {
114     int out_idx_n = desc_output.extents[1] * n;
115     int cond_idx_n = desc_condition.strides[0] * n;
116     int in_idx1_n = desc_x.strides[0] * n;
117     int in_idx2_n = desc_y.strides[0] * n;
118     for (int b = 0; b < desc_output.extents[1]; ++b) {
119       int out_idx_b = (out_idx_n + b) * desc_output.extents[2];
120       int cond_idx_b = cond_idx_n + desc_condition.strides[1] * b;
121       int in_idx1_b = in_idx1_n + desc_x.strides[1] * b;
122       int in_idx2_b = in_idx2_n + desc_y.strides[1] * b;
123       for (int y = 0; y < desc_output.extents[2]; ++y) {
124         int out_idx_y = (out_idx_b + y) * desc_output.extents[3];
125         int cond_idx_y = cond_idx_b + desc_condition.strides[2] * y;
126         int in_idx1_y = in_idx1_b + desc_x.strides[2] * y;
127         int in_idx2_y = in_idx2_b + desc_y.strides[2] * y;
128         for (int x = 0; x < desc_output.extents[3]; ++x) {
129           int out_idx = (out_idx_y + x) * desc_output.extents[4];
130           int cond_idx = cond_idx_y + desc_condition.strides[3] * x;
131           int in_idx1 = in_idx1_y + desc_x.strides[3] * x;
132           int in_idx2 = in_idx2_y + desc_y.strides[3] * x;
133           for (int c = 0; c < desc_output.extents[4]; ++c) {
134             output_data[out_idx] = input_condition_data[cond_idx]
135                                        ? input_x_data[in_idx1]
136                                        : input_y_data[in_idx2];
137             out_idx++;
138             cond_idx += desc_condition.strides[4];
139             in_idx1 += desc_x.strides[4];
140             in_idx2 += desc_y.strides[4];
141           }
142         }
143       }
144     }
145   }
146 }
147 
148 }  // namespace reference_ops
149 }  // namespace tflite
150 
151 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
152