xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/batch_matmul.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
20 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 
25 namespace tflite {
26 namespace optimized_ops {
27 
BatchMatMul(const RuntimeShape & lhs_shape,const float * lhs_data,const RuntimeShape & rhs_shape,const float * rhs_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * context)28 inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
29                         const RuntimeShape& rhs_shape, const float* rhs_data,
30                         const RuntimeShape& output_shape, float* output_data,
31                         CpuBackendContext* context) {
32   using ::tflite::cpu_backend_gemm::Gemm;
33   using ::tflite::cpu_backend_gemm::GemmParams;
34   using ::tflite::cpu_backend_gemm::MatrixParams;
35   const RuntimeShape extended_lhs_shape =
36       RuntimeShape::ExtendedShape(5, lhs_shape);
37   const RuntimeShape extended_rhs_shape =
38       RuntimeShape::ExtendedShape(5, rhs_shape);
39 
40   // Determine which dimension is the broadcast dimension.
41   auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
42     if (lhs_dim == rhs_dim) return lhs_dim;
43     if (lhs_dim == 1) return rhs_dim;
44     TFLITE_DCHECK_EQ(rhs_dim, 1);
45     return lhs_dim;
46   };
47 
48   // Compute the "extent" for iterating on this dimension.
49   // If we are broadcasting, then don't advance (i.e return 0).
50   auto extent = [](const RuntimeShape& shape, int x) {
51     if (shape.Dims(x) == 1) {
52       return 0;
53     }
54     int prod = 1;
55     for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
56       prod *= shape.Dims(i);
57     }
58     return prod;
59   };
60 
61   const int batch_dim0 =
62       broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
63   const int batch_dim1 =
64       broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
65   const int batch_dim2 =
66       broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
67 
68   const int lhs_ext0 = extent(extended_lhs_shape, 0);
69   const int lhs_ext1 = extent(extended_lhs_shape, 1);
70   const int lhs_ext2 = extent(extended_lhs_shape, 2);
71   const int rhs_ext0 = extent(extended_rhs_shape, 0);
72   const int rhs_ext1 = extent(extended_rhs_shape, 1);
73   const int rhs_ext2 = extent(extended_rhs_shape, 2);
74 
75   // Set params for each matrix multiply.
76   const int lhs_rows = extended_lhs_shape.Dims(3);
77   const int rhs_cols = extended_rhs_shape.Dims(4);
78   const int accum_depth = extended_lhs_shape.Dims(4);
79 
80   MatrixParams<float> lhs_params;
81   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
82   lhs_params.rows = lhs_rows;
83   lhs_params.cols = accum_depth;
84 
85   MatrixParams<float> rhs_params;
86   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
87   rhs_params.rows = accum_depth;
88   rhs_params.cols = rhs_cols;
89 
90   MatrixParams<float> dst_params;
91   dst_params.order = cpu_backend_gemm::Order::kColMajor;
92   dst_params.rows = lhs_rows;
93   dst_params.cols = rhs_cols;
94 
95   for (int b0 = 0; b0 < batch_dim0; ++b0) {
96     const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
97     const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
98     for (int b1 = 0; b1 < batch_dim1; ++b1) {
99       const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
100       const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
101       for (int b2 = 0; b2 < batch_dim2; ++b2) {
102         const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
103         const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
104         float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
105                                         b1 * batch_dim2 + b2) *
106                                            lhs_rows * rhs_cols;
107         GemmParams<float, float> gemm_params;
108         cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
109                                dst_params, out_ptr, gemm_params, context);
110       }
111     }
112   }
113 }
114 
BatchMatMul(const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const float * scaling_factors,const int32_t * input_offset,int32_t * row_sums,const RuntimeShape & output_shape,int32_t * accum_scratch,float * output_data,bool * compute_row_sums,CpuBackendContext * context)115 inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
116                         const RuntimeShape& rhs_shape, const int8_t* rhs_data,
117                         const float* scaling_factors,
118                         const int32_t* input_offset, int32_t* row_sums,
119                         const RuntimeShape& output_shape,
120                         int32_t* accum_scratch, float* output_data,
121                         bool* compute_row_sums, CpuBackendContext* context) {
122   using ::tflite::cpu_backend_gemm::Gemm;
123   using ::tflite::cpu_backend_gemm::GemmParams;
124   using ::tflite::cpu_backend_gemm::MatrixParams;
125 
126   const RuntimeShape extended_lhs_shape =
127       RuntimeShape::ExtendedShape(5, lhs_shape);
128   const RuntimeShape extended_rhs_shape =
129       RuntimeShape::ExtendedShape(5, rhs_shape);
130 
131   // Determine which dimension is the broadcast dimension.
132   auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
133     if (lhs_dim == rhs_dim) return lhs_dim;
134     if (lhs_dim == 1) return rhs_dim;
135     TFLITE_DCHECK_EQ(rhs_dim, 1);
136     return lhs_dim;
137   };
138 
139   // Compute the "extent" for iterating on this dimension.
140   // If we are broadcasting, then don't advance (i.e return 0).
141   auto extent = [](const RuntimeShape& shape, int x) {
142     if (shape.Dims(x) == 1) {
143       return 0;
144     }
145     int prod = 1;
146     for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
147       prod *= shape.Dims(i);
148     }
149     return prod;
150   };
151 
152   const int batch_dim0 =
153       broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
154   const int batch_dim1 =
155       broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
156   const int batch_dim2 =
157       broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
158 
159   const int lhs_ext0 = extent(extended_lhs_shape, 0);
160   const int lhs_ext1 = extent(extended_lhs_shape, 1);
161   const int lhs_ext2 = extent(extended_lhs_shape, 2);
162   const int rhs_ext0 = extent(extended_rhs_shape, 0);
163   const int rhs_ext1 = extent(extended_rhs_shape, 1);
164   const int rhs_ext2 = extent(extended_rhs_shape, 2);
165 
166   // Set params for each matrix multiply.
167   const int lhs_rows = extended_lhs_shape.Dims(3);
168   const int rhs_cols = extended_rhs_shape.Dims(4);
169   const int accum_depth = extended_lhs_shape.Dims(4);
170 
171   const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols;
172   const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols;
173   const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols;
174   const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows;
175   const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows;
176   const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows;
177 
178   if (!compute_row_sums || *compute_row_sums) {
179     int num_weights_matrices = 1;
180     for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
181       num_weights_matrices *= extended_lhs_shape.Dims(i);
182     }
183     tensor_utils::ReductionSumVector(
184         lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth);
185     if (compute_row_sums) {
186       *compute_row_sums = false;
187     }
188   }
189 
190   MatrixParams<int8_t> lhs_params;
191   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
192   lhs_params.rows = lhs_rows;
193   lhs_params.cols = accum_depth;
194 
195   MatrixParams<int8_t> rhs_params;
196   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
197   rhs_params.rows = accum_depth;
198   rhs_params.cols = rhs_cols;
199 
200   MatrixParams<int32_t> dst_params;
201   dst_params.order = cpu_backend_gemm::Order::kColMajor;
202   dst_params.rows = lhs_rows;
203   dst_params.cols = rhs_cols;
204 
205   for (int b0 = 0; b0 < batch_dim0; ++b0) {
206     const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
207     const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
208     const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0);
209     const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0);
210     const int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0);
211     for (int b1 = 0; b1 < batch_dim1; ++b1) {
212       const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
213       const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
214       const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1);
215       const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1);
216       const int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1);
217       for (int b2 = 0; b2 < batch_dim2; ++b2) {
218         const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
219         const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
220         const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2);
221         const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2);
222         const int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2);
223         float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
224                                         b1 * batch_dim2 + b2) *
225                                            lhs_rows * rhs_cols;
226         GemmParams<int32_t, int32_t> gemm_params;
227         cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
228                                dst_params, accum_scratch, gemm_params, context);
229         for (int j = 0; j < rhs_cols; ++j) {
230           const float batch_scaling_factor = scale_ptr2[j];
231           const float batch_offset = static_cast<float>(ioff_ptr2[j]);
232           int i = 0;
233 #ifdef USE_NEON
234           const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor);
235           const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor);
236           const int32x4_t input_offset0 = vdupq_n_s32(-batch_offset);
237           const int32x4_t input_offset1 = vdupq_n_s32(-batch_offset);
238           for (; i < lhs_rows - 8; i += 8) {
239             // Load the row sums;
240             const int32x4_t row_sum0 = vld1q_s32(woff_ptr2 + i);
241             const int32x4_t row_sum1 = vld1q_s32(woff_ptr2 + i + 4);
242             // Load the accumulated values.
243             int idx = lhs_rows * j + i;
244             const int32x4_t scratch_val0 = vld1q_s32(accum_scratch + idx);
245             const int32x4_t scratch_val1 = vld1q_s32(accum_scratch + idx + 4);
246             const int32x4_t dotprod0 =
247                 vmlaq_s32(scratch_val0, row_sum0, input_offset0);
248             const int32x4_t dotprod1 =
249                 vmlaq_s32(scratch_val1, row_sum1, input_offset1);
250             const float32x4_t float_val0 = vcvtq_f32_s32(dotprod0);
251             const float32x4_t float_val1 = vcvtq_f32_s32(dotprod1);
252             const float32x4_t result0 = vmlaq_f32(vld1q_f32(out_ptr + idx),
253                                                   float_val0, scaling_factor0);
254             const float32x4_t result1 = vmlaq_f32(vld1q_f32(out_ptr + idx + 4),
255                                                   float_val1, scaling_factor1);
256             vst1q_f32(out_ptr + idx, result0);
257             vst1q_f32(out_ptr + idx + 4, result1);
258           }
259 #endif  // USE_NEON
260           for (; i < lhs_rows; ++i) {
261             int idx = lhs_rows * j + i;
262             accum_scratch[idx] -= woff_ptr2[i] * batch_offset;
263             out_ptr[idx] += batch_scaling_factor * accum_scratch[idx];
264           }
265         }
266       }
267     }
268   }
269 }
270 
BatchMatMul(const FullyConnectedParams & params,const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const RuntimeShape & output_shape,int8_t * output_data,CpuBackendContext * context)271 inline void BatchMatMul(const FullyConnectedParams& params,
272                         const RuntimeShape& lhs_shape, const int8_t* lhs_data,
273                         const RuntimeShape& rhs_shape, const int8_t* rhs_data,
274                         const RuntimeShape& output_shape, int8_t* output_data,
275                         CpuBackendContext* context) {
276   using ::tflite::cpu_backend_gemm::Gemm;
277   using ::tflite::cpu_backend_gemm::GemmParams;
278   using ::tflite::cpu_backend_gemm::MatrixParams;
279 
280   const RuntimeShape extended_lhs_shape =
281       RuntimeShape::ExtendedShape(5, lhs_shape);
282   const RuntimeShape extended_rhs_shape =
283       RuntimeShape::ExtendedShape(5, rhs_shape);
284 
285   // Determine which dimension is the broadcast dimension.
286   auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
287     if (lhs_dim == rhs_dim) return lhs_dim;
288     if (lhs_dim == 1) return rhs_dim;
289     TFLITE_DCHECK_EQ(rhs_dim, 1);
290     return lhs_dim;
291   };
292 
293   // Compute the "extent" for iterating on this dimension.
294   // If we are broadcasting, then don't advance (i.e return 0).
295   auto extent = [](const RuntimeShape& shape, int x) {
296     if (shape.Dims(x) == 1) {
297       return 0;
298     }
299     int prod = 1;
300     for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
301       prod *= shape.Dims(i);
302     }
303     return prod;
304   };
305 
306   const int batch_dim0 =
307       broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
308   const int batch_dim1 =
309       broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
310   const int batch_dim2 =
311       broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
312 
313   const int lhs_ext0 = extent(extended_lhs_shape, 0);
314   const int lhs_ext1 = extent(extended_lhs_shape, 1);
315   const int lhs_ext2 = extent(extended_lhs_shape, 2);
316   const int rhs_ext0 = extent(extended_rhs_shape, 0);
317   const int rhs_ext1 = extent(extended_rhs_shape, 1);
318   const int rhs_ext2 = extent(extended_rhs_shape, 2);
319 
320   // Set params for each matrix multiply.
321   const int lhs_rows = extended_lhs_shape.Dims(3);
322   const int rhs_cols = extended_rhs_shape.Dims(4);
323   const int accum_depth = extended_lhs_shape.Dims(4);
324 
325   const int32 input_offset = params.input_offset;
326   const int32 filter_offset = params.weights_offset;
327   const int32 output_offset = params.output_offset;
328   const int32 output_multiplier = params.output_multiplier;
329   const int output_shift = params.output_shift;
330   const int32 output_activation_min = params.quantized_activation_min;
331   const int32 output_activation_max = params.quantized_activation_max;
332   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
333 
334   MatrixParams<int8_t> lhs_params;
335   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
336   lhs_params.rows = lhs_rows;
337   lhs_params.cols = accum_depth;
338   lhs_params.zero_point = -filter_offset;
339 
340   MatrixParams<int8_t> rhs_params;
341   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
342   rhs_params.rows = accum_depth;
343   rhs_params.cols = rhs_cols;
344   rhs_params.zero_point = -input_offset;
345 
346   MatrixParams<int8_t> dst_params;
347   dst_params.order = cpu_backend_gemm::Order::kColMajor;
348   dst_params.rows = lhs_rows;
349   dst_params.cols = rhs_cols;
350   dst_params.zero_point = output_offset;
351 
352   for (int b0 = 0; b0 < batch_dim0; ++b0) {
353     const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
354     const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
355     for (int b1 = 0; b1 < batch_dim1; ++b1) {
356       const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
357       const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
358       for (int b2 = 0; b2 < batch_dim2; ++b2) {
359         const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
360         const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
361         int8_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
362                                          b1 * batch_dim2 + b2) *
363                                             lhs_rows * rhs_cols;
364 
365         GemmParams<int32_t, int8_t> gemm_params;
366         gemm_params.clamp_min = output_activation_min;
367         gemm_params.clamp_max = output_activation_max;
368         gemm_params.multiplier_fixedpoint = output_multiplier;
369         gemm_params.multiplier_exponent = output_shift;
370         cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
371                                dst_params, out_ptr, gemm_params, context);
372       }
373     }
374   }
375 }
376 
BatchMatMul(const FullyConnectedParams & params,const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const RuntimeShape & output_shape,int32_t * output_data,CpuBackendContext * context)377 inline void BatchMatMul(const FullyConnectedParams& params,
378                         const RuntimeShape& lhs_shape, const int8_t* lhs_data,
379                         const RuntimeShape& rhs_shape, const int8_t* rhs_data,
380                         const RuntimeShape& output_shape, int32_t* output_data,
381                         CpuBackendContext* context) {
382   using ::tflite::cpu_backend_gemm::Gemm;
383   using ::tflite::cpu_backend_gemm::GemmParams;
384   using ::tflite::cpu_backend_gemm::MatrixParams;
385 
386   const RuntimeShape extended_lhs_shape =
387       RuntimeShape::ExtendedShape(5, lhs_shape);
388   const RuntimeShape extended_rhs_shape =
389       RuntimeShape::ExtendedShape(5, rhs_shape);
390 
391   // Determine which dimension is the broadcast dimension.
392   auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
393     if (lhs_dim == rhs_dim) return lhs_dim;
394     if (lhs_dim == 1) return rhs_dim;
395     TFLITE_DCHECK_EQ(rhs_dim, 1);
396     return lhs_dim;
397   };
398 
399   // Compute the "extent" for iterating on this dimension.
400   // If we are broadcasting, then don't advance (i.e return 0).
401   auto extent = [](const RuntimeShape& shape, int x) {
402     if (shape.Dims(x) == 1) {
403       return 0;
404     }
405     int prod = 1;
406     for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
407       prod *= shape.Dims(i);
408     }
409     return prod;
410   };
411 
412   const int batch_dim0 =
413       broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
414   const int batch_dim1 =
415       broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
416   const int batch_dim2 =
417       broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
418 
419   const int lhs_ext0 = extent(extended_lhs_shape, 0);
420   const int lhs_ext1 = extent(extended_lhs_shape, 1);
421   const int lhs_ext2 = extent(extended_lhs_shape, 2);
422   const int rhs_ext0 = extent(extended_rhs_shape, 0);
423   const int rhs_ext1 = extent(extended_rhs_shape, 1);
424   const int rhs_ext2 = extent(extended_rhs_shape, 2);
425 
426   // Set params for each matrix multiply.
427   const int lhs_rows = extended_lhs_shape.Dims(3);
428   const int rhs_cols = extended_rhs_shape.Dims(4);
429   const int accum_depth = extended_lhs_shape.Dims(4);
430 
431   const int32 input_offset = params.input_offset;
432   const int32 weights_offset = params.weights_offset;
433   const int32 output_offset = params.output_offset;
434   const int32 output_multiplier = params.output_multiplier;
435   const int output_shift = params.output_shift;
436   const int32 output_activation_min = params.quantized_activation_min;
437   const int32 output_activation_max = params.quantized_activation_max;
438   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
439 
440   MatrixParams<int8_t> lhs_params;
441   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
442   lhs_params.rows = lhs_rows;
443   lhs_params.cols = accum_depth;
444   lhs_params.zero_point = -weights_offset;
445 
446   MatrixParams<int8_t> rhs_params;
447   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
448   rhs_params.rows = accum_depth;
449   rhs_params.cols = rhs_cols;
450   rhs_params.zero_point = -input_offset;
451 
452   MatrixParams<int32_t> dst_params;
453   dst_params.order = cpu_backend_gemm::Order::kColMajor;
454   dst_params.rows = lhs_rows;
455   dst_params.cols = rhs_cols;
456   dst_params.zero_point = output_offset;
457 
458   for (int b0 = 0; b0 < batch_dim0; ++b0) {
459     const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
460     const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
461     for (int b1 = 0; b1 < batch_dim1; ++b1) {
462       const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
463       const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
464       for (int b2 = 0; b2 < batch_dim2; ++b2) {
465         const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
466         const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
467         int32_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
468                                           b1 * batch_dim2 + b2) *
469                                              lhs_rows * rhs_cols;
470 
471         GemmParams<int32_t, int32_t> gemm_params;
472         gemm_params.clamp_min = output_activation_min;
473         gemm_params.clamp_max = output_activation_max;
474         gemm_params.multiplier_fixedpoint = output_multiplier;
475         gemm_params.multiplier_exponent = output_shift;
476         cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
477                                dst_params, out_ptr, gemm_params, context);
478       }
479     }
480   }
481 }
482 
483 }  // namespace optimized_ops
484 }  // namespace tflite
485 
486 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
487