1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
20 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24
25 namespace tflite {
26 namespace optimized_ops {
27
BatchMatMul(const RuntimeShape & lhs_shape,const float * lhs_data,const RuntimeShape & rhs_shape,const float * rhs_data,const RuntimeShape & output_shape,float * output_data,CpuBackendContext * context)28 inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data,
29 const RuntimeShape& rhs_shape, const float* rhs_data,
30 const RuntimeShape& output_shape, float* output_data,
31 CpuBackendContext* context) {
32 using ::tflite::cpu_backend_gemm::Gemm;
33 using ::tflite::cpu_backend_gemm::GemmParams;
34 using ::tflite::cpu_backend_gemm::MatrixParams;
35 const RuntimeShape extended_lhs_shape =
36 RuntimeShape::ExtendedShape(5, lhs_shape);
37 const RuntimeShape extended_rhs_shape =
38 RuntimeShape::ExtendedShape(5, rhs_shape);
39
40 // Determine which dimension is the broadcast dimension.
41 auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
42 if (lhs_dim == rhs_dim) return lhs_dim;
43 if (lhs_dim == 1) return rhs_dim;
44 TFLITE_DCHECK_EQ(rhs_dim, 1);
45 return lhs_dim;
46 };
47
48 // Compute the "extent" for iterating on this dimension.
49 // If we are broadcasting, then don't advance (i.e return 0).
50 auto extent = [](const RuntimeShape& shape, int x) {
51 if (shape.Dims(x) == 1) {
52 return 0;
53 }
54 int prod = 1;
55 for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
56 prod *= shape.Dims(i);
57 }
58 return prod;
59 };
60
61 const int batch_dim0 =
62 broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
63 const int batch_dim1 =
64 broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
65 const int batch_dim2 =
66 broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
67
68 const int lhs_ext0 = extent(extended_lhs_shape, 0);
69 const int lhs_ext1 = extent(extended_lhs_shape, 1);
70 const int lhs_ext2 = extent(extended_lhs_shape, 2);
71 const int rhs_ext0 = extent(extended_rhs_shape, 0);
72 const int rhs_ext1 = extent(extended_rhs_shape, 1);
73 const int rhs_ext2 = extent(extended_rhs_shape, 2);
74
75 // Set params for each matrix multiply.
76 const int lhs_rows = extended_lhs_shape.Dims(3);
77 const int rhs_cols = extended_rhs_shape.Dims(4);
78 const int accum_depth = extended_lhs_shape.Dims(4);
79
80 MatrixParams<float> lhs_params;
81 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
82 lhs_params.rows = lhs_rows;
83 lhs_params.cols = accum_depth;
84
85 MatrixParams<float> rhs_params;
86 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
87 rhs_params.rows = accum_depth;
88 rhs_params.cols = rhs_cols;
89
90 MatrixParams<float> dst_params;
91 dst_params.order = cpu_backend_gemm::Order::kColMajor;
92 dst_params.rows = lhs_rows;
93 dst_params.cols = rhs_cols;
94
95 for (int b0 = 0; b0 < batch_dim0; ++b0) {
96 const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
97 const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
98 for (int b1 = 0; b1 < batch_dim1; ++b1) {
99 const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
100 const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
101 for (int b2 = 0; b2 < batch_dim2; ++b2) {
102 const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
103 const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
104 float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
105 b1 * batch_dim2 + b2) *
106 lhs_rows * rhs_cols;
107 GemmParams<float, float> gemm_params;
108 cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
109 dst_params, out_ptr, gemm_params, context);
110 }
111 }
112 }
113 }
114
BatchMatMul(const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const float * scaling_factors,const int32_t * input_offset,int32_t * row_sums,const RuntimeShape & output_shape,int32_t * accum_scratch,float * output_data,bool * compute_row_sums,CpuBackendContext * context)115 inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
116 const RuntimeShape& rhs_shape, const int8_t* rhs_data,
117 const float* scaling_factors,
118 const int32_t* input_offset, int32_t* row_sums,
119 const RuntimeShape& output_shape,
120 int32_t* accum_scratch, float* output_data,
121 bool* compute_row_sums, CpuBackendContext* context) {
122 using ::tflite::cpu_backend_gemm::Gemm;
123 using ::tflite::cpu_backend_gemm::GemmParams;
124 using ::tflite::cpu_backend_gemm::MatrixParams;
125
126 const RuntimeShape extended_lhs_shape =
127 RuntimeShape::ExtendedShape(5, lhs_shape);
128 const RuntimeShape extended_rhs_shape =
129 RuntimeShape::ExtendedShape(5, rhs_shape);
130
131 // Determine which dimension is the broadcast dimension.
132 auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
133 if (lhs_dim == rhs_dim) return lhs_dim;
134 if (lhs_dim == 1) return rhs_dim;
135 TFLITE_DCHECK_EQ(rhs_dim, 1);
136 return lhs_dim;
137 };
138
139 // Compute the "extent" for iterating on this dimension.
140 // If we are broadcasting, then don't advance (i.e return 0).
141 auto extent = [](const RuntimeShape& shape, int x) {
142 if (shape.Dims(x) == 1) {
143 return 0;
144 }
145 int prod = 1;
146 for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
147 prod *= shape.Dims(i);
148 }
149 return prod;
150 };
151
152 const int batch_dim0 =
153 broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
154 const int batch_dim1 =
155 broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
156 const int batch_dim2 =
157 broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
158
159 const int lhs_ext0 = extent(extended_lhs_shape, 0);
160 const int lhs_ext1 = extent(extended_lhs_shape, 1);
161 const int lhs_ext2 = extent(extended_lhs_shape, 2);
162 const int rhs_ext0 = extent(extended_rhs_shape, 0);
163 const int rhs_ext1 = extent(extended_rhs_shape, 1);
164 const int rhs_ext2 = extent(extended_rhs_shape, 2);
165
166 // Set params for each matrix multiply.
167 const int lhs_rows = extended_lhs_shape.Dims(3);
168 const int rhs_cols = extended_rhs_shape.Dims(4);
169 const int accum_depth = extended_lhs_shape.Dims(4);
170
171 const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols;
172 const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols;
173 const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols;
174 const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows;
175 const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows;
176 const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows;
177
178 if (!compute_row_sums || *compute_row_sums) {
179 int num_weights_matrices = 1;
180 for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) {
181 num_weights_matrices *= extended_lhs_shape.Dims(i);
182 }
183 tensor_utils::ReductionSumVector(
184 lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth);
185 if (compute_row_sums) {
186 *compute_row_sums = false;
187 }
188 }
189
190 MatrixParams<int8_t> lhs_params;
191 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
192 lhs_params.rows = lhs_rows;
193 lhs_params.cols = accum_depth;
194
195 MatrixParams<int8_t> rhs_params;
196 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
197 rhs_params.rows = accum_depth;
198 rhs_params.cols = rhs_cols;
199
200 MatrixParams<int32_t> dst_params;
201 dst_params.order = cpu_backend_gemm::Order::kColMajor;
202 dst_params.rows = lhs_rows;
203 dst_params.cols = rhs_cols;
204
205 for (int b0 = 0; b0 < batch_dim0; ++b0) {
206 const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
207 const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
208 const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0);
209 const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0);
210 const int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0);
211 for (int b1 = 0; b1 < batch_dim1; ++b1) {
212 const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
213 const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
214 const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1);
215 const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1);
216 const int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1);
217 for (int b2 = 0; b2 < batch_dim2; ++b2) {
218 const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
219 const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
220 const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2);
221 const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2);
222 const int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2);
223 float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
224 b1 * batch_dim2 + b2) *
225 lhs_rows * rhs_cols;
226 GemmParams<int32_t, int32_t> gemm_params;
227 cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
228 dst_params, accum_scratch, gemm_params, context);
229 for (int j = 0; j < rhs_cols; ++j) {
230 const float batch_scaling_factor = scale_ptr2[j];
231 const float batch_offset = static_cast<float>(ioff_ptr2[j]);
232 int i = 0;
233 #ifdef USE_NEON
234 const float32x4_t scaling_factor0 = vdupq_n_f32(batch_scaling_factor);
235 const float32x4_t scaling_factor1 = vdupq_n_f32(batch_scaling_factor);
236 const int32x4_t input_offset0 = vdupq_n_s32(-batch_offset);
237 const int32x4_t input_offset1 = vdupq_n_s32(-batch_offset);
238 for (; i < lhs_rows - 8; i += 8) {
239 // Load the row sums;
240 const int32x4_t row_sum0 = vld1q_s32(woff_ptr2 + i);
241 const int32x4_t row_sum1 = vld1q_s32(woff_ptr2 + i + 4);
242 // Load the accumulated values.
243 int idx = lhs_rows * j + i;
244 const int32x4_t scratch_val0 = vld1q_s32(accum_scratch + idx);
245 const int32x4_t scratch_val1 = vld1q_s32(accum_scratch + idx + 4);
246 const int32x4_t dotprod0 =
247 vmlaq_s32(scratch_val0, row_sum0, input_offset0);
248 const int32x4_t dotprod1 =
249 vmlaq_s32(scratch_val1, row_sum1, input_offset1);
250 const float32x4_t float_val0 = vcvtq_f32_s32(dotprod0);
251 const float32x4_t float_val1 = vcvtq_f32_s32(dotprod1);
252 const float32x4_t result0 = vmlaq_f32(vld1q_f32(out_ptr + idx),
253 float_val0, scaling_factor0);
254 const float32x4_t result1 = vmlaq_f32(vld1q_f32(out_ptr + idx + 4),
255 float_val1, scaling_factor1);
256 vst1q_f32(out_ptr + idx, result0);
257 vst1q_f32(out_ptr + idx + 4, result1);
258 }
259 #endif // USE_NEON
260 for (; i < lhs_rows; ++i) {
261 int idx = lhs_rows * j + i;
262 accum_scratch[idx] -= woff_ptr2[i] * batch_offset;
263 out_ptr[idx] += batch_scaling_factor * accum_scratch[idx];
264 }
265 }
266 }
267 }
268 }
269 }
270
BatchMatMul(const FullyConnectedParams & params,const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const RuntimeShape & output_shape,int8_t * output_data,CpuBackendContext * context)271 inline void BatchMatMul(const FullyConnectedParams& params,
272 const RuntimeShape& lhs_shape, const int8_t* lhs_data,
273 const RuntimeShape& rhs_shape, const int8_t* rhs_data,
274 const RuntimeShape& output_shape, int8_t* output_data,
275 CpuBackendContext* context) {
276 using ::tflite::cpu_backend_gemm::Gemm;
277 using ::tflite::cpu_backend_gemm::GemmParams;
278 using ::tflite::cpu_backend_gemm::MatrixParams;
279
280 const RuntimeShape extended_lhs_shape =
281 RuntimeShape::ExtendedShape(5, lhs_shape);
282 const RuntimeShape extended_rhs_shape =
283 RuntimeShape::ExtendedShape(5, rhs_shape);
284
285 // Determine which dimension is the broadcast dimension.
286 auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
287 if (lhs_dim == rhs_dim) return lhs_dim;
288 if (lhs_dim == 1) return rhs_dim;
289 TFLITE_DCHECK_EQ(rhs_dim, 1);
290 return lhs_dim;
291 };
292
293 // Compute the "extent" for iterating on this dimension.
294 // If we are broadcasting, then don't advance (i.e return 0).
295 auto extent = [](const RuntimeShape& shape, int x) {
296 if (shape.Dims(x) == 1) {
297 return 0;
298 }
299 int prod = 1;
300 for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
301 prod *= shape.Dims(i);
302 }
303 return prod;
304 };
305
306 const int batch_dim0 =
307 broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
308 const int batch_dim1 =
309 broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
310 const int batch_dim2 =
311 broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
312
313 const int lhs_ext0 = extent(extended_lhs_shape, 0);
314 const int lhs_ext1 = extent(extended_lhs_shape, 1);
315 const int lhs_ext2 = extent(extended_lhs_shape, 2);
316 const int rhs_ext0 = extent(extended_rhs_shape, 0);
317 const int rhs_ext1 = extent(extended_rhs_shape, 1);
318 const int rhs_ext2 = extent(extended_rhs_shape, 2);
319
320 // Set params for each matrix multiply.
321 const int lhs_rows = extended_lhs_shape.Dims(3);
322 const int rhs_cols = extended_rhs_shape.Dims(4);
323 const int accum_depth = extended_lhs_shape.Dims(4);
324
325 const int32 input_offset = params.input_offset;
326 const int32 filter_offset = params.weights_offset;
327 const int32 output_offset = params.output_offset;
328 const int32 output_multiplier = params.output_multiplier;
329 const int output_shift = params.output_shift;
330 const int32 output_activation_min = params.quantized_activation_min;
331 const int32 output_activation_max = params.quantized_activation_max;
332 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
333
334 MatrixParams<int8_t> lhs_params;
335 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
336 lhs_params.rows = lhs_rows;
337 lhs_params.cols = accum_depth;
338 lhs_params.zero_point = -filter_offset;
339
340 MatrixParams<int8_t> rhs_params;
341 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
342 rhs_params.rows = accum_depth;
343 rhs_params.cols = rhs_cols;
344 rhs_params.zero_point = -input_offset;
345
346 MatrixParams<int8_t> dst_params;
347 dst_params.order = cpu_backend_gemm::Order::kColMajor;
348 dst_params.rows = lhs_rows;
349 dst_params.cols = rhs_cols;
350 dst_params.zero_point = output_offset;
351
352 for (int b0 = 0; b0 < batch_dim0; ++b0) {
353 const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
354 const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
355 for (int b1 = 0; b1 < batch_dim1; ++b1) {
356 const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
357 const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
358 for (int b2 = 0; b2 < batch_dim2; ++b2) {
359 const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
360 const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
361 int8_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
362 b1 * batch_dim2 + b2) *
363 lhs_rows * rhs_cols;
364
365 GemmParams<int32_t, int8_t> gemm_params;
366 gemm_params.clamp_min = output_activation_min;
367 gemm_params.clamp_max = output_activation_max;
368 gemm_params.multiplier_fixedpoint = output_multiplier;
369 gemm_params.multiplier_exponent = output_shift;
370 cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
371 dst_params, out_ptr, gemm_params, context);
372 }
373 }
374 }
375 }
376
BatchMatMul(const FullyConnectedParams & params,const RuntimeShape & lhs_shape,const int8_t * lhs_data,const RuntimeShape & rhs_shape,const int8_t * rhs_data,const RuntimeShape & output_shape,int32_t * output_data,CpuBackendContext * context)377 inline void BatchMatMul(const FullyConnectedParams& params,
378 const RuntimeShape& lhs_shape, const int8_t* lhs_data,
379 const RuntimeShape& rhs_shape, const int8_t* rhs_data,
380 const RuntimeShape& output_shape, int32_t* output_data,
381 CpuBackendContext* context) {
382 using ::tflite::cpu_backend_gemm::Gemm;
383 using ::tflite::cpu_backend_gemm::GemmParams;
384 using ::tflite::cpu_backend_gemm::MatrixParams;
385
386 const RuntimeShape extended_lhs_shape =
387 RuntimeShape::ExtendedShape(5, lhs_shape);
388 const RuntimeShape extended_rhs_shape =
389 RuntimeShape::ExtendedShape(5, rhs_shape);
390
391 // Determine which dimension is the broadcast dimension.
392 auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
393 if (lhs_dim == rhs_dim) return lhs_dim;
394 if (lhs_dim == 1) return rhs_dim;
395 TFLITE_DCHECK_EQ(rhs_dim, 1);
396 return lhs_dim;
397 };
398
399 // Compute the "extent" for iterating on this dimension.
400 // If we are broadcasting, then don't advance (i.e return 0).
401 auto extent = [](const RuntimeShape& shape, int x) {
402 if (shape.Dims(x) == 1) {
403 return 0;
404 }
405 int prod = 1;
406 for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
407 prod *= shape.Dims(i);
408 }
409 return prod;
410 };
411
412 const int batch_dim0 =
413 broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
414 const int batch_dim1 =
415 broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
416 const int batch_dim2 =
417 broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
418
419 const int lhs_ext0 = extent(extended_lhs_shape, 0);
420 const int lhs_ext1 = extent(extended_lhs_shape, 1);
421 const int lhs_ext2 = extent(extended_lhs_shape, 2);
422 const int rhs_ext0 = extent(extended_rhs_shape, 0);
423 const int rhs_ext1 = extent(extended_rhs_shape, 1);
424 const int rhs_ext2 = extent(extended_rhs_shape, 2);
425
426 // Set params for each matrix multiply.
427 const int lhs_rows = extended_lhs_shape.Dims(3);
428 const int rhs_cols = extended_rhs_shape.Dims(4);
429 const int accum_depth = extended_lhs_shape.Dims(4);
430
431 const int32 input_offset = params.input_offset;
432 const int32 weights_offset = params.weights_offset;
433 const int32 output_offset = params.output_offset;
434 const int32 output_multiplier = params.output_multiplier;
435 const int output_shift = params.output_shift;
436 const int32 output_activation_min = params.quantized_activation_min;
437 const int32 output_activation_max = params.quantized_activation_max;
438 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
439
440 MatrixParams<int8_t> lhs_params;
441 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
442 lhs_params.rows = lhs_rows;
443 lhs_params.cols = accum_depth;
444 lhs_params.zero_point = -weights_offset;
445
446 MatrixParams<int8_t> rhs_params;
447 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
448 rhs_params.rows = accum_depth;
449 rhs_params.cols = rhs_cols;
450 rhs_params.zero_point = -input_offset;
451
452 MatrixParams<int32_t> dst_params;
453 dst_params.order = cpu_backend_gemm::Order::kColMajor;
454 dst_params.rows = lhs_rows;
455 dst_params.cols = rhs_cols;
456 dst_params.zero_point = output_offset;
457
458 for (int b0 = 0; b0 < batch_dim0; ++b0) {
459 const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
460 const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
461 for (int b1 = 0; b1 < batch_dim1; ++b1) {
462 const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
463 const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
464 for (int b2 = 0; b2 < batch_dim2; ++b2) {
465 const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
466 const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
467 int32_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
468 b1 * batch_dim2 + b2) *
469 lhs_rows * rhs_cols;
470
471 GemmParams<int32_t, int32_t> gemm_params;
472 gemm_params.clamp_min = output_activation_min;
473 gemm_params.clamp_max = output_activation_max;
474 gemm_params.multiplier_fixedpoint = output_multiplier;
475 gemm_params.multiplier_exponent = output_shift;
476 cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
477 dst_params, out_ptr, gemm_params, context);
478 }
479 }
480 }
481 }
482
483 } // namespace optimized_ops
484 } // namespace tflite
485
486 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_BATCH_MATMUL_H_
487