1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
17
18 #include "tensorflow/lite/kernels/internal/common.h"
19 #include "tensorflow/lite/kernels/kernel_util.h"
20
21 namespace tflite {
22 namespace reference_ops {
23 template <int N>
BroadcastImpl(const NdArrayDesc<N> & input_desc,const char * input_data,const NdArrayDesc<N> & output_desc,char * output_data,int indexes[N],int dim,const int last_broadcasting_dim,const int type_size)24 void BroadcastImpl(const NdArrayDesc<N>& input_desc, const char* input_data,
25 const NdArrayDesc<N>& output_desc, char* output_data,
26 int indexes[N], int dim, const int last_broadcasting_dim,
27 const int type_size) {
28 // Copy data from input to output.
29 if (dim == last_broadcasting_dim) {
30 int copy_size = output_desc.strides[dim] * type_size;
31 const char* data_src =
32 input_data + SubscriptToIndex(input_desc, indexes) * type_size;
33 char* data_dst =
34 output_data + SubscriptToIndex(output_desc, indexes) * type_size;
35 for (int i = 0; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
36 memcpy(data_dst, data_src, copy_size);
37 }
38 return;
39 }
40
41 // Recursive call to find the next broadcasting.
42 for (indexes[dim] = 0; indexes[dim] < input_desc.extents[dim];
43 ++indexes[dim]) {
44 BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes,
45 dim + 1, last_broadcasting_dim, type_size);
46 }
47
48 // Duplicate data in output tensor.
49 indexes[dim] = 0;
50 if (input_desc.extents[dim] != output_desc.extents[dim]) {
51 int copy_size = output_desc.strides[dim] * type_size;
52 char* data_src =
53 output_data + SubscriptToIndex(output_desc, indexes) * type_size;
54 char* data_dst = data_src + copy_size;
55 for (int i = 1; i < output_desc.extents[dim]; ++i, data_dst += copy_size) {
56 memcpy(data_dst, data_src, copy_size);
57 }
58 }
59 }
60
61 template <int N>
BroadcastTo(const RuntimeShape & unextended_input_shape,const char * input_data,const RuntimeShape & unextended_output_shape,char * output_data,TfLiteType data_type)62 inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
63 const char* input_data,
64 const RuntimeShape& unextended_output_shape,
65 char* output_data, TfLiteType data_type) {
66 NdArrayDesc<N> input_desc;
67 NdArrayDesc<N> output_desc;
68 CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_input_shape),
69 &input_desc);
70 CopyDimsToDesc(RuntimeShape::ExtendedShape(N, unextended_output_shape),
71 &output_desc);
72
73 // Get the last dimension has broadcasting. At this dimension, the data is
74 // copied from input tensor to output tensor.
75 int last_broadcast_dim = -1;
76 for (int i = N - 1; i >= 0; --i) {
77 if (input_desc.extents[i] != output_desc.extents[i]) {
78 last_broadcast_dim = i;
79 break;
80 }
81 }
82
83 // If non-broadcasting, just copy data from input to output tensor.
84 if (last_broadcast_dim == -1) {
85 memcpy(output_data, input_data,
86 unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
87 return;
88 }
89
90 // Broadcasting using memcpy.
91 int indexes[N] = {0};
92 BroadcastImpl<N>(input_desc, input_data, output_desc, output_data, indexes, 0,
93 last_broadcast_dim, TfLiteTypeGetSize(data_type));
94 }
95 } // namespace reference_ops
96 } // namespace tflite
97 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
98