1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
17
18 #include <algorithm>
19 #include <cstdint>
20
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25
26 namespace tflite {
27 namespace reference_ops {
28 namespace batch_matmul {
29
30 // Determine which dimension is the broadcast dimension.
broadcast_dim(int lhs_dim,int rhs_dim)31 inline int broadcast_dim(int lhs_dim, int rhs_dim) {
32 if (lhs_dim == rhs_dim) return lhs_dim;
33 if (lhs_dim == 1) return rhs_dim;
34 TFLITE_DCHECK_EQ(rhs_dim, 1);
35 return lhs_dim;
36 }
37
38 // Compute the "extent" for iterating on this dimension.
39 // If we are broadcasting, then don't advance (i.e return 0).
extent(const RuntimeShape & shape,int x)40 inline int extent(const RuntimeShape& shape, int x) {
41 if (shape.Dims(x) == 1) {
42 return 0;
43 }
44 int prod = 1;
45 for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
46 prod *= shape.Dims(i);
47 }
48 return prod;
49 }
50
51 } // namespace batch_matmul
52
53 template <typename Ta, typename Tb, typename Tout>
BatchMatMul(const RuntimeShape & lhs_shape,const Ta * lhs_data,const RuntimeShape & rhs_shape,const Tb * rhs_data,const RuntimeShape & output_shape,Tout * output_data)54 inline void BatchMatMul(const RuntimeShape& lhs_shape, const Ta* lhs_data,
55 const RuntimeShape& rhs_shape, const Tb* rhs_data,
56 const RuntimeShape& output_shape, Tout* output_data) {
57 const RuntimeShape extended_lhs_shape =
58 RuntimeShape::ExtendedShape(5, lhs_shape);
59 const RuntimeShape extended_rhs_shape =
60 RuntimeShape::ExtendedShape(5, rhs_shape);
61
62 const int batch_dim0 = batch_matmul::broadcast_dim(
63 extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
64 const int batch_dim1 = batch_matmul::broadcast_dim(
65 extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
66 const int batch_dim2 = batch_matmul::broadcast_dim(
67 extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
68
69 const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
70 const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
71 const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
72 const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
73 const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
74 const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
75
76 // Set params for each matrix multiply.
77 const int lhs_rows = extended_lhs_shape.Dims(3);
78 const int rhs_cols = extended_rhs_shape.Dims(4);
79 const int accum_depth = extended_lhs_shape.Dims(4);
80
81 for (int b0 = 0; b0 < batch_dim0; ++b0) {
82 const Ta* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
83 const Tb* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
84 for (int b1 = 0; b1 < batch_dim1; ++b1) {
85 const Ta* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
86 const Tb* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
87 for (int b2 = 0; b2 < batch_dim2; ++b2) {
88 const Ta* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
89 const Tb* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
90 Tout* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
91 b1 * batch_dim2 + b2) *
92 lhs_rows * rhs_cols;
93 for (int j = 0; j < rhs_cols; ++j) {
94 for (int i = 0; i < lhs_rows; ++i) {
95 Tout total = 0;
96 for (int k = 0; k < accum_depth; ++k) {
97 total += static_cast<Tout>(lhs_ptr2[accum_depth * i + k]) *
98 static_cast<Tout>(rhs_ptr2[j * accum_depth + k]);
99 }
100 int idx = lhs_rows * j + i;
101 out_ptr[idx] = total;
102 }
103 }
104 }
105 }
106 }
107 }
108
BatchMatMul(const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const float * scaling_factors,const int32_t * input_offset,int32_t * row_sums,const RuntimeShape & output_shape,float * output_data,bool * compute_row_sums)109 inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
110 const RuntimeShape& rhs_shape, const int8_t* rhs_data,
111 const float* scaling_factors,
112 const int32_t* input_offset, int32_t* row_sums,
113 const RuntimeShape& output_shape, float* output_data,
114 bool* compute_row_sums) {
115 const RuntimeShape extended_lhs_shape =
116 RuntimeShape::ExtendedShape(5, lhs_shape);
117 const RuntimeShape extended_rhs_shape =
118 RuntimeShape::ExtendedShape(5, rhs_shape);
119
120 const int batch_dim0 = batch_matmul::broadcast_dim(
121 extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
122 const int batch_dim1 = batch_matmul::broadcast_dim(
123 extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
124 const int batch_dim2 = batch_matmul::broadcast_dim(
125 extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
126
127 const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
128 const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
129 const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
130 const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
131 const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
132 const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
133
134 // Set params for each matrix multiply.
135 const int lhs_rows = extended_lhs_shape.Dims(3);
136 const int rhs_cols = extended_rhs_shape.Dims(4);
137 const int accum_depth = extended_lhs_shape.Dims(4);
138
139 const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols;
140 const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols;
141 const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols;
142 const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows;
143 const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows;
144 const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows;
145
146 if (!compute_row_sums || *compute_row_sums) {
147 int num_weights_matrices = 1;
148 for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
149 num_weights_matrices *= extended_lhs_shape.Dims(i);
150 }
151 tensor_utils::ReductionSumVector(
152 lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth);
153 if (compute_row_sums) {
154 *compute_row_sums = false;
155 }
156 }
157
158 for (int b0 = 0; b0 < batch_dim0; ++b0) {
159 const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
160 const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
161 const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0);
162 const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0);
163 const int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0);
164 for (int b1 = 0; b1 < batch_dim1; ++b1) {
165 const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
166 const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
167 const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1);
168 const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1);
169 const int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1);
170 for (int b2 = 0; b2 < batch_dim2; ++b2) {
171 const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
172 const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
173 const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2);
174 const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2);
175 const int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2);
176 float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
177 b1 * batch_dim2 + b2) *
178 lhs_rows * rhs_cols;
179 for (int j = 0; j < rhs_cols; ++j) {
180 const float batch_scaling_factor = scale_ptr2[j];
181 const float batch_offset = static_cast<float>(ioff_ptr2[j]);
182 for (int i = 0; i < lhs_rows; ++i) {
183 int32_t total = 0;
184 for (int k = 0; k < accum_depth; ++k) {
185 total +=
186 lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
187 }
188 int32_t row_sum = woff_ptr2[i];
189 total -= row_sum * batch_offset;
190 int idx = lhs_rows * j + i;
191 out_ptr[idx] += batch_scaling_factor * total;
192 }
193 }
194 }
195 }
196 }
197 }
198
199 template <typename T, typename AccumT>
BatchMatMul(const FullyConnectedParams & params,const RuntimeShape & lhs_shape,const T * lhs_data,const RuntimeShape & rhs_shape,const T * rhs_data,const RuntimeShape & output_shape,T * output_data)200 inline void BatchMatMul(const FullyConnectedParams& params,
201 const RuntimeShape& lhs_shape, const T* lhs_data,
202 const RuntimeShape& rhs_shape, const T* rhs_data,
203 const RuntimeShape& output_shape, T* output_data) {
204 const RuntimeShape extended_lhs_shape =
205 RuntimeShape::ExtendedShape(5, lhs_shape);
206 const RuntimeShape extended_rhs_shape =
207 RuntimeShape::ExtendedShape(5, rhs_shape);
208
209 const int batch_dim0 = batch_matmul::broadcast_dim(
210 extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
211 const int batch_dim1 = batch_matmul::broadcast_dim(
212 extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
213 const int batch_dim2 = batch_matmul::broadcast_dim(
214 extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
215
216 const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
217 const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
218 const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
219 const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
220 const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
221 const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
222
223 // Set params for each matrix multiply.
224 const int lhs_rows = extended_lhs_shape.Dims(3);
225 const int rhs_cols = extended_rhs_shape.Dims(4);
226 const int accum_depth = extended_lhs_shape.Dims(4);
227
228 const int32_t input_offset = params.input_offset;
229 const int32_t filter_offset = params.weights_offset;
230 const int32_t output_offset = params.output_offset;
231 const int32_t output_multiplier = params.output_multiplier;
232 const int output_shift = params.output_shift;
233 const int32_t output_activation_min = params.quantized_activation_min;
234 const int32_t output_activation_max = params.quantized_activation_max;
235 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
236
237 for (int b0 = 0; b0 < batch_dim0; ++b0) {
238 const T* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
239 const T* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
240 for (int b1 = 0; b1 < batch_dim1; ++b1) {
241 const T* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
242 const T* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
243 for (int b2 = 0; b2 < batch_dim2; ++b2) {
244 const T* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
245 const T* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
246 T* out_ptr = output_data +
247 ((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) *
248 lhs_rows * rhs_cols;
249
250 for (int j = 0; j < rhs_cols; ++j) {
251 for (int i = 0; i < lhs_rows; ++i) {
252 AccumT total = 0;
253 for (int k = 0; k < accum_depth; ++k) {
254 AccumT lhs_val = lhs_ptr2[accum_depth * i + k];
255 AccumT rhs_val = rhs_ptr2[accum_depth * j + k];
256 total += (lhs_val + filter_offset) * (rhs_val + input_offset);
257 }
258 int32_t total_scaled = MultiplyByQuantizedMultiplier(
259 total, output_multiplier, output_shift);
260 total_scaled += output_offset;
261 total_scaled = std::max(total_scaled, output_activation_min);
262 total_scaled = std::min(total_scaled, output_activation_max);
263 const int idx = lhs_rows * j + i;
264 out_ptr[idx] = static_cast<T>(total_scaled);
265 }
266 }
267 }
268 }
269 }
270 }
271
272 } // namespace reference_ops
273 } // namespace tflite
274
275 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
276