xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/batch_matmul.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
17 
18 #include <algorithm>
19 #include <cstdint>
20 
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
24 #include "tensorflow/lite/kernels/internal/types.h"
25 
26 namespace tflite {
27 namespace reference_ops {
28 namespace batch_matmul {
29 
30 // Determine which dimension is the broadcast dimension.
broadcast_dim(int lhs_dim,int rhs_dim)31 inline int broadcast_dim(int lhs_dim, int rhs_dim) {
32   if (lhs_dim == rhs_dim) return lhs_dim;
33   if (lhs_dim == 1) return rhs_dim;
34   TFLITE_DCHECK_EQ(rhs_dim, 1);
35   return lhs_dim;
36 }
37 
38 // Compute the "extent" for iterating on this dimension.
39 // If we are broadcasting, then don't advance (i.e return 0).
extent(const RuntimeShape & shape,int x)40 inline int extent(const RuntimeShape& shape, int x) {
41   if (shape.Dims(x) == 1) {
42     return 0;
43   }
44   int prod = 1;
45   for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
46     prod *= shape.Dims(i);
47   }
48   return prod;
49 }
50 
51 }  // namespace batch_matmul
52 
53 template <typename Ta, typename Tb, typename Tout>
BatchMatMul(const RuntimeShape & lhs_shape,const Ta * lhs_data,const RuntimeShape & rhs_shape,const Tb * rhs_data,const RuntimeShape & output_shape,Tout * output_data)54 inline void BatchMatMul(const RuntimeShape& lhs_shape, const Ta* lhs_data,
55                         const RuntimeShape& rhs_shape, const Tb* rhs_data,
56                         const RuntimeShape& output_shape, Tout* output_data) {
57   const RuntimeShape extended_lhs_shape =
58       RuntimeShape::ExtendedShape(5, lhs_shape);
59   const RuntimeShape extended_rhs_shape =
60       RuntimeShape::ExtendedShape(5, rhs_shape);
61 
62   const int batch_dim0 = batch_matmul::broadcast_dim(
63       extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
64   const int batch_dim1 = batch_matmul::broadcast_dim(
65       extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
66   const int batch_dim2 = batch_matmul::broadcast_dim(
67       extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
68 
69   const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
70   const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
71   const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
72   const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
73   const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
74   const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
75 
76   // Set params for each matrix multiply.
77   const int lhs_rows = extended_lhs_shape.Dims(3);
78   const int rhs_cols = extended_rhs_shape.Dims(4);
79   const int accum_depth = extended_lhs_shape.Dims(4);
80 
81   for (int b0 = 0; b0 < batch_dim0; ++b0) {
82     const Ta* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
83     const Tb* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
84     for (int b1 = 0; b1 < batch_dim1; ++b1) {
85       const Ta* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
86       const Tb* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
87       for (int b2 = 0; b2 < batch_dim2; ++b2) {
88         const Ta* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
89         const Tb* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
90         Tout* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
91                                        b1 * batch_dim2 + b2) *
92                                           lhs_rows * rhs_cols;
93         for (int j = 0; j < rhs_cols; ++j) {
94           for (int i = 0; i < lhs_rows; ++i) {
95             Tout total = 0;
96             for (int k = 0; k < accum_depth; ++k) {
97               total += static_cast<Tout>(lhs_ptr2[accum_depth * i + k]) *
98                        static_cast<Tout>(rhs_ptr2[j * accum_depth + k]);
99             }
100             int idx = lhs_rows * j + i;
101             out_ptr[idx] = total;
102           }
103         }
104       }
105     }
106   }
107 }
108 
BatchMatMul(const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const float * scaling_factors,const int32_t * input_offset,int32_t * row_sums,const RuntimeShape & output_shape,float * output_data,bool * compute_row_sums)109 inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
110                         const RuntimeShape& rhs_shape, const int8_t* rhs_data,
111                         const float* scaling_factors,
112                         const int32_t* input_offset, int32_t* row_sums,
113                         const RuntimeShape& output_shape, float* output_data,
114                         bool* compute_row_sums) {
115   const RuntimeShape extended_lhs_shape =
116       RuntimeShape::ExtendedShape(5, lhs_shape);
117   const RuntimeShape extended_rhs_shape =
118       RuntimeShape::ExtendedShape(5, rhs_shape);
119 
120   const int batch_dim0 = batch_matmul::broadcast_dim(
121       extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
122   const int batch_dim1 = batch_matmul::broadcast_dim(
123       extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
124   const int batch_dim2 = batch_matmul::broadcast_dim(
125       extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
126 
127   const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
128   const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
129   const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
130   const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
131   const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
132   const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
133 
134   // Set params for each matrix multiply.
135   const int lhs_rows = extended_lhs_shape.Dims(3);
136   const int rhs_cols = extended_rhs_shape.Dims(4);
137   const int accum_depth = extended_lhs_shape.Dims(4);
138 
139   const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols;
140   const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols;
141   const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols;
142   const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows;
143   const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows;
144   const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows;
145 
146   if (!compute_row_sums || *compute_row_sums) {
147     int num_weights_matrices = 1;
148     for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
149       num_weights_matrices *= extended_lhs_shape.Dims(i);
150     }
151     tensor_utils::ReductionSumVector(
152         lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth);
153     if (compute_row_sums) {
154       *compute_row_sums = false;
155     }
156   }
157 
158   for (int b0 = 0; b0 < batch_dim0; ++b0) {
159     const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
160     const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
161     const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0);
162     const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0);
163     const int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0);
164     for (int b1 = 0; b1 < batch_dim1; ++b1) {
165       const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
166       const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
167       const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1);
168       const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1);
169       const int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1);
170       for (int b2 = 0; b2 < batch_dim2; ++b2) {
171         const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
172         const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
173         const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2);
174         const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2);
175         const int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2);
176         float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
177                                         b1 * batch_dim2 + b2) *
178                                            lhs_rows * rhs_cols;
179         for (int j = 0; j < rhs_cols; ++j) {
180           const float batch_scaling_factor = scale_ptr2[j];
181           const float batch_offset = static_cast<float>(ioff_ptr2[j]);
182           for (int i = 0; i < lhs_rows; ++i) {
183             int32_t total = 0;
184             for (int k = 0; k < accum_depth; ++k) {
185               total +=
186                   lhs_ptr2[accum_depth * i + k] * rhs_ptr2[j * accum_depth + k];
187             }
188             int32_t row_sum = woff_ptr2[i];
189             total -= row_sum * batch_offset;
190             int idx = lhs_rows * j + i;
191             out_ptr[idx] += batch_scaling_factor * total;
192           }
193         }
194       }
195     }
196   }
197 }
198 
199 template <typename T, typename AccumT>
BatchMatMul(const FullyConnectedParams & params,const RuntimeShape & lhs_shape,const T * lhs_data,const RuntimeShape & rhs_shape,const T * rhs_data,const RuntimeShape & output_shape,T * output_data)200 inline void BatchMatMul(const FullyConnectedParams& params,
201                         const RuntimeShape& lhs_shape, const T* lhs_data,
202                         const RuntimeShape& rhs_shape, const T* rhs_data,
203                         const RuntimeShape& output_shape, T* output_data) {
204   const RuntimeShape extended_lhs_shape =
205       RuntimeShape::ExtendedShape(5, lhs_shape);
206   const RuntimeShape extended_rhs_shape =
207       RuntimeShape::ExtendedShape(5, rhs_shape);
208 
209   const int batch_dim0 = batch_matmul::broadcast_dim(
210       extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
211   const int batch_dim1 = batch_matmul::broadcast_dim(
212       extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
213   const int batch_dim2 = batch_matmul::broadcast_dim(
214       extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
215 
216   const int lhs_ext0 = batch_matmul::extent(extended_lhs_shape, 0);
217   const int lhs_ext1 = batch_matmul::extent(extended_lhs_shape, 1);
218   const int lhs_ext2 = batch_matmul::extent(extended_lhs_shape, 2);
219   const int rhs_ext0 = batch_matmul::extent(extended_rhs_shape, 0);
220   const int rhs_ext1 = batch_matmul::extent(extended_rhs_shape, 1);
221   const int rhs_ext2 = batch_matmul::extent(extended_rhs_shape, 2);
222 
223   // Set params for each matrix multiply.
224   const int lhs_rows = extended_lhs_shape.Dims(3);
225   const int rhs_cols = extended_rhs_shape.Dims(4);
226   const int accum_depth = extended_lhs_shape.Dims(4);
227 
228   const int32_t input_offset = params.input_offset;
229   const int32_t filter_offset = params.weights_offset;
230   const int32_t output_offset = params.output_offset;
231   const int32_t output_multiplier = params.output_multiplier;
232   const int output_shift = params.output_shift;
233   const int32_t output_activation_min = params.quantized_activation_min;
234   const int32_t output_activation_max = params.quantized_activation_max;
235   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
236 
237   for (int b0 = 0; b0 < batch_dim0; ++b0) {
238     const T* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
239     const T* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
240     for (int b1 = 0; b1 < batch_dim1; ++b1) {
241       const T* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
242       const T* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
243       for (int b2 = 0; b2 < batch_dim2; ++b2) {
244         const T* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
245         const T* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
246         T* out_ptr = output_data +
247                      ((b0 * batch_dim1 * batch_dim2) + b1 * batch_dim2 + b2) *
248                          lhs_rows * rhs_cols;
249 
250         for (int j = 0; j < rhs_cols; ++j) {
251           for (int i = 0; i < lhs_rows; ++i) {
252             AccumT total = 0;
253             for (int k = 0; k < accum_depth; ++k) {
254               AccumT lhs_val = lhs_ptr2[accum_depth * i + k];
255               AccumT rhs_val = rhs_ptr2[accum_depth * j + k];
256               total += (lhs_val + filter_offset) * (rhs_val + input_offset);
257             }
258             int32_t total_scaled = MultiplyByQuantizedMultiplier(
259                 total, output_multiplier, output_shift);
260             total_scaled += output_offset;
261             total_scaled = std::max(total_scaled, output_activation_min);
262             total_scaled = std::min(total_scaled, output_activation_max);
263             const int idx = lhs_rows * j + i;
264             out_ptr[idx] = static_cast<T>(total_scaled);
265           }
266         }
267       }
268     }
269   }
270 }
271 
272 }  // namespace reference_ops
273 }  // namespace tflite
274 
275 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BATCH_MATMUL_H_
276