1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include <algorithm>
22 #include <cmath>
23 #include <cstring>
24 #include <functional>
25 #include <limits>
26 #include <memory>
27 #include <type_traits>
28
29 #include "Eigen/Core"
30 #include "fixedpoint/fixedpoint.h"
31 #include "ruy/profiler/instrumentation.h" // from @ruy
32 #include "tensorflow/lite/c/c_api_types.h"
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/kernels/internal/common.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/reference/add.h"
37 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
38 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
39 #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
40 #include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
41 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
42 #include "tensorflow/lite/kernels/internal/reference/cast.h"
43 #include "tensorflow/lite/kernels/internal/reference/ceil.h"
44 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
45 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
46 #include "tensorflow/lite/kernels/internal/reference/conv.h"
47 #include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
48 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
49 #include "tensorflow/lite/kernels/internal/reference/div.h"
50 #include "tensorflow/lite/kernels/internal/reference/elu.h"
51 #include "tensorflow/lite/kernels/internal/reference/exp.h"
52 #include "tensorflow/lite/kernels/internal/reference/fill.h"
53 #include "tensorflow/lite/kernels/internal/reference/floor.h"
54 #include "tensorflow/lite/kernels/internal/reference/floor_div.h"
55 #include "tensorflow/lite/kernels/internal/reference/floor_mod.h"
56 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
57 #include "tensorflow/lite/kernels/internal/reference/gather.h"
58 #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
59 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
60 #include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
61 #include "tensorflow/lite/kernels/internal/reference/log_softmax.h"
62 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
63 #include "tensorflow/lite/kernels/internal/reference/lstm_cell.h"
64 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
65 #include "tensorflow/lite/kernels/internal/reference/mul.h"
66 #include "tensorflow/lite/kernels/internal/reference/neg.h"
67 #include "tensorflow/lite/kernels/internal/reference/pad.h"
68 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
69 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
70 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
71 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
72 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
73 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
74 #include "tensorflow/lite/kernels/internal/reference/resize_bilinear.h"
75 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
76 #include "tensorflow/lite/kernels/internal/reference/round.h"
77 #include "tensorflow/lite/kernels/internal/reference/select.h"
78 #include "tensorflow/lite/kernels/internal/reference/slice.h"
79 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
80 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
81 #include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
82 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
83 #include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
84 #include "tensorflow/lite/kernels/internal/reference/sub.h"
85 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
86 #include "tensorflow/lite/kernels/internal/reference/transpose.h"
87 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
88 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
89 #include "tensorflow/lite/kernels/internal/tensor.h"
90 namespace tflite {
91
92 namespace reference_ops {
93
94 template <typename T>
Relu(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)95 inline void Relu(const RuntimeShape& input_shape, const T* input_data,
96 const RuntimeShape& output_shape, T* output_data) {
97 const int flat_size = MatchingFlatSize(input_shape, output_shape);
98 for (int i = 0; i < flat_size; ++i) {
99 const T val = input_data[i];
100 const T lower = 0;
101 const T clamped = val < lower ? lower : val;
102 output_data[i] = clamped;
103 }
104 }
105
106 template <typename T>
Relu0To1(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)107 inline void Relu0To1(const RuntimeShape& input_shape, const T* input_data,
108 const RuntimeShape& output_shape, T* output_data) {
109 ruy::profiler::ScopeLabel label("Relu0To1 (not fused)");
110 const int flat_size = MatchingFlatSize(input_shape, output_shape);
111 for (int i = 0; i < flat_size; ++i) {
112 const T val = input_data[i];
113 const T upper = 1;
114 const T lower = 0;
115 const T clamped = val > upper ? upper : val < lower ? lower : val;
116 output_data[i] = clamped;
117 }
118 }
119
120 template <typename T>
Relu1(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)121 inline void Relu1(const RuntimeShape& input_shape, const T* input_data,
122 const RuntimeShape& output_shape, T* output_data) {
123 ruy::profiler::ScopeLabel label("Relu1 (not fused)");
124 const int flat_size = MatchingFlatSize(input_shape, output_shape);
125 for (int i = 0; i < flat_size; ++i) {
126 const T val = input_data[i];
127 const T upper = 1;
128 const T lower = -1;
129 const T clamped = val > upper ? upper : val < lower ? lower : val;
130 output_data[i] = clamped;
131 }
132 }
133
Relu6(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)134 inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
135 const RuntimeShape& output_shape, float* output_data) {
136 ruy::profiler::ScopeLabel label("Relu6 (not fused)");
137 const int flat_size = MatchingFlatSize(input_shape, output_shape);
138 for (int i = 0; i < flat_size; ++i) {
139 const float val = input_data[i];
140 const float upper = 6;
141 const float lower = 0;
142 const float clamped = val > upper ? upper : val < lower ? lower : val;
143 output_data[i] = clamped;
144 }
145 }
146
147 template <typename T>
ReluX(const tflite::ReluParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)148 inline void ReluX(const tflite::ReluParams& params,
149 const RuntimeShape& input_shape, const T* input_data,
150 const RuntimeShape& output_shape, T* output_data) {
151 ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
152 const int flat_size = MatchingFlatSize(input_shape, output_shape);
153 for (int i = 0; i < flat_size; ++i) {
154 const int32 val = static_cast<int32_t>(input_data[i]);
155 int32 clamped = params.output_offset +
156 MultiplyByQuantizedMultiplier(val - params.input_offset,
157 params.output_multiplier,
158 params.output_shift);
159 clamped = std::max(params.quantized_activation_min, clamped);
160 clamped = std::min(params.quantized_activation_max, clamped);
161 output_data[i] = static_cast<T>(clamped);
162 }
163 }
164
165 template <typename T>
ReluX(const tflite::ActivationParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)166 inline void ReluX(const tflite::ActivationParams& params,
167 const RuntimeShape& input_shape, const T* input_data,
168 const RuntimeShape& output_shape, T* output_data) {
169 ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
170 const int flat_size = MatchingFlatSize(input_shape, output_shape);
171 const T max_value = params.quantized_activation_max;
172 const T min_value = params.quantized_activation_min;
173 for (int i = 0; i < flat_size; ++i) {
174 const T val = input_data[i];
175 const T clamped = val > max_value ? max_value
176 : val < min_value ? min_value
177 : val;
178 output_data[i] = clamped;
179 }
180 }
181
182 // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
183 // dimensionality if the runtime code does a single loop over one dimension
184 // that handles broadcasting as the base case. The code generator would then
185 // generate max(D1, D2) nested for loops.
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)186 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
187 const RuntimeShape& unswitched_input1_shape,
188 const uint8* unswitched_input1_data,
189 const RuntimeShape& unswitched_input2_shape,
190 const uint8* unswitched_input2_data,
191 const RuntimeShape& output_shape,
192 uint8* output_data) {
193 ArithmeticParams switched_params = unswitched_params;
194 switched_params.input1_offset = unswitched_params.input2_offset;
195 switched_params.input2_offset = unswitched_params.input1_offset;
196
197 const bool use_unswitched =
198 unswitched_params.broadcast_category ==
199 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
200
201 const ArithmeticParams& params =
202 use_unswitched ? unswitched_params : switched_params;
203 const uint8* input1_data =
204 use_unswitched ? unswitched_input1_data : unswitched_input2_data;
205 const uint8* input2_data =
206 use_unswitched ? unswitched_input2_data : unswitched_input1_data;
207
208 // Fivefold nested loops. The second input resets its position for each
209 // iteration of the second loop. The first input resets its position at the
210 // beginning of the fourth loop. The innermost loop is an elementwise Mul of
211 // sections of the arrays.
212 uint8* output_data_ptr = output_data;
213 const uint8* input1_data_ptr = input1_data;
214 const uint8* input2_data_reset = input2_data;
215 int y0 = params.broadcast_shape[0];
216 int y1 = params.broadcast_shape[1];
217 int y2 = params.broadcast_shape[2];
218 int y3 = params.broadcast_shape[3];
219 int y4 = params.broadcast_shape[4];
220 for (int i0 = 0; i0 < y0; ++i0) {
221 const uint8* input2_data_ptr;
222 for (int i1 = 0; i1 < y1; ++i1) {
223 input2_data_ptr = input2_data_reset;
224 for (int i2 = 0; i2 < y2; ++i2) {
225 for (int i3 = 0; i3 < y3; ++i3) {
226 MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
227 output_data_ptr);
228 input2_data_ptr += y4;
229 output_data_ptr += y4;
230 }
231 input1_data_ptr += y4;
232 }
233 }
234 input2_data_reset = input2_data_ptr;
235 }
236 }
237
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)238 inline void Mul(const ArithmeticParams& params,
239 const RuntimeShape& input1_shape, const int16* input1_data,
240 const RuntimeShape& input2_shape, const int16* input2_data,
241 const RuntimeShape& output_shape, int16* output_data) {
242 ruy::profiler::ScopeLabel label("Mul/Int16");
243
244 const int flat_size =
245 MatchingElementsSize(input1_shape, input2_shape, output_shape);
246
247 for (int i = 0; i < flat_size; i++) {
248 // F0 uses 0 integer bits, range [-1, 1].
249 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
250
251 F0 unclamped_result =
252 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
253 output_data[i] = unclamped_result.raw();
254 }
255 }
256
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)257 inline void Mul(const ArithmeticParams& params,
258 const RuntimeShape& input1_shape, const int16* input1_data,
259 const RuntimeShape& input2_shape, const int16* input2_data,
260 const RuntimeShape& output_shape, uint8* output_data) {
261 ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
262 int32 output_offset = params.output_offset;
263 int32 output_activation_min = params.quantized_activation_min;
264 int32 output_activation_max = params.quantized_activation_max;
265 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
266
267 const int flat_size =
268 MatchingElementsSize(input1_shape, input2_shape, output_shape);
269
270 for (int i = 0; i < flat_size; i++) {
271 // F0 uses 0 integer bits, range [-1, 1].
272 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
273
274 F0 unclamped_result =
275 F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
276 int16 rescaled_result =
277 gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
278 int16 clamped_result =
279 std::min<int16>(output_activation_max - output_offset, rescaled_result);
280 clamped_result =
281 std::max<int16>(output_activation_min - output_offset, clamped_result);
282 output_data[i] = output_offset + clamped_result;
283 }
284 }
285
Sub16(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16_t * input1_data,const RuntimeShape & input2_shape,const int16_t * input2_data,const RuntimeShape & output_shape,int16_t * output_data)286 inline void Sub16(const ArithmeticParams& params,
287 const RuntimeShape& input1_shape, const int16_t* input1_data,
288 const RuntimeShape& input2_shape, const int16_t* input2_data,
289 const RuntimeShape& output_shape, int16_t* output_data) {
290 ruy::profiler::ScopeLabel label("Sub/Int16");
291 const int input1_shift = params.input1_shift;
292 const int flat_size =
293 MatchingElementsSize(input1_shape, input2_shape, output_shape);
294 const int16 output_activation_min = params.quantized_activation_min;
295 const int16 output_activation_max = params.quantized_activation_max;
296
297 TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
298 TFLITE_DCHECK_LE(input1_shift, 0);
299 TFLITE_DCHECK_LE(params.input2_shift, 0);
300 const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
301 const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
302 const int input_right_shift =
303 input1_shift == 0 ? -params.input2_shift : -input1_shift;
304
305 if (input1_shift == 0) {
306 // F0 uses 0 integer bits, range [-1, 1].
307 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
308 for (int i = 0; i < flat_size; ++i) {
309 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
310 F0 scaled_input = F0::FromRaw(
311 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
312 F0 result = SaturatingSub(input_ready_scaled, scaled_input);
313 const int16 raw_output = result.raw();
314 const int16 clamped_output = std::min(
315 output_activation_max, std::max(output_activation_min, raw_output));
316 output_data[i] = clamped_output;
317 }
318 } else {
319 // F0 uses 0 integer bits, range [-1, 1].
320 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
321 for (int i = 0; i < flat_size; ++i) {
322 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
323 F0 scaled_input = F0::FromRaw(
324 gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
325 F0 result = SaturatingSub(scaled_input, input_ready_scaled);
326 const int16 raw_output = result.raw();
327 const int16 clamped_output = std::min(
328 output_activation_max, std::max(output_activation_min, raw_output));
329 output_data[i] = clamped_output;
330 }
331 }
332 }
333
334 template <typename Scalar>
Pack(const PackParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)335 void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
336 const Scalar* const* input_data, const RuntimeShape& output_shape,
337 Scalar* output_data) {
338 ruy::profiler::ScopeLabel label("Pack");
339 const int dimensions = output_shape.DimensionsCount();
340 int axis = params.axis;
341 int inputs_count = params.inputs_count;
342
343 int outer_size = 1;
344 for (int i = 0; i < axis; i++) {
345 outer_size *= output_shape.Dims(i);
346 }
347 int copy_size = 1;
348 for (int i = params.axis + 1; i < dimensions; i++) {
349 copy_size *= output_shape.Dims(i);
350 }
351 TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
352
353 for (int i = 0; i < inputs_count; ++i) {
354 for (int k = 0; k < outer_size; k++) {
355 const Scalar* input_ptr = input_data[i] + copy_size * k;
356 int loc = k * inputs_count * copy_size + i * copy_size;
357 memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
358 }
359 }
360 }
361
362 template <typename Scalar>
Unpack(const UnpackParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * const * output_datas)363 void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
364 const Scalar* input_data, const RuntimeShape& output_shape,
365 Scalar* const* output_datas) {
366 ruy::profiler::ScopeLabel label("Unpack");
367 const int dimensions = input_shape.DimensionsCount();
368 const int outputs_count = params.num_split;
369
370 int outer_size = 1;
371 int axis = params.axis;
372 if (axis < 0) {
373 axis += dimensions;
374 }
375 TFLITE_DCHECK_GE(axis, 0);
376 TFLITE_DCHECK_LT(axis, dimensions);
377 for (int i = 0; i < axis; ++i) {
378 outer_size *= input_shape.Dims(i);
379 }
380 int copy_size = 1;
381 for (int i = axis + 1; i < dimensions; ++i) {
382 copy_size *= input_shape.Dims(i);
383 }
384 TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
385
386 for (int i = 0; i < outputs_count; ++i) {
387 for (int k = 0; k < outer_size; k++) {
388 Scalar* output_ptr = output_datas[i] + copy_size * k;
389 int loc = k * outputs_count * copy_size + i * copy_size;
390 memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
391 }
392 }
393 }
394
395 template <typename Scalar>
PackWithScaling(const PackParams & params,const RuntimeShape * const * input_shapes,const uint8 * const * input_data,const RuntimeShape & output_shape,uint8 * output_data)396 void PackWithScaling(const PackParams& params,
397 const RuntimeShape* const* input_shapes,
398 const uint8* const* input_data,
399 const RuntimeShape& output_shape, uint8* output_data) {
400 ruy::profiler::ScopeLabel label("PackWithScaling");
401 const int dimensions = output_shape.DimensionsCount();
402 int axis = params.axis;
403 const int32* input_zeropoint = params.input_zeropoint;
404 const float* input_scale = params.input_scale;
405 int inputs_count = params.inputs_count;
406 const int32 output_zeropoint = params.output_zeropoint;
407 const float output_scale = params.output_scale;
408
409 int outer_size = 1;
410 for (int i = 0; i < axis; i++) {
411 outer_size *= output_shape.Dims(i);
412 }
413 int copy_size = 1;
414 for (int i = axis + 1; i < dimensions; i++) {
415 copy_size *= output_shape.Dims(i);
416 }
417 TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
418
419 Scalar* output_ptr = output_data;
420 const float inverse_output_scale = 1.f / output_scale;
421 for (int k = 0; k < outer_size; k++) {
422 for (int i = 0; i < inputs_count; ++i) {
423 if (input_zeropoint[i] == output_zeropoint &&
424 input_scale[i] == output_scale) {
425 memcpy(output_ptr, input_data[i] + k * copy_size,
426 copy_size * sizeof(Scalar));
427 } else {
428 assert(false);
429 const float scale = input_scale[i] * inverse_output_scale;
430 const float bias = -input_zeropoint[i] * scale;
431 auto input_ptr = input_data[i];
432 for (int j = 0; j < copy_size; ++j) {
433 const int32_t value =
434 static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) +
435 output_zeropoint;
436 output_ptr[j] =
437 static_cast<uint8_t>(std::max(std::min(255, value), 0));
438 }
439 }
440 output_ptr += copy_size;
441 }
442 }
443 }
444
445 template <typename Scalar>
DepthConcatenation(const ConcatenationParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)446 void DepthConcatenation(const ConcatenationParams& params,
447 const RuntimeShape* const* input_shapes,
448 const Scalar* const* input_data,
449 const RuntimeShape& output_shape, Scalar* output_data) {
450 ruy::profiler::ScopeLabel label("DepthConcatenation");
451 auto params_copy = params;
452 params_copy.axis = 3;
453 Concatenation(params_copy, input_shapes, input_data, output_shape,
454 output_data);
455 }
456
457 template <typename Scalar>
Split(const SplitParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape * const * output_shapes,Scalar * const * output_data)458 void Split(const SplitParams& params, const RuntimeShape& input_shape,
459 const Scalar* input_data, const RuntimeShape* const* output_shapes,
460 Scalar* const* output_data) {
461 ruy::profiler::ScopeLabel label("Split");
462 const int split_dimensions = input_shape.DimensionsCount();
463 int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
464 int outputs_count = params.num_split;
465 TFLITE_DCHECK_LT(axis, split_dimensions);
466
467 int64_t split_size = 0;
468 for (int i = 0; i < outputs_count; i++) {
469 TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
470 for (int j = 0; j < split_dimensions; j++) {
471 if (j != axis) {
472 MatchingDim(*output_shapes[i], j, input_shape, j);
473 }
474 }
475 split_size += output_shapes[i]->Dims(axis);
476 }
477 TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
478 int64_t outer_size = 1;
479 for (int i = 0; i < axis; ++i) {
480 outer_size *= input_shape.Dims(i);
481 }
482 // For all output arrays,
483 // FlatSize() = outer_size * Dims(axis) * base_inner_size;
484 int64_t base_inner_size = 1;
485 for (int i = axis + 1; i < split_dimensions; ++i) {
486 base_inner_size *= input_shape.Dims(i);
487 }
488
489 const Scalar* input_ptr = input_data;
490 for (int k = 0; k < outer_size; k++) {
491 for (int i = 0; i < outputs_count; ++i) {
492 const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
493 memcpy(output_data[i] + k * copy_size, input_ptr,
494 copy_size * sizeof(Scalar));
495 input_ptr += copy_size;
496 }
497 }
498 }
499
NodeOffset(int b,int h,int w,int height,int width)500 inline int NodeOffset(int b, int h, int w, int height, int width) {
501 return (b * height + h) * width + w;
502 }
503
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)504 inline void LocalResponseNormalization(
505 const tflite::LocalResponseNormalizationParams& op_params,
506 const RuntimeShape& input_shape, const float* input_data,
507 const RuntimeShape& output_shape, float* output_data) {
508 const int trailing_dim = input_shape.DimensionsCount() - 1;
509 const int outer_size =
510 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
511 const int depth =
512 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
513
514 for (int i = 0; i < outer_size; ++i) {
515 for (int c = 0; c < depth; ++c) {
516 const int begin_input_c = std::max(0, c - op_params.range);
517 const int end_input_c = std::min(depth, c + op_params.range);
518 float accum = 0.f;
519 for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
520 const float input_val = input_data[i * depth + input_c];
521 accum += input_val * input_val;
522 }
523 const float multiplier =
524 std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
525 output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
526 }
527 }
528 }
529
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)530 inline void Dequantize(const RuntimeShape& input_shape,
531 const Eigen::half* input_data,
532 const RuntimeShape& output_shape, float* output_data) {
533 const int flat_size = MatchingFlatSize(input_shape, output_shape);
534 for (int i = 0; i < flat_size; i++) {
535 output_data[i] = static_cast<float>(input_data[i]);
536 }
537 }
538
FakeQuant(const tflite::FakeQuantParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)539 inline void FakeQuant(const tflite::FakeQuantParams& op_params,
540 const RuntimeShape& input_shape, const float* input_data,
541 const RuntimeShape& output_shape, float* output_data) {
542 ruy::profiler::ScopeLabel label("FakeQuant");
543 float rmin = op_params.minmax.min;
544 float rmax = op_params.minmax.max;
545 int num_bits = op_params.num_bits;
546 // 0 should always be a representable value. Let's assume that the initial
547 // min,max range contains 0.
548 TFLITE_DCHECK_LE(rmin, 0.0f);
549 TFLITE_DCHECK_GE(rmax, 0.0f);
550 TFLITE_DCHECK_LT(rmin, rmax);
551
552 // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
553 int quant_min = 0;
554 int quant_max = (1 << num_bits) - 1;
555 float nudged_min, nudged_max, nudged_scale;
556 NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
557 &nudged_max, &nudged_scale);
558 const int flat_size = MatchingFlatSize(input_shape, output_shape);
559 FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
560 output_data, flat_size);
561 }
562
563 // Common subroutine for both `GatherNd` and `GatherNdString`.
564 struct GatherNdHelperResult {
565 int n_slices;
566 int slice_size;
567 int indices_nd;
568 std::vector<int> dims_to_count;
569 };
570
571 // Returns common values being used on both `GatherNd` and `GatherNdString`.
GatherNdHelper(const RuntimeShape & params_shape,const RuntimeShape & indices_shape)572 inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape,
573 const RuntimeShape& indices_shape) {
574 GatherNdHelperResult ret;
575 ret.n_slices = 1;
576 ret.slice_size = 1;
577 const int indices_dims = indices_shape.DimensionsCount();
578 ret.indices_nd = indices_shape.Dims(indices_dims - 1);
579 const int params_dims = params_shape.DimensionsCount();
580 for (int i = 0; i < indices_dims - 1; ++i) {
581 ret.n_slices *= indices_shape.Dims(i);
582 }
583 if (ret.n_slices == 0) return ret;
584
585 for (int i = ret.indices_nd; i < params_dims; ++i) {
586 ret.slice_size *= params_shape.Dims(i);
587 }
588
589 int remain_flat_size = params_shape.FlatSize();
590 ret.dims_to_count = std::vector<int>(ret.indices_nd, 0);
591 for (int i = 0; i < ret.indices_nd; ++i) {
592 ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i);
593 remain_flat_size = ret.dims_to_count[i];
594 }
595
596 return ret;
597 }
598
599 // Implements GatherNd.
600 // Returns an error if any of the indices_data would cause an out of bounds
601 // memory read.
602 template <typename ParamsT, typename IndicesT = int32>
GatherNd(const RuntimeShape & params_shape,const ParamsT * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,ParamsT * output_data)603 inline TfLiteStatus GatherNd(const RuntimeShape& params_shape,
604 const ParamsT* params_data,
605 const RuntimeShape& indices_shape,
606 const IndicesT* indices_data,
607 const RuntimeShape& output_shape,
608 ParamsT* output_data) {
609 ruy::profiler::ScopeLabel label("GatherNd");
610
611 const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
612 for (int i = 0; i < res.n_slices; ++i) {
613 int64_t from_pos = 0;
614 for (int j = 0; j < res.indices_nd; ++j) {
615 from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
616 }
617 if (from_pos < 0 || from_pos + res.slice_size > params_shape.FlatSize()) {
618 return kTfLiteError;
619 }
620 std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
621 sizeof(ParamsT) * res.slice_size);
622 }
623 return kTfLiteOk;
624 }
625
626 #ifndef TF_LITE_STATIC_MEMORY
627 // Implements GatherNd on strings.
628 // Returns an error if any of the indices_data would cause an out of bounds
629 // memory read.
630 template <typename IndicesT = int32>
GatherNdString(const RuntimeShape & params_shape,const TfLiteTensor * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,TfLiteTensor * output_data)631 inline TfLiteStatus GatherNdString(const RuntimeShape& params_shape,
632 const TfLiteTensor* params_data,
633 const RuntimeShape& indices_shape,
634 const IndicesT* indices_data,
635 const RuntimeShape& output_shape,
636 TfLiteTensor* output_data) {
637 ruy::profiler::ScopeLabel label("GatherNdString");
638
639 const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
640 DynamicBuffer buffer;
641 for (int i = 0; i < res.n_slices; ++i) {
642 int64_t from_pos = 0;
643 for (int j = 0; j < res.indices_nd; ++j) {
644 from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
645 }
646 if (from_pos < 0 || from_pos + res.slice_size > params_shape.FlatSize()) {
647 return kTfLiteError;
648 }
649 for (int j = 0; j < res.slice_size; ++j) {
650 buffer.AddString(GetString(params_data, from_pos + j));
651 }
652 }
653 buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
654 return kTfLiteOk;
655 }
656 #endif
657
658 template <typename IndicesT, typename UpdatesT>
ScatterNd(const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & updates_shape,const UpdatesT * updates_data,const RuntimeShape & output_shape,UpdatesT * output_data)659 inline TfLiteStatus ScatterNd(const RuntimeShape& indices_shape,
660 const IndicesT* indices_data,
661 const RuntimeShape& updates_shape,
662 const UpdatesT* updates_data,
663 const RuntimeShape& output_shape,
664 UpdatesT* output_data) {
665 ruy::profiler::ScopeLabel label("ScatterNd");
666
667 int n_slices = 1;
668 int slice_size = 1;
669 const int outer_dims = indices_shape.DimensionsCount() - 1;
670 const int indices_nd = indices_shape.Dims(outer_dims);
671 const int updates_dims = updates_shape.DimensionsCount();
672 for (int i = 0; i < outer_dims; ++i) {
673 n_slices *= indices_shape.Dims(i);
674 }
675 for (int i = outer_dims; i < updates_dims; ++i) {
676 slice_size *= updates_shape.Dims(i);
677 }
678
679 int output_flat_size = output_shape.FlatSize();
680 int remain_flat_size = output_flat_size;
681 std::vector<int> dims_to_count(indices_nd, 0);
682 for (int i = 0; i < indices_nd; ++i) {
683 dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
684 remain_flat_size = dims_to_count[i];
685 }
686
687 if (n_slices * slice_size > updates_shape.FlatSize()) {
688 return kTfLiteError;
689 }
690 memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
691 for (int i = 0; i < n_slices; ++i) {
692 int to_pos = 0;
693 for (int j = 0; j < indices_nd; ++j) {
694 IndicesT idx = indices_data[i * indices_nd + j];
695 to_pos += idx * dims_to_count[j];
696 }
697 if (to_pos < 0 || to_pos + slice_size > output_flat_size) {
698 return kTfLiteError;
699 }
700 for (int j = 0; j < slice_size; j++) {
701 output_data[to_pos + j] += updates_data[i * slice_size + j];
702 }
703 }
704 return kTfLiteOk;
705 }
706
707 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)708 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
709 const T* input2_data, const RuntimeShape& output_shape,
710 T* output_data) {
711 const int flat_size = MatchingFlatSize(input1_shape, output_shape);
712
713 auto min_value = input2_data[0];
714 for (int i = 0; i < flat_size; i++) {
715 output_data[i] = input1_data[i] > min_value ? min_value : input1_data[i];
716 }
717 }
718
719 // Convenience version that allows, for example, generated-code calls to be
720 // the same as other binary ops.
721 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)722 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
723 const RuntimeShape&, const T* input2_data,
724 const RuntimeShape& output_shape, T* output_data) {
725 // Drop shape of second input: not needed.
726 Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
727 }
728
729 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)730 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
731 const T* input2_data, const RuntimeShape& output_shape,
732 T* output_data) {
733 const int flat_size = MatchingFlatSize(input1_shape, output_shape);
734
735 auto max_value = input2_data[0];
736 for (int i = 0; i < flat_size; i++) {
737 output_data[i] = input1_data[i] < max_value ? max_value : input1_data[i];
738 }
739 }
740
741 // Convenience version that allows, for example, generated-code calls to be
742 // the same as other binary ops.
743 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)744 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
745 const RuntimeShape&, const T* input2_data,
746 const RuntimeShape& output_shape, T* output_data) {
747 // Drop shape of second input: not needed.
748 Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
749 }
750
751 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)752 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
753 const T3* input2_data, const RuntimeShape& output_shape,
754 T2* output_data) {
755 ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
756 std::greater<T1>());
757 }
758
759 // Convenience version that allows, for example, generated-code calls to be
760 // the same as other binary ops.
761 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)762 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
763 const RuntimeShape& input2_shape, const T3* input2_data,
764 const RuntimeShape& output_shape, T2* output_data) {
765 // Drop shape of second input: not needed.
766 ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
767 }
768
769 template <typename D, typename T>
SelectTrueCoords(const RuntimeShape & input_condition_shape,const D * input_condition_data,T * output_data)770 void SelectTrueCoords(const RuntimeShape& input_condition_shape,
771 const D* input_condition_data, T* output_data) {
772 const size_t size = input_condition_shape.FlatSize();
773 if (size == 0) {
774 // Dimension is zero, in which case we don't need to output.
775 return;
776 }
777 const size_t cond_rank = input_condition_shape.DimensionsCount();
778
779 std::vector<int> dims_to_count(cond_rank, 0);
780 int cur_flat_size = size;
781 for (int i = 0; i < cond_rank; ++i) {
782 dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
783 cur_flat_size = dims_to_count[i];
784 }
785
786 int output_index = 0;
787 for (int i = 0; i < size; ++i) {
788 if (input_condition_data[i] != D(0)) {
789 // Insert the coordinate of the current item (row major) into output.
790 int flat_index = i;
791 for (int j = 0; j < cond_rank; ++j) {
792 int coord_j = flat_index / dims_to_count[j];
793 output_data[output_index * cond_rank + j] = coord_j;
794 flat_index %= dims_to_count[j];
795 }
796 output_index++;
797 }
798 }
799 }
800
801 // For easy implementation, the indices is always a vector of size-4 vectors.
802 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,bool value_is_scalar,const RuntimeShape & unextended_output_shape,T * output_data)803 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
804 const T* values, T default_value,
805 bool value_is_scalar,
806 const RuntimeShape& unextended_output_shape,
807 T* output_data) {
808 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
809 const RuntimeShape output_shape =
810 RuntimeShape::ExtendedShape(4, unextended_output_shape);
811 const int value_count = indices.size();
812
813 // First fill the output_data with default value.
814 const int num_elements = output_shape.FlatSize();
815 for (int i = 0; i < num_elements; ++i) {
816 output_data[i] = default_value;
817 }
818
819 // Special handle for value is scalar case to avoid checking the boolean
820 // condition within the loop every time.
821 if (value_is_scalar) {
822 for (int i = 0; i < value_count; ++i) {
823 const std::vector<TI>& index = indices[i];
824 TFLITE_DCHECK_EQ(index.size(), 4);
825 const T value = *values; // just use the first value.
826 output_data[Offset(output_shape, index[0], index[1], index[2],
827 index[3])] = value;
828 }
829 return;
830 }
831
832 // Go through the values and indices to fill the sparse values.
833 for (int i = 0; i < value_count; ++i) {
834 const std::vector<TI>& index = indices[i];
835 TFLITE_DCHECK_EQ(index.size(), 4);
836 const T value = values[i];
837 output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
838 value;
839 }
840 }
841
842 template <typename T>
Pow(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)843 inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
844 const RuntimeShape& input2_shape, const T* input2_data,
845 const RuntimeShape& output_shape, T* output_data) {
846 const int flat_size =
847 MatchingFlatSize(input1_shape, input2_shape, output_shape);
848 for (int i = 0; i < flat_size; ++i) {
849 output_data[i] = std::pow(input1_data[i], input2_data[i]);
850 }
851 }
852
853 template <typename T>
BroadcastPow4DSlow(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)854 inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
855 const T* input1_data,
856 const RuntimeShape& unextended_input2_shape,
857 const T* input2_data,
858 const RuntimeShape& unextended_output_shape,
859 T* output_data) {
860 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
861 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
862 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
863 const RuntimeShape output_shape =
864 RuntimeShape::ExtendedShape(4, unextended_output_shape);
865
866 NdArrayDesc<4> desc1;
867 NdArrayDesc<4> desc2;
868 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
869 unextended_input2_shape, &desc1, &desc2);
870
871 for (int b = 0; b < output_shape.Dims(0); ++b) {
872 for (int y = 0; y < output_shape.Dims(1); ++y) {
873 for (int x = 0; x < output_shape.Dims(2); ++x) {
874 for (int c = 0; c < output_shape.Dims(3); ++c) {
875 auto out_idx = Offset(output_shape, b, y, x, c);
876 auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
877 auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
878 auto in1_val = input1_data[in1_idx];
879 auto in2_val = input2_data[in2_idx];
880 output_data[out_idx] = std::pow(in1_val, in2_val);
881 }
882 }
883 }
884 }
885 }
886
887 template <typename Scalar>
Reverse(int axis,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)888 void Reverse(int axis, const RuntimeShape& input_shape,
889 const Scalar* input_data, const RuntimeShape& output_shape,
890 Scalar* output_data) {
891 ruy::profiler::ScopeLabel label("Reverse");
892
893 int outer_size = 1;
894 for (int i = 0; i < axis; ++i) {
895 outer_size *= input_shape.Dims(i);
896 }
897
898 int copy_size = 1;
899 for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
900 copy_size *= input_shape.Dims(i);
901 }
902
903 const int dims_at_axis = input_shape.Dims(axis);
904 for (int i = 0; i < outer_size; ++i) {
905 for (int j = 0; j < dims_at_axis; ++j) {
906 const int start_pos = (i * dims_at_axis + j) * copy_size;
907 Scalar* output_ptr = output_data + start_pos;
908 int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
909 memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
910 }
911 }
912 }
913
914 template <typename Scalar, typename TS>
ReverseSequence(const TS * seq_lengths,const int seq_dim,const int batch_dim,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)915 void ReverseSequence(const TS* seq_lengths, const int seq_dim,
916 const int batch_dim, const RuntimeShape& input_shape,
917 const Scalar* input_data, const RuntimeShape& output_shape,
918 Scalar* output_data) {
919 ruy::profiler::ScopeLabel label("ReverseSequence");
920
921 int outer_size = 1;
922 int outer_dim = std::min(batch_dim, seq_dim);
923 int medium_dim = std::max(batch_dim, seq_dim);
924 for (int i = 0; i < outer_dim; ++i) {
925 outer_size *= input_shape.Dims(i);
926 }
927
928 int medium_size = 1;
929 for (int i = outer_dim + 1; i < medium_dim; ++i) {
930 medium_size *= input_shape.Dims(i);
931 }
932
933 int copy_size = 1;
934 for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i) {
935 copy_size *= input_shape.Dims(i);
936 }
937
938 const int dims_at_outer_dim = input_shape.Dims(outer_dim);
939 const int dims_at_medium_dim = input_shape.Dims(medium_dim);
940
941 Scalar* output_ptr;
942 if (batch_dim > seq_dim) {
943 for (int i = 0; i < outer_size; ++i) {
944 for (int j = 0; j < dims_at_outer_dim; ++j) {
945 const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
946 for (int p = 0; p < medium_size; ++p) {
947 for (int q = 0; q < dims_at_medium_dim; ++q) {
948 const int in_pos =
949 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
950 const Scalar* in_ptr = input_data + in_pos;
951 int sl = seq_lengths[q] - 1;
952 if (j > sl) {
953 output_ptr = output_data + in_pos;
954 } else {
955 const int out_pos_base =
956 (i * dims_at_outer_dim + sl - j) * medium_size;
957 const int out_pos =
958 ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size;
959 output_ptr = output_data + out_pos;
960 }
961 memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
962 }
963 }
964 }
965 }
966 } else if (batch_dim < seq_dim) {
967 for (int i = 0; i < outer_size; ++i) {
968 for (int j = 0; j < dims_at_outer_dim; ++j) {
969 const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
970 int sl = seq_lengths[j] - 1;
971 const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size;
972 for (int p = 0; p < medium_size; ++p) {
973 for (int q = 0; q < dims_at_medium_dim; ++q) {
974 const int in_pos =
975 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
976 const Scalar* in_ptr = input_data + in_pos;
977 if (q > sl) {
978 output_ptr = output_data + in_pos;
979 } else {
980 const int out_pos =
981 ((out_pos_base + p) * dims_at_medium_dim + sl - q) *
982 copy_size;
983 output_ptr = output_data + out_pos;
984 }
985 memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
986 }
987 }
988 }
989 }
990 }
991 }
992
993 template <typename T>
SegmentSum(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const RuntimeShape & output_shape,T * output_data)994 inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
995 const RuntimeShape& segment_ids_shape,
996 const int32_t* segment_ids_data,
997 const RuntimeShape& output_shape, T* output_data) {
998 const int segment_flat_size =
999 MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
1000
1001 memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
1002
1003 for (int i = 0; i < input_shape.Dims(0); i++) {
1004 int output_index = segment_ids_data[i];
1005 for (int j = 0; j < segment_flat_size; ++j) {
1006 output_data[output_index * segment_flat_size + j] +=
1007 input_data[i * segment_flat_size + j];
1008 }
1009 }
1010 }
1011
1012 template <typename T, template <typename T2> typename Op>
UnsortedSegmentRef(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const RuntimeShape & output_shape,T * output_data)1013 inline void UnsortedSegmentRef(const RuntimeShape& input_shape,
1014 const T* input_data,
1015 const RuntimeShape& segment_ids_shape,
1016 const int32_t* segment_ids_data,
1017 const RuntimeShape& output_shape,
1018 T* output_data) {
1019 for (int i = 0; i < output_shape.FlatSize(); ++i) {
1020 output_data[i] = Op<T>::kInitialValue;
1021 }
1022 Op<T> op;
1023 int segment_flat_size = 1;
1024 for (int i = 1; i < output_shape.DimensionsCount(); ++i) {
1025 segment_flat_size *= output_shape.Dims(i);
1026 }
1027 for (int i = 0; i < segment_ids_shape.FlatSize(); i++) {
1028 int output_index = segment_ids_data[i];
1029 if (output_index < 0) continue;
1030 for (int j = 0; j < segment_flat_size; ++j) {
1031 output_data[output_index * segment_flat_size + j] =
1032 op(output_data[output_index * segment_flat_size + j],
1033 input_data[i * segment_flat_size + j]);
1034 }
1035 }
1036 }
1037
1038 } // namespace reference_ops
1039 } // namespace tflite
1040
1041 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
1042