1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
17
18 #include <cmath>
19
20 #include "ruy/profiler/instrumentation.h" // from @ruy
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23
24 namespace tflite {
25 namespace reference_ops {
26
27 template <typename D, typename T>
Select(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)28 void Select(const RuntimeShape& input_condition_shape,
29 const D* input_condition_data, const RuntimeShape& input_x_shape,
30 const T* input_x_data, const RuntimeShape& input_y_shape,
31 const T* input_y_data, const RuntimeShape& output_shape,
32 T* output_data) {
33 ruy::profiler::ScopeLabel label("Select");
34 int64_t flatsize;
35 // Allow select operator executions on mixed scalar tensors and one element
36 // tensors.
37 if (input_condition_shape.FlatSize() == 1 && input_x_shape.FlatSize() == 1 &&
38 input_y_shape.FlatSize() == 1 && output_shape.FlatSize() == 1) {
39 flatsize = 1;
40 } else {
41 flatsize = MatchingFlatSize(input_condition_shape, input_x_shape,
42 input_y_shape, output_shape);
43 }
44 for (int64_t i = 0; i < flatsize; ++i) {
45 output_data[i] =
46 input_condition_data[i] ? input_x_data[i] : input_y_data[i];
47 }
48 }
49
50 template <typename D, typename T>
RankOneSelect(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)51 void RankOneSelect(const RuntimeShape& input_condition_shape,
52 const D* input_condition_data,
53 const RuntimeShape& input_x_shape, const T* input_x_data,
54 const RuntimeShape& input_y_shape, const T* input_y_data,
55 const RuntimeShape& output_shape, T* output_data) {
56 ruy::profiler::ScopeLabel label("Select/RankOneSelect");
57 const int64_t outer_size = input_condition_shape.FlatSize();
58 int64_t inner_size;
59 if (input_condition_shape.DimensionsCount() == 0) {
60 inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
61 } else {
62 TFLITE_DCHECK_EQ(
63 MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
64 outer_size);
65 inner_size =
66 MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
67 }
68
69 int64_t offset = 0;
70 for (int64_t i = 0; i < outer_size; i++) {
71 const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
72 memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
73 offset += inner_size;
74 }
75 }
76
77 template <typename D, typename T>
BroadcastSelect5DSlow(const RuntimeShape & input_condition_shape,const D * input_condition_data,const RuntimeShape & input_x_shape,const T * input_x_data,const RuntimeShape & input_y_shape,const T * input_y_data,const RuntimeShape & output_shape,T * output_data)78 void BroadcastSelect5DSlow(const RuntimeShape& input_condition_shape,
79 const D* input_condition_data,
80 const RuntimeShape& input_x_shape,
81 const T* input_x_data,
82 const RuntimeShape& input_y_shape,
83 const T* input_y_data,
84 const RuntimeShape& output_shape, T* output_data) {
85 ruy::profiler::ScopeLabel label("Select/BroadcastSelectSlow");
86 TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 5);
87 TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 5);
88 TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 5);
89 TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 5);
90
91 NdArrayDesc<5> desc_condition;
92 NdArrayDesc<5> desc_x;
93 NdArrayDesc<5> desc_y;
94 NdArrayDesc<5> desc_output;
95 const RuntimeShape extended_output_shape =
96 RuntimeShape::ExtendedShape(5, output_shape);
97 CopyDimsToDesc(extended_output_shape, &desc_output);
98 NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape,
99 input_y_shape, &desc_condition, &desc_x,
100 &desc_y);
101
102 // In Tensorflow, the dimensions are canonically named (batch_number, row,
103 // col, channel), with extents (batches, height, width, depth), with the
104 // trailing dimension changing most rapidly (channels has the smallest
105 // stride, typically 1 element).
106 //
107 // In generated C code, we store arrays with the dimensions reversed. The
108 // first dimension has smallest stride.
109 //
110 // We name our variables by their Tensorflow convention, but generate C code
111 // nesting loops such that the innermost loop has the smallest stride for
112 // the best cache behavior.
113 for (int n = 0; n < desc_output.extents[0]; ++n) {
114 int out_idx_n = desc_output.extents[1] * n;
115 int cond_idx_n = desc_condition.strides[0] * n;
116 int in_idx1_n = desc_x.strides[0] * n;
117 int in_idx2_n = desc_y.strides[0] * n;
118 for (int b = 0; b < desc_output.extents[1]; ++b) {
119 int out_idx_b = (out_idx_n + b) * desc_output.extents[2];
120 int cond_idx_b = cond_idx_n + desc_condition.strides[1] * b;
121 int in_idx1_b = in_idx1_n + desc_x.strides[1] * b;
122 int in_idx2_b = in_idx2_n + desc_y.strides[1] * b;
123 for (int y = 0; y < desc_output.extents[2]; ++y) {
124 int out_idx_y = (out_idx_b + y) * desc_output.extents[3];
125 int cond_idx_y = cond_idx_b + desc_condition.strides[2] * y;
126 int in_idx1_y = in_idx1_b + desc_x.strides[2] * y;
127 int in_idx2_y = in_idx2_b + desc_y.strides[2] * y;
128 for (int x = 0; x < desc_output.extents[3]; ++x) {
129 int out_idx = (out_idx_y + x) * desc_output.extents[4];
130 int cond_idx = cond_idx_y + desc_condition.strides[3] * x;
131 int in_idx1 = in_idx1_y + desc_x.strides[3] * x;
132 int in_idx2 = in_idx2_y + desc_y.strides[3] * x;
133 for (int c = 0; c < desc_output.extents[4]; ++c) {
134 output_data[out_idx] = input_condition_data[cond_idx]
135 ? input_x_data[in_idx1]
136 : input_y_data[in_idx2];
137 out_idx++;
138 cond_idx += desc_condition.strides[4];
139 in_idx1 += desc_x.strides[4];
140 in_idx2 += desc_y.strides[4];
141 }
142 }
143 }
144 }
145 }
146 }
147
148 } // namespace reference_ops
149 } // namespace tflite
150
151 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SELECT_H_
152