xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include <algorithm>
22 #include <cmath>
23 #include <cstring>
24 #include <functional>
25 #include <limits>
26 #include <memory>
27 #include <type_traits>
28 
29 #include "Eigen/Core"
30 #include "fixedpoint/fixedpoint.h"
31 #include "ruy/profiler/instrumentation.h"  // from @ruy
32 #include "tensorflow/lite/c/c_api_types.h"
33 #include "tensorflow/lite/c/common.h"
34 #include "tensorflow/lite/kernels/internal/common.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/reference/add.h"
37 #include "tensorflow/lite/kernels/internal/reference/add_n.h"
38 #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
39 #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
40 #include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
41 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
42 #include "tensorflow/lite/kernels/internal/reference/cast.h"
43 #include "tensorflow/lite/kernels/internal/reference/ceil.h"
44 #include "tensorflow/lite/kernels/internal/reference/comparisons.h"
45 #include "tensorflow/lite/kernels/internal/reference/concatenation.h"
46 #include "tensorflow/lite/kernels/internal/reference/conv.h"
47 #include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
48 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
49 #include "tensorflow/lite/kernels/internal/reference/div.h"
50 #include "tensorflow/lite/kernels/internal/reference/elu.h"
51 #include "tensorflow/lite/kernels/internal/reference/exp.h"
52 #include "tensorflow/lite/kernels/internal/reference/fill.h"
53 #include "tensorflow/lite/kernels/internal/reference/floor.h"
54 #include "tensorflow/lite/kernels/internal/reference/floor_div.h"
55 #include "tensorflow/lite/kernels/internal/reference/floor_mod.h"
56 #include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
57 #include "tensorflow/lite/kernels/internal/reference/gather.h"
58 #include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
59 #include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
60 #include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
61 #include "tensorflow/lite/kernels/internal/reference/log_softmax.h"
62 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
63 #include "tensorflow/lite/kernels/internal/reference/lstm_cell.h"
64 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
65 #include "tensorflow/lite/kernels/internal/reference/mul.h"
66 #include "tensorflow/lite/kernels/internal/reference/neg.h"
67 #include "tensorflow/lite/kernels/internal/reference/pad.h"
68 #include "tensorflow/lite/kernels/internal/reference/pooling.h"
69 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
70 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
71 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
72 #include "tensorflow/lite/kernels/internal/reference/reduce.h"
73 #include "tensorflow/lite/kernels/internal/reference/requantize.h"
74 #include "tensorflow/lite/kernels/internal/reference/resize_bilinear.h"
75 #include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
76 #include "tensorflow/lite/kernels/internal/reference/round.h"
77 #include "tensorflow/lite/kernels/internal/reference/select.h"
78 #include "tensorflow/lite/kernels/internal/reference/slice.h"
79 #include "tensorflow/lite/kernels/internal/reference/softmax.h"
80 #include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
81 #include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
82 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
83 #include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
84 #include "tensorflow/lite/kernels/internal/reference/sub.h"
85 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
86 #include "tensorflow/lite/kernels/internal/reference/transpose.h"
87 #include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
88 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
89 #include "tensorflow/lite/kernels/internal/tensor.h"
90 namespace tflite {
91 
92 namespace reference_ops {
93 
94 template <typename T>
Relu(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)95 inline void Relu(const RuntimeShape& input_shape, const T* input_data,
96                  const RuntimeShape& output_shape, T* output_data) {
97   const int flat_size = MatchingFlatSize(input_shape, output_shape);
98   for (int i = 0; i < flat_size; ++i) {
99     const T val = input_data[i];
100     const T lower = 0;
101     const T clamped = val < lower ? lower : val;
102     output_data[i] = clamped;
103   }
104 }
105 
106 template <typename T>
Relu0To1(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)107 inline void Relu0To1(const RuntimeShape& input_shape, const T* input_data,
108                      const RuntimeShape& output_shape, T* output_data) {
109   ruy::profiler::ScopeLabel label("Relu0To1 (not fused)");
110   const int flat_size = MatchingFlatSize(input_shape, output_shape);
111   for (int i = 0; i < flat_size; ++i) {
112     const T val = input_data[i];
113     const T upper = 1;
114     const T lower = 0;
115     const T clamped = val > upper ? upper : val < lower ? lower : val;
116     output_data[i] = clamped;
117   }
118 }
119 
120 template <typename T>
Relu1(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)121 inline void Relu1(const RuntimeShape& input_shape, const T* input_data,
122                   const RuntimeShape& output_shape, T* output_data) {
123   ruy::profiler::ScopeLabel label("Relu1 (not fused)");
124   const int flat_size = MatchingFlatSize(input_shape, output_shape);
125   for (int i = 0; i < flat_size; ++i) {
126     const T val = input_data[i];
127     const T upper = 1;
128     const T lower = -1;
129     const T clamped = val > upper ? upper : val < lower ? lower : val;
130     output_data[i] = clamped;
131   }
132 }
133 
Relu6(const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)134 inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
135                   const RuntimeShape& output_shape, float* output_data) {
136   ruy::profiler::ScopeLabel label("Relu6 (not fused)");
137   const int flat_size = MatchingFlatSize(input_shape, output_shape);
138   for (int i = 0; i < flat_size; ++i) {
139     const float val = input_data[i];
140     const float upper = 6;
141     const float lower = 0;
142     const float clamped = val > upper ? upper : val < lower ? lower : val;
143     output_data[i] = clamped;
144   }
145 }
146 
147 template <typename T>
ReluX(const tflite::ReluParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)148 inline void ReluX(const tflite::ReluParams& params,
149                   const RuntimeShape& input_shape, const T* input_data,
150                   const RuntimeShape& output_shape, T* output_data) {
151   ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
152   const int flat_size = MatchingFlatSize(input_shape, output_shape);
153   for (int i = 0; i < flat_size; ++i) {
154     const int32 val = static_cast<int32_t>(input_data[i]);
155     int32 clamped = params.output_offset +
156                     MultiplyByQuantizedMultiplier(val - params.input_offset,
157                                                   params.output_multiplier,
158                                                   params.output_shift);
159     clamped = std::max(params.quantized_activation_min, clamped);
160     clamped = std::min(params.quantized_activation_max, clamped);
161     output_data[i] = static_cast<T>(clamped);
162   }
163 }
164 
165 template <typename T>
ReluX(const tflite::ActivationParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)166 inline void ReluX(const tflite::ActivationParams& params,
167                   const RuntimeShape& input_shape, const T* input_data,
168                   const RuntimeShape& output_shape, T* output_data) {
169   ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
170   const int flat_size = MatchingFlatSize(input_shape, output_shape);
171   const T max_value = params.quantized_activation_max;
172   const T min_value = params.quantized_activation_min;
173   for (int i = 0; i < flat_size; ++i) {
174     const T val = input_data[i];
175     const T clamped = val > max_value   ? max_value
176                       : val < min_value ? min_value
177                                         : val;
178     output_data[i] = clamped;
179   }
180 }
181 
182 // TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
183 // dimensionality if the runtime code does a single loop over one dimension
184 // that handles broadcasting as the base case. The code generator would then
185 // generate max(D1, D2) nested for loops.
BroadcastMulFivefold(const ArithmeticParams & unswitched_params,const RuntimeShape & unswitched_input1_shape,const uint8 * unswitched_input1_data,const RuntimeShape & unswitched_input2_shape,const uint8 * unswitched_input2_data,const RuntimeShape & output_shape,uint8 * output_data)186 inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
187                                  const RuntimeShape& unswitched_input1_shape,
188                                  const uint8* unswitched_input1_data,
189                                  const RuntimeShape& unswitched_input2_shape,
190                                  const uint8* unswitched_input2_data,
191                                  const RuntimeShape& output_shape,
192                                  uint8* output_data) {
193   ArithmeticParams switched_params = unswitched_params;
194   switched_params.input1_offset = unswitched_params.input2_offset;
195   switched_params.input2_offset = unswitched_params.input1_offset;
196 
197   const bool use_unswitched =
198       unswitched_params.broadcast_category ==
199       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
200 
201   const ArithmeticParams& params =
202       use_unswitched ? unswitched_params : switched_params;
203   const uint8* input1_data =
204       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
205   const uint8* input2_data =
206       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
207 
208   // Fivefold nested loops. The second input resets its position for each
209   // iteration of the second loop. The first input resets its position at the
210   // beginning of the fourth loop. The innermost loop is an elementwise Mul of
211   // sections of the arrays.
212   uint8* output_data_ptr = output_data;
213   const uint8* input1_data_ptr = input1_data;
214   const uint8* input2_data_reset = input2_data;
215   int y0 = params.broadcast_shape[0];
216   int y1 = params.broadcast_shape[1];
217   int y2 = params.broadcast_shape[2];
218   int y3 = params.broadcast_shape[3];
219   int y4 = params.broadcast_shape[4];
220   for (int i0 = 0; i0 < y0; ++i0) {
221     const uint8* input2_data_ptr;
222     for (int i1 = 0; i1 < y1; ++i1) {
223       input2_data_ptr = input2_data_reset;
224       for (int i2 = 0; i2 < y2; ++i2) {
225         for (int i3 = 0; i3 < y3; ++i3) {
226           MulElementwise(y4, params, input1_data_ptr, input2_data_ptr,
227                          output_data_ptr);
228           input2_data_ptr += y4;
229           output_data_ptr += y4;
230         }
231         input1_data_ptr += y4;
232       }
233     }
234     input2_data_reset = input2_data_ptr;
235   }
236 }
237 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)238 inline void Mul(const ArithmeticParams& params,
239                 const RuntimeShape& input1_shape, const int16* input1_data,
240                 const RuntimeShape& input2_shape, const int16* input2_data,
241                 const RuntimeShape& output_shape, int16* output_data) {
242   ruy::profiler::ScopeLabel label("Mul/Int16");
243 
244   const int flat_size =
245       MatchingElementsSize(input1_shape, input2_shape, output_shape);
246 
247   for (int i = 0; i < flat_size; i++) {
248     // F0 uses 0 integer bits, range [-1, 1].
249     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
250 
251     F0 unclamped_result =
252         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
253     output_data[i] = unclamped_result.raw();
254   }
255 }
256 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,uint8 * output_data)257 inline void Mul(const ArithmeticParams& params,
258                 const RuntimeShape& input1_shape, const int16* input1_data,
259                 const RuntimeShape& input2_shape, const int16* input2_data,
260                 const RuntimeShape& output_shape, uint8* output_data) {
261   ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
262   int32 output_offset = params.output_offset;
263   int32 output_activation_min = params.quantized_activation_min;
264   int32 output_activation_max = params.quantized_activation_max;
265   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
266 
267   const int flat_size =
268       MatchingElementsSize(input1_shape, input2_shape, output_shape);
269 
270   for (int i = 0; i < flat_size; i++) {
271     // F0 uses 0 integer bits, range [-1, 1].
272     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
273 
274     F0 unclamped_result =
275         F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
276     int16 rescaled_result =
277         gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
278     int16 clamped_result =
279         std::min<int16>(output_activation_max - output_offset, rescaled_result);
280     clamped_result =
281         std::max<int16>(output_activation_min - output_offset, clamped_result);
282     output_data[i] = output_offset + clamped_result;
283   }
284 }
285 
Sub16(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16_t * input1_data,const RuntimeShape & input2_shape,const int16_t * input2_data,const RuntimeShape & output_shape,int16_t * output_data)286 inline void Sub16(const ArithmeticParams& params,
287                   const RuntimeShape& input1_shape, const int16_t* input1_data,
288                   const RuntimeShape& input2_shape, const int16_t* input2_data,
289                   const RuntimeShape& output_shape, int16_t* output_data) {
290   ruy::profiler::ScopeLabel label("Sub/Int16");
291   const int input1_shift = params.input1_shift;
292   const int flat_size =
293       MatchingElementsSize(input1_shape, input2_shape, output_shape);
294   const int16 output_activation_min = params.quantized_activation_min;
295   const int16 output_activation_max = params.quantized_activation_max;
296 
297   TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
298   TFLITE_DCHECK_LE(input1_shift, 0);
299   TFLITE_DCHECK_LE(params.input2_shift, 0);
300   const int16* not_shift_input = input1_shift == 0 ? input1_data : input2_data;
301   const int16* shift_input = input1_shift == 0 ? input2_data : input1_data;
302   const int input_right_shift =
303       input1_shift == 0 ? -params.input2_shift : -input1_shift;
304 
305   if (input1_shift == 0) {
306     // F0 uses 0 integer bits, range [-1, 1].
307     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
308     for (int i = 0; i < flat_size; ++i) {
309       F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
310       F0 scaled_input = F0::FromRaw(
311           gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
312       F0 result = SaturatingSub(input_ready_scaled, scaled_input);
313       const int16 raw_output = result.raw();
314       const int16 clamped_output = std::min(
315           output_activation_max, std::max(output_activation_min, raw_output));
316       output_data[i] = clamped_output;
317     }
318   } else {
319     // F0 uses 0 integer bits, range [-1, 1].
320     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
321     for (int i = 0; i < flat_size; ++i) {
322       F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
323       F0 scaled_input = F0::FromRaw(
324           gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
325       F0 result = SaturatingSub(scaled_input, input_ready_scaled);
326       const int16 raw_output = result.raw();
327       const int16 clamped_output = std::min(
328           output_activation_max, std::max(output_activation_min, raw_output));
329       output_data[i] = clamped_output;
330     }
331   }
332 }
333 
334 template <typename Scalar>
Pack(const PackParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)335 void Pack(const PackParams& params, const RuntimeShape* const* input_shapes,
336           const Scalar* const* input_data, const RuntimeShape& output_shape,
337           Scalar* output_data) {
338   ruy::profiler::ScopeLabel label("Pack");
339   const int dimensions = output_shape.DimensionsCount();
340   int axis = params.axis;
341   int inputs_count = params.inputs_count;
342 
343   int outer_size = 1;
344   for (int i = 0; i < axis; i++) {
345     outer_size *= output_shape.Dims(i);
346   }
347   int copy_size = 1;
348   for (int i = params.axis + 1; i < dimensions; i++) {
349     copy_size *= output_shape.Dims(i);
350   }
351   TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
352 
353   for (int i = 0; i < inputs_count; ++i) {
354     for (int k = 0; k < outer_size; k++) {
355       const Scalar* input_ptr = input_data[i] + copy_size * k;
356       int loc = k * inputs_count * copy_size + i * copy_size;
357       memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
358     }
359   }
360 }
361 
362 template <typename Scalar>
Unpack(const UnpackParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * const * output_datas)363 void Unpack(const UnpackParams& params, const RuntimeShape& input_shape,
364             const Scalar* input_data, const RuntimeShape& output_shape,
365             Scalar* const* output_datas) {
366   ruy::profiler::ScopeLabel label("Unpack");
367   const int dimensions = input_shape.DimensionsCount();
368   const int outputs_count = params.num_split;
369 
370   int outer_size = 1;
371   int axis = params.axis;
372   if (axis < 0) {
373     axis += dimensions;
374   }
375   TFLITE_DCHECK_GE(axis, 0);
376   TFLITE_DCHECK_LT(axis, dimensions);
377   for (int i = 0; i < axis; ++i) {
378     outer_size *= input_shape.Dims(i);
379   }
380   int copy_size = 1;
381   for (int i = axis + 1; i < dimensions; ++i) {
382     copy_size *= input_shape.Dims(i);
383   }
384   TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
385 
386   for (int i = 0; i < outputs_count; ++i) {
387     for (int k = 0; k < outer_size; k++) {
388       Scalar* output_ptr = output_datas[i] + copy_size * k;
389       int loc = k * outputs_count * copy_size + i * copy_size;
390       memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
391     }
392   }
393 }
394 
395 template <typename Scalar>
PackWithScaling(const PackParams & params,const RuntimeShape * const * input_shapes,const uint8 * const * input_data,const RuntimeShape & output_shape,uint8 * output_data)396 void PackWithScaling(const PackParams& params,
397                      const RuntimeShape* const* input_shapes,
398                      const uint8* const* input_data,
399                      const RuntimeShape& output_shape, uint8* output_data) {
400   ruy::profiler::ScopeLabel label("PackWithScaling");
401   const int dimensions = output_shape.DimensionsCount();
402   int axis = params.axis;
403   const int32* input_zeropoint = params.input_zeropoint;
404   const float* input_scale = params.input_scale;
405   int inputs_count = params.inputs_count;
406   const int32 output_zeropoint = params.output_zeropoint;
407   const float output_scale = params.output_scale;
408 
409   int outer_size = 1;
410   for (int i = 0; i < axis; i++) {
411     outer_size *= output_shape.Dims(i);
412   }
413   int copy_size = 1;
414   for (int i = axis + 1; i < dimensions; i++) {
415     copy_size *= output_shape.Dims(i);
416   }
417   TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
418 
419   Scalar* output_ptr = output_data;
420   const float inverse_output_scale = 1.f / output_scale;
421   for (int k = 0; k < outer_size; k++) {
422     for (int i = 0; i < inputs_count; ++i) {
423       if (input_zeropoint[i] == output_zeropoint &&
424           input_scale[i] == output_scale) {
425         memcpy(output_ptr, input_data[i] + k * copy_size,
426                copy_size * sizeof(Scalar));
427       } else {
428         assert(false);
429         const float scale = input_scale[i] * inverse_output_scale;
430         const float bias = -input_zeropoint[i] * scale;
431         auto input_ptr = input_data[i];
432         for (int j = 0; j < copy_size; ++j) {
433           const int32_t value =
434               static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) +
435               output_zeropoint;
436           output_ptr[j] =
437               static_cast<uint8_t>(std::max(std::min(255, value), 0));
438         }
439       }
440       output_ptr += copy_size;
441     }
442   }
443 }
444 
445 template <typename Scalar>
DepthConcatenation(const ConcatenationParams & params,const RuntimeShape * const * input_shapes,const Scalar * const * input_data,const RuntimeShape & output_shape,Scalar * output_data)446 void DepthConcatenation(const ConcatenationParams& params,
447                         const RuntimeShape* const* input_shapes,
448                         const Scalar* const* input_data,
449                         const RuntimeShape& output_shape, Scalar* output_data) {
450   ruy::profiler::ScopeLabel label("DepthConcatenation");
451   auto params_copy = params;
452   params_copy.axis = 3;
453   Concatenation(params_copy, input_shapes, input_data, output_shape,
454                 output_data);
455 }
456 
457 template <typename Scalar>
Split(const SplitParams & params,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape * const * output_shapes,Scalar * const * output_data)458 void Split(const SplitParams& params, const RuntimeShape& input_shape,
459            const Scalar* input_data, const RuntimeShape* const* output_shapes,
460            Scalar* const* output_data) {
461   ruy::profiler::ScopeLabel label("Split");
462   const int split_dimensions = input_shape.DimensionsCount();
463   int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
464   int outputs_count = params.num_split;
465   TFLITE_DCHECK_LT(axis, split_dimensions);
466 
467   int64_t split_size = 0;
468   for (int i = 0; i < outputs_count; i++) {
469     TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
470     for (int j = 0; j < split_dimensions; j++) {
471       if (j != axis) {
472         MatchingDim(*output_shapes[i], j, input_shape, j);
473       }
474     }
475     split_size += output_shapes[i]->Dims(axis);
476   }
477   TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
478   int64_t outer_size = 1;
479   for (int i = 0; i < axis; ++i) {
480     outer_size *= input_shape.Dims(i);
481   }
482   // For all output arrays,
483   // FlatSize() = outer_size * Dims(axis) * base_inner_size;
484   int64_t base_inner_size = 1;
485   for (int i = axis + 1; i < split_dimensions; ++i) {
486     base_inner_size *= input_shape.Dims(i);
487   }
488 
489   const Scalar* input_ptr = input_data;
490   for (int k = 0; k < outer_size; k++) {
491     for (int i = 0; i < outputs_count; ++i) {
492       const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
493       memcpy(output_data[i] + k * copy_size, input_ptr,
494              copy_size * sizeof(Scalar));
495       input_ptr += copy_size;
496     }
497   }
498 }
499 
NodeOffset(int b,int h,int w,int height,int width)500 inline int NodeOffset(int b, int h, int w, int height, int width) {
501   return (b * height + h) * width + w;
502 }
503 
LocalResponseNormalization(const tflite::LocalResponseNormalizationParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)504 inline void LocalResponseNormalization(
505     const tflite::LocalResponseNormalizationParams& op_params,
506     const RuntimeShape& input_shape, const float* input_data,
507     const RuntimeShape& output_shape, float* output_data) {
508   const int trailing_dim = input_shape.DimensionsCount() - 1;
509   const int outer_size =
510       MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
511   const int depth =
512       MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
513 
514   for (int i = 0; i < outer_size; ++i) {
515     for (int c = 0; c < depth; ++c) {
516       const int begin_input_c = std::max(0, c - op_params.range);
517       const int end_input_c = std::min(depth, c + op_params.range);
518       float accum = 0.f;
519       for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
520         const float input_val = input_data[i * depth + input_c];
521         accum += input_val * input_val;
522       }
523       const float multiplier =
524           std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
525       output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
526     }
527   }
528 }
529 
Dequantize(const RuntimeShape & input_shape,const Eigen::half * input_data,const RuntimeShape & output_shape,float * output_data)530 inline void Dequantize(const RuntimeShape& input_shape,
531                        const Eigen::half* input_data,
532                        const RuntimeShape& output_shape, float* output_data) {
533   const int flat_size = MatchingFlatSize(input_shape, output_shape);
534   for (int i = 0; i < flat_size; i++) {
535     output_data[i] = static_cast<float>(input_data[i]);
536   }
537 }
538 
FakeQuant(const tflite::FakeQuantParams & op_params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & output_shape,float * output_data)539 inline void FakeQuant(const tflite::FakeQuantParams& op_params,
540                       const RuntimeShape& input_shape, const float* input_data,
541                       const RuntimeShape& output_shape, float* output_data) {
542   ruy::profiler::ScopeLabel label("FakeQuant");
543   float rmin = op_params.minmax.min;
544   float rmax = op_params.minmax.max;
545   int num_bits = op_params.num_bits;
546   // 0 should always be a representable value. Let's assume that the initial
547   // min,max range contains 0.
548   TFLITE_DCHECK_LE(rmin, 0.0f);
549   TFLITE_DCHECK_GE(rmax, 0.0f);
550   TFLITE_DCHECK_LT(rmin, rmax);
551 
552   // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
553   int quant_min = 0;
554   int quant_max = (1 << num_bits) - 1;
555   float nudged_min, nudged_max, nudged_scale;
556   NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
557                          &nudged_max, &nudged_scale);
558   const int flat_size = MatchingFlatSize(input_shape, output_shape);
559   FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
560                     output_data, flat_size);
561 }
562 
563 // Common subroutine for both `GatherNd` and `GatherNdString`.
564 struct GatherNdHelperResult {
565   int n_slices;
566   int slice_size;
567   int indices_nd;
568   std::vector<int> dims_to_count;
569 };
570 
571 // Returns common values being used on both `GatherNd` and `GatherNdString`.
GatherNdHelper(const RuntimeShape & params_shape,const RuntimeShape & indices_shape)572 inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape,
573                                            const RuntimeShape& indices_shape) {
574   GatherNdHelperResult ret;
575   ret.n_slices = 1;
576   ret.slice_size = 1;
577   const int indices_dims = indices_shape.DimensionsCount();
578   ret.indices_nd = indices_shape.Dims(indices_dims - 1);
579   const int params_dims = params_shape.DimensionsCount();
580   for (int i = 0; i < indices_dims - 1; ++i) {
581     ret.n_slices *= indices_shape.Dims(i);
582   }
583   if (ret.n_slices == 0) return ret;
584 
585   for (int i = ret.indices_nd; i < params_dims; ++i) {
586     ret.slice_size *= params_shape.Dims(i);
587   }
588 
589   int remain_flat_size = params_shape.FlatSize();
590   ret.dims_to_count = std::vector<int>(ret.indices_nd, 0);
591   for (int i = 0; i < ret.indices_nd; ++i) {
592     ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i);
593     remain_flat_size = ret.dims_to_count[i];
594   }
595 
596   return ret;
597 }
598 
599 // Implements GatherNd.
600 // Returns an error if any of the indices_data would cause an out of bounds
601 // memory read.
602 template <typename ParamsT, typename IndicesT = int32>
GatherNd(const RuntimeShape & params_shape,const ParamsT * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,ParamsT * output_data)603 inline TfLiteStatus GatherNd(const RuntimeShape& params_shape,
604                              const ParamsT* params_data,
605                              const RuntimeShape& indices_shape,
606                              const IndicesT* indices_data,
607                              const RuntimeShape& output_shape,
608                              ParamsT* output_data) {
609   ruy::profiler::ScopeLabel label("GatherNd");
610 
611   const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
612   for (int i = 0; i < res.n_slices; ++i) {
613     int64_t from_pos = 0;
614     for (int j = 0; j < res.indices_nd; ++j) {
615       from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
616     }
617     if (from_pos < 0 || from_pos + res.slice_size > params_shape.FlatSize()) {
618       return kTfLiteError;
619     }
620     std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
621                 sizeof(ParamsT) * res.slice_size);
622   }
623   return kTfLiteOk;
624 }
625 
626 #ifndef TF_LITE_STATIC_MEMORY
627 // Implements GatherNd on strings.
628 // Returns an error if any of the indices_data would cause an out of bounds
629 // memory read.
630 template <typename IndicesT = int32>
GatherNdString(const RuntimeShape & params_shape,const TfLiteTensor * params_data,const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & output_shape,TfLiteTensor * output_data)631 inline TfLiteStatus GatherNdString(const RuntimeShape& params_shape,
632                                    const TfLiteTensor* params_data,
633                                    const RuntimeShape& indices_shape,
634                                    const IndicesT* indices_data,
635                                    const RuntimeShape& output_shape,
636                                    TfLiteTensor* output_data) {
637   ruy::profiler::ScopeLabel label("GatherNdString");
638 
639   const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
640   DynamicBuffer buffer;
641   for (int i = 0; i < res.n_slices; ++i) {
642     int64_t from_pos = 0;
643     for (int j = 0; j < res.indices_nd; ++j) {
644       from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
645     }
646     if (from_pos < 0 || from_pos + res.slice_size > params_shape.FlatSize()) {
647       return kTfLiteError;
648     }
649     for (int j = 0; j < res.slice_size; ++j) {
650       buffer.AddString(GetString(params_data, from_pos + j));
651     }
652   }
653   buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
654   return kTfLiteOk;
655 }
656 #endif
657 
658 template <typename IndicesT, typename UpdatesT>
ScatterNd(const RuntimeShape & indices_shape,const IndicesT * indices_data,const RuntimeShape & updates_shape,const UpdatesT * updates_data,const RuntimeShape & output_shape,UpdatesT * output_data)659 inline TfLiteStatus ScatterNd(const RuntimeShape& indices_shape,
660                               const IndicesT* indices_data,
661                               const RuntimeShape& updates_shape,
662                               const UpdatesT* updates_data,
663                               const RuntimeShape& output_shape,
664                               UpdatesT* output_data) {
665   ruy::profiler::ScopeLabel label("ScatterNd");
666 
667   int n_slices = 1;
668   int slice_size = 1;
669   const int outer_dims = indices_shape.DimensionsCount() - 1;
670   const int indices_nd = indices_shape.Dims(outer_dims);
671   const int updates_dims = updates_shape.DimensionsCount();
672   for (int i = 0; i < outer_dims; ++i) {
673     n_slices *= indices_shape.Dims(i);
674   }
675   for (int i = outer_dims; i < updates_dims; ++i) {
676     slice_size *= updates_shape.Dims(i);
677   }
678 
679   int output_flat_size = output_shape.FlatSize();
680   int remain_flat_size = output_flat_size;
681   std::vector<int> dims_to_count(indices_nd, 0);
682   for (int i = 0; i < indices_nd; ++i) {
683     dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
684     remain_flat_size = dims_to_count[i];
685   }
686 
687   if (n_slices * slice_size > updates_shape.FlatSize()) {
688     return kTfLiteError;
689   }
690   memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
691   for (int i = 0; i < n_slices; ++i) {
692     int to_pos = 0;
693     for (int j = 0; j < indices_nd; ++j) {
694       IndicesT idx = indices_data[i * indices_nd + j];
695       to_pos += idx * dims_to_count[j];
696     }
697     if (to_pos < 0 || to_pos + slice_size > output_flat_size) {
698       return kTfLiteError;
699     }
700     for (int j = 0; j < slice_size; j++) {
701       output_data[to_pos + j] += updates_data[i * slice_size + j];
702     }
703   }
704   return kTfLiteOk;
705 }
706 
707 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)708 void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
709              const T* input2_data, const RuntimeShape& output_shape,
710              T* output_data) {
711   const int flat_size = MatchingFlatSize(input1_shape, output_shape);
712 
713   auto min_value = input2_data[0];
714   for (int i = 0; i < flat_size; i++) {
715     output_data[i] = input1_data[i] > min_value ? min_value : input1_data[i];
716   }
717 }
718 
719 // Convenience version that allows, for example, generated-code calls to be
720 // the same as other binary ops.
721 template <typename T>
Minimum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)722 inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
723                     const RuntimeShape&, const T* input2_data,
724                     const RuntimeShape& output_shape, T* output_data) {
725   // Drop shape of second input: not needed.
726   Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
727 }
728 
729 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const T * input2_data,const RuntimeShape & output_shape,T * output_data)730 void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
731              const T* input2_data, const RuntimeShape& output_shape,
732              T* output_data) {
733   const int flat_size = MatchingFlatSize(input1_shape, output_shape);
734 
735   auto max_value = input2_data[0];
736   for (int i = 0; i < flat_size; i++) {
737     output_data[i] = input1_data[i] < max_value ? max_value : input1_data[i];
738   }
739 }
740 
741 // Convenience version that allows, for example, generated-code calls to be
742 // the same as other binary ops.
743 template <typename T>
Maximum(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape &,const T * input2_data,const RuntimeShape & output_shape,T * output_data)744 inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
745                     const RuntimeShape&, const T* input2_data,
746                     const RuntimeShape& output_shape, T* output_data) {
747   // Drop shape of second input: not needed.
748   Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
749 }
750 
751 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)752 void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
753             const T3* input2_data, const RuntimeShape& output_shape,
754             T2* output_data) {
755   ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
756             std::greater<T1>());
757 }
758 
759 // Convenience version that allows, for example, generated-code calls to be
760 // the same as other binary ops.
761 template <typename T1, typename T2, typename T3>
ArgMax(const RuntimeShape & input1_shape,const T1 * input1_data,const RuntimeShape & input2_shape,const T3 * input2_data,const RuntimeShape & output_shape,T2 * output_data)762 inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
763                    const RuntimeShape& input2_shape, const T3* input2_data,
764                    const RuntimeShape& output_shape, T2* output_data) {
765   // Drop shape of second input: not needed.
766   ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
767 }
768 
769 template <typename D, typename T>
SelectTrueCoords(const RuntimeShape & input_condition_shape,const D * input_condition_data,T * output_data)770 void SelectTrueCoords(const RuntimeShape& input_condition_shape,
771                       const D* input_condition_data, T* output_data) {
772   const size_t size = input_condition_shape.FlatSize();
773   if (size == 0) {
774     // Dimension is zero, in which case we don't need to output.
775     return;
776   }
777   const size_t cond_rank = input_condition_shape.DimensionsCount();
778 
779   std::vector<int> dims_to_count(cond_rank, 0);
780   int cur_flat_size = size;
781   for (int i = 0; i < cond_rank; ++i) {
782     dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
783     cur_flat_size = dims_to_count[i];
784   }
785 
786   int output_index = 0;
787   for (int i = 0; i < size; ++i) {
788     if (input_condition_data[i] != D(0)) {
789       // Insert the coordinate of the current item (row major) into output.
790       int flat_index = i;
791       for (int j = 0; j < cond_rank; ++j) {
792         int coord_j = flat_index / dims_to_count[j];
793         output_data[output_index * cond_rank + j] = coord_j;
794         flat_index %= dims_to_count[j];
795       }
796       output_index++;
797     }
798   }
799 }
800 
801 // For easy implementation, the indices is always a vector of size-4 vectors.
802 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,bool value_is_scalar,const RuntimeShape & unextended_output_shape,T * output_data)803 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
804                           const T* values, T default_value,
805                           bool value_is_scalar,
806                           const RuntimeShape& unextended_output_shape,
807                           T* output_data) {
808   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
809   const RuntimeShape output_shape =
810       RuntimeShape::ExtendedShape(4, unextended_output_shape);
811   const int value_count = indices.size();
812 
813   // First fill the output_data with default value.
814   const int num_elements = output_shape.FlatSize();
815   for (int i = 0; i < num_elements; ++i) {
816     output_data[i] = default_value;
817   }
818 
819   // Special handle for value is scalar case to avoid checking the boolean
820   // condition within the loop every time.
821   if (value_is_scalar) {
822     for (int i = 0; i < value_count; ++i) {
823       const std::vector<TI>& index = indices[i];
824       TFLITE_DCHECK_EQ(index.size(), 4);
825       const T value = *values;  // just use the first value.
826       output_data[Offset(output_shape, index[0], index[1], index[2],
827                          index[3])] = value;
828     }
829     return;
830   }
831 
832   // Go through the values and indices to fill the sparse values.
833   for (int i = 0; i < value_count; ++i) {
834     const std::vector<TI>& index = indices[i];
835     TFLITE_DCHECK_EQ(index.size(), 4);
836     const T value = values[i];
837     output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
838         value;
839   }
840 }
841 
842 template <typename T>
Pow(const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,T * output_data)843 inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
844                 const RuntimeShape& input2_shape, const T* input2_data,
845                 const RuntimeShape& output_shape, T* output_data) {
846   const int flat_size =
847       MatchingFlatSize(input1_shape, input2_shape, output_shape);
848   for (int i = 0; i < flat_size; ++i) {
849     output_data[i] = std::pow(input1_data[i], input2_data[i]);
850   }
851 }
852 
853 template <typename T>
BroadcastPow4DSlow(const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,T * output_data)854 inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
855                                const T* input1_data,
856                                const RuntimeShape& unextended_input2_shape,
857                                const T* input2_data,
858                                const RuntimeShape& unextended_output_shape,
859                                T* output_data) {
860   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
861   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
862   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
863   const RuntimeShape output_shape =
864       RuntimeShape::ExtendedShape(4, unextended_output_shape);
865 
866   NdArrayDesc<4> desc1;
867   NdArrayDesc<4> desc2;
868   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
869                                       unextended_input2_shape, &desc1, &desc2);
870 
871   for (int b = 0; b < output_shape.Dims(0); ++b) {
872     for (int y = 0; y < output_shape.Dims(1); ++y) {
873       for (int x = 0; x < output_shape.Dims(2); ++x) {
874         for (int c = 0; c < output_shape.Dims(3); ++c) {
875           auto out_idx = Offset(output_shape, b, y, x, c);
876           auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
877           auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
878           auto in1_val = input1_data[in1_idx];
879           auto in2_val = input2_data[in2_idx];
880           output_data[out_idx] = std::pow(in1_val, in2_val);
881         }
882       }
883     }
884   }
885 }
886 
887 template <typename Scalar>
Reverse(int axis,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)888 void Reverse(int axis, const RuntimeShape& input_shape,
889              const Scalar* input_data, const RuntimeShape& output_shape,
890              Scalar* output_data) {
891   ruy::profiler::ScopeLabel label("Reverse");
892 
893   int outer_size = 1;
894   for (int i = 0; i < axis; ++i) {
895     outer_size *= input_shape.Dims(i);
896   }
897 
898   int copy_size = 1;
899   for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i) {
900     copy_size *= input_shape.Dims(i);
901   }
902 
903   const int dims_at_axis = input_shape.Dims(axis);
904   for (int i = 0; i < outer_size; ++i) {
905     for (int j = 0; j < dims_at_axis; ++j) {
906       const int start_pos = (i * dims_at_axis + j) * copy_size;
907       Scalar* output_ptr = output_data + start_pos;
908       int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
909       memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
910     }
911   }
912 }
913 
914 template <typename Scalar, typename TS>
ReverseSequence(const TS * seq_lengths,const int seq_dim,const int batch_dim,const RuntimeShape & input_shape,const Scalar * input_data,const RuntimeShape & output_shape,Scalar * output_data)915 void ReverseSequence(const TS* seq_lengths, const int seq_dim,
916                      const int batch_dim, const RuntimeShape& input_shape,
917                      const Scalar* input_data, const RuntimeShape& output_shape,
918                      Scalar* output_data) {
919   ruy::profiler::ScopeLabel label("ReverseSequence");
920 
921   int outer_size = 1;
922   int outer_dim = std::min(batch_dim, seq_dim);
923   int medium_dim = std::max(batch_dim, seq_dim);
924   for (int i = 0; i < outer_dim; ++i) {
925     outer_size *= input_shape.Dims(i);
926   }
927 
928   int medium_size = 1;
929   for (int i = outer_dim + 1; i < medium_dim; ++i) {
930     medium_size *= input_shape.Dims(i);
931   }
932 
933   int copy_size = 1;
934   for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i) {
935     copy_size *= input_shape.Dims(i);
936   }
937 
938   const int dims_at_outer_dim = input_shape.Dims(outer_dim);
939   const int dims_at_medium_dim = input_shape.Dims(medium_dim);
940 
941   Scalar* output_ptr;
942   if (batch_dim > seq_dim) {
943     for (int i = 0; i < outer_size; ++i) {
944       for (int j = 0; j < dims_at_outer_dim; ++j) {
945         const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
946         for (int p = 0; p < medium_size; ++p) {
947           for (int q = 0; q < dims_at_medium_dim; ++q) {
948             const int in_pos =
949                 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
950             const Scalar* in_ptr = input_data + in_pos;
951             int sl = seq_lengths[q] - 1;
952             if (j > sl) {
953               output_ptr = output_data + in_pos;
954             } else {
955               const int out_pos_base =
956                   (i * dims_at_outer_dim + sl - j) * medium_size;
957               const int out_pos =
958                   ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size;
959               output_ptr = output_data + out_pos;
960             }
961             memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
962           }
963         }
964       }
965     }
966   } else if (batch_dim < seq_dim) {
967     for (int i = 0; i < outer_size; ++i) {
968       for (int j = 0; j < dims_at_outer_dim; ++j) {
969         const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
970         int sl = seq_lengths[j] - 1;
971         const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size;
972         for (int p = 0; p < medium_size; ++p) {
973           for (int q = 0; q < dims_at_medium_dim; ++q) {
974             const int in_pos =
975                 ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
976             const Scalar* in_ptr = input_data + in_pos;
977             if (q > sl) {
978               output_ptr = output_data + in_pos;
979             } else {
980               const int out_pos =
981                   ((out_pos_base + p) * dims_at_medium_dim + sl - q) *
982                   copy_size;
983               output_ptr = output_data + out_pos;
984             }
985             memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
986           }
987         }
988       }
989     }
990   }
991 }
992 
993 template <typename T>
SegmentSum(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const RuntimeShape & output_shape,T * output_data)994 inline void SegmentSum(const RuntimeShape& input_shape, const T* input_data,
995                        const RuntimeShape& segment_ids_shape,
996                        const int32_t* segment_ids_data,
997                        const RuntimeShape& output_shape, T* output_data) {
998   const int segment_flat_size =
999       MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
1000 
1001   memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
1002 
1003   for (int i = 0; i < input_shape.Dims(0); i++) {
1004     int output_index = segment_ids_data[i];
1005     for (int j = 0; j < segment_flat_size; ++j) {
1006       output_data[output_index * segment_flat_size + j] +=
1007           input_data[i * segment_flat_size + j];
1008     }
1009   }
1010 }
1011 
1012 template <typename T, template <typename T2> typename Op>
UnsortedSegmentRef(const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & segment_ids_shape,const int32_t * segment_ids_data,const RuntimeShape & output_shape,T * output_data)1013 inline void UnsortedSegmentRef(const RuntimeShape& input_shape,
1014                                const T* input_data,
1015                                const RuntimeShape& segment_ids_shape,
1016                                const int32_t* segment_ids_data,
1017                                const RuntimeShape& output_shape,
1018                                T* output_data) {
1019   for (int i = 0; i < output_shape.FlatSize(); ++i) {
1020     output_data[i] = Op<T>::kInitialValue;
1021   }
1022   Op<T> op;
1023   int segment_flat_size = 1;
1024   for (int i = 1; i < output_shape.DimensionsCount(); ++i) {
1025     segment_flat_size *= output_shape.Dims(i);
1026   }
1027   for (int i = 0; i < segment_ids_shape.FlatSize(); i++) {
1028     int output_index = segment_ids_data[i];
1029     if (output_index < 0) continue;
1030     for (int j = 0; j < segment_flat_size; ++j) {
1031       output_data[output_index * segment_flat_size + j] =
1032           op(output_data[output_index * segment_flat_size + j],
1033              input_data[i * segment_flat_size + j]);
1034     }
1035   }
1036 }
1037 
1038 }  // namespace reference_ops
1039 }  // namespace tflite
1040 
1041 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_
1042