xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/batch_matmul.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
17 
18 #include <stddef.h>
19 
20 #include <algorithm>
21 #include <cstdint>
22 #include <limits>
23 
24 #include "tensorflow/lite/c/builtin_op_data.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/kernels/cpu_backend_context.h"
27 #include "tensorflow/lite/kernels/internal/compatibility.h"
28 #include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h"
29 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/tensor.h"
32 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
33 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
34 #include "tensorflow/lite/kernels/internal/types.h"
35 #include "tensorflow/lite/kernels/kernel_util.h"
36 
37 namespace tflite {
38 namespace ops {
39 namespace builtin {
40 namespace batch_matmul {
41 
42 static const int kInputLHSTensor = 0;
43 static const int kInputRHSTensor = 1;
44 static const int kOutputTensor = 0;
45 
46 static const int kNumTempTensorsForAdjoints = 2;
47 static const int kNumTempTensorsForHybrid = 5;
48 
49 // This file has two implementations of Transpose.
50 enum KernelType {
51   kReference,
52   kGenericOptimized,
53 };
54 
55 struct OpData {
56   // The scaling factor from input to output (aka the 'real multiplier') can
57   // be represented as a fixed point multiplier plus a left shift.
58   int32_t output_multiplier;
59   int output_shift;
60   // The range of the fused activation layer. For example for kNone and
61   // uint8_t these would be 0 and 255.
62   int32_t output_activation_min;
63   int32_t output_activation_max;
64   // The index of the temporary tensors where we store transposed LHS/RHS.
65   int scratch_tensor_index;
66   bool rhs_transposed;
67   bool compute_row_sums = false;
68 };
69 
70 struct OpContext {
OpContexttflite::ops::builtin::batch_matmul::OpContext71   OpContext(TfLiteContext* context, TfLiteNode* node) {
72     params = reinterpret_cast<TfLiteBatchMatMulParams*>(node->builtin_data);
73     lhs = GetInput(context, node, kInputLHSTensor);
74     rhs = GetInput(context, node, kInputRHSTensor);
75     output = GetOutput(context, node, 0);
76   }
77   TfLiteBatchMatMulParams* params;
78   const TfLiteTensor* lhs;
79   const TfLiteTensor* rhs;
80   TfLiteTensor* output;
81 };
82 
Init(TfLiteContext * context,const char * buffer,size_t length)83 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
84   auto* op_data = new OpData();
85   // If the RHS is constant, we only transpose once.
86   op_data->rhs_transposed = false;
87   // Creates the temp tensors to store the transposed LHS and/or RHS, and
88   // extra buffers for the quantized case.
89   context->AddTensors(context,
90                       kNumTempTensorsForAdjoints + kNumTempTensorsForHybrid,
91                       &op_data->scratch_tensor_index);
92   return op_data;
93 }
94 
Free(TfLiteContext * context,void * buffer)95 void Free(TfLiteContext* context, void* buffer) {
96   delete static_cast<OpData*>(buffer);
97 }
98 
ResizeOutputTensor(TfLiteContext * context,const RuntimeShape & extended_lhs_shape,const RuntimeShape & extended_rhs_shape,bool adj_x,bool adj_y,int output_rank,TfLiteTensor * output)99 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
100                                 const RuntimeShape& extended_lhs_shape,
101                                 const RuntimeShape& extended_rhs_shape,
102                                 bool adj_x, bool adj_y, int output_rank,
103                                 TfLiteTensor* output) {
104   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
105   // Fill in any broadcast dimensions.
106   for (int i = 0; i < output_rank - 2; ++i) {
107     const int lhs_dim = extended_lhs_shape.Dims(i);
108     const int rhs_dim = extended_rhs_shape.Dims(i);
109     int broadcast_dim = lhs_dim;
110     if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
111       broadcast_dim = rhs_dim;
112     }
113     output_shape->data[i] = broadcast_dim;
114   }
115   // Fill in the matmul dimensions.
116   int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
117   int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
118 
119   output_shape->data[output_rank - 2] = extended_lhs_shape.Dims(lhs_rows_index);
120   output_shape->data[output_rank - 1] = extended_rhs_shape.Dims(rhs_cols_index);
121   TfLiteStatus stat = context->ResizeTensor(context, output, output_shape);
122   return stat;
123 }
124 
125 // Initializes temp tensors to store transposed operands.
InitializeTemporaries(TfLiteContext * context,TfLiteNode * node,OpContext * op_context)126 TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
127                                    OpContext* op_context) {
128   // Create temporary tensors to hold transposed LHS/RHS.
129   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
130   const TfLiteTensor* lhs = op_context->lhs;
131   const TfLiteTensor* rhs = op_context->rhs;
132   TfLiteIntArrayFree(node->temporaries);
133   // For "hybrid" quantization, we impose the constraint that the LHS
134   // is float (typically an activation from a prior layer) and the RHS
135   // is quantized int8.
136   bool is_hybrid =
137       (op_context->lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8);
138   if (is_hybrid) {
139     node->temporaries = TfLiteIntArrayCreate(kNumTempTensorsForAdjoints +
140                                              kNumTempTensorsForHybrid);
141   } else {
142     node->temporaries = TfLiteIntArrayCreate(kNumTempTensorsForAdjoints);
143   }
144 
145   const int lhs_rank = NumDimensions(lhs);
146   const int rhs_rank = NumDimensions(rhs);
147   const int batch_size = op_context->params->adj_x
148                              ? lhs->dims->data[lhs_rank - 1]
149                              : lhs->dims->data[lhs_rank - 2];
150   const int num_units = op_context->params->adj_y
151                             ? rhs->dims->data[rhs_rank - 2]
152                             : rhs->dims->data[rhs_rank - 1];
153 
154   // Temp tensor for Transposed LHS;
155   {
156     node->temporaries->data[0] = op_data->scratch_tensor_index;
157     TfLiteTensor* scratch_buffer;
158     TF_LITE_ENSURE_OK(
159         context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
160     TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(lhs_rank);
161     for (int i = 0; i < lhs_rank - 2; ++i) {
162       scratch_buffer_size->data[i] = lhs->dims->data[i];
163     }
164     // Swap last two dimensions.
165     scratch_buffer_size->data[lhs_rank - 2] = lhs->dims->data[lhs_rank - 1];
166     scratch_buffer_size->data[lhs_rank - 1] = lhs->dims->data[lhs_rank - 2];
167 
168     scratch_buffer->type = op_context->lhs->type;
169     scratch_buffer->allocation_type = kTfLiteArenaRw;
170     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
171                                                      scratch_buffer_size));
172   }
173 
174   // We need a temp buffer for the RHS if we need to transpose the RHS. We
175   // transpose by default, so that the two inputs (LHS and RHS) are in a proper
176   // layout for our fast matrix multiplication routines. If the transpose flag
177   // is set by the caller, the data is already in the desired layout.
178   {
179     node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
180     TfLiteTensor* scratch_buffer;
181     TF_LITE_ENSURE_OK(
182         context, GetTemporarySafe(context, node, /*index=*/1, &scratch_buffer));
183     scratch_buffer->name = "BatchMatMul_scratch_buffer";
184     const TfLiteTensor* rhs = op_context->rhs;
185     int rhs_rank = NumDimensions(rhs);
186     TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(rhs_rank);
187     for (int i = 0; i < rhs_rank - 2; ++i) {
188       scratch_buffer_size->data[i] = rhs->dims->data[i];
189     }
190     // Swap last two dimensions.
191     scratch_buffer_size->data[rhs_rank - 2] = rhs->dims->data[rhs_rank - 1];
192     scratch_buffer_size->data[rhs_rank - 1] = rhs->dims->data[rhs_rank - 2];
193 
194     if (IsConstantTensor(op_context->rhs)) {
195       scratch_buffer->allocation_type = kTfLiteArenaRwPersistent;
196     } else {
197       scratch_buffer->allocation_type = kTfLiteArenaRw;
198     }
199     scratch_buffer->type = op_context->rhs->type;
200     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
201                                                      scratch_buffer_size));
202   }
203 
204   // If we have to perform on-the-fly quantization (with quantized weights and
205   // float inputs) first we need to quantize the inputs. Allocate temporary
206   // buffer to store the intermediate quantized values, the batch scaling
207   // factors, the accumulator buffer (optimized version), the input offsets,
208   // and the sums of the rows for each weights matrix.
209   // RHS = weights, LHS = inputs
210   if (is_hybrid) {
211     // Calculate the total number of LHS batches.
212     int num_batches = 1;
213     for (int i = 0; i < lhs_rank - 2; ++i) {
214       num_batches *= lhs->dims->data[i];
215     }
216     int num_weights_matrices = 1;
217     for (int i = 0; i < rhs_rank - 2; ++i) {
218       num_weights_matrices *= rhs->dims->data[i];
219     }
220     op_data->compute_row_sums = true;
221     node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
222     TfLiteTensor* input_quantized;
223     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
224                                                 &input_quantized));
225     input_quantized->type = op_context->rhs->type;
226     input_quantized->allocation_type = kTfLiteArenaRw;
227 
228     TfLiteIntArray* input_quantized_size =
229         TfLiteIntArrayCopy(op_context->lhs->dims);
230     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
231                                                      input_quantized_size));
232 
233     node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
234     TfLiteTensor* scaling_factors;
235     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
236                                                 &scaling_factors));
237     scaling_factors->type = kTfLiteFloat32;
238     scaling_factors->allocation_type = kTfLiteArenaRw;
239     // Total size of scaling factors is batch size * number of total batches
240     int scaling_dims[1] = {num_batches * batch_size};
241     if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
242       TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
243       scaling_factors_size->data[0] = scaling_dims[0];
244       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
245                                                        scaling_factors_size));
246     }
247 
248     node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
249     TfLiteTensor* accum_scratch;
250     TF_LITE_ENSURE_OK(
251         context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
252     accum_scratch->type = kTfLiteInt32;
253     accum_scratch->allocation_type = kTfLiteArenaRw;
254     int accum_scratch_dims[2] = {num_units, batch_size};
255     if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
256                                    accum_scratch_dims)) {
257       TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
258       accum_size->data[0] = num_units;
259       accum_size->data[1] = batch_size;
260       TF_LITE_ENSURE_OK(
261           context, context->ResizeTensor(context, accum_scratch, accum_size));
262     }
263 
264     node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
265     TfLiteTensor* input_offsets;
266     TF_LITE_ENSURE_OK(
267         context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
268     input_offsets->type = kTfLiteInt32;
269     input_offsets->allocation_type = kTfLiteArenaRw;
270     if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
271       TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
272       input_offsets_size->data[0] = num_batches * batch_size;
273       TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
274                                                        input_offsets_size));
275     }
276     node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
277     TfLiteTensor* row_sums;
278     TF_LITE_ENSURE_OK(context,
279                       GetTemporarySafe(context, node, /*index=*/6, &row_sums));
280     row_sums->type = kTfLiteInt32;
281     row_sums->allocation_type = kTfLiteArenaRwPersistent;
282     int row_sums_dims[1] = {num_weights_matrices * num_units};
283     if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
284       TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
285       row_sums_size->data[0] = row_sums_dims[0];
286       TF_LITE_ENSURE_OK(
287           context, context->ResizeTensor(context, row_sums, row_sums_size));
288     }
289   }
290 
291   return kTfLiteOk;
292 }
293 
Prepare(TfLiteContext * context,TfLiteNode * node)294 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
295   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
296   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
297 
298   OpContext op_context(context, node);
299   TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
300   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
301 
302   bool adj_x = op_context.params->adj_x;
303   bool adj_y = op_context.params->adj_y;
304 
305   const TfLiteTensor* lhs_data;
306   TF_LITE_ENSURE_OK(context,
307                     GetInputSafe(context, node, kInputLHSTensor, &lhs_data));
308   const TfLiteTensor* rhs_data;
309   TF_LITE_ENSURE_OK(context,
310                     GetInputSafe(context, node, kInputRHSTensor, &rhs_data));
311   TfLiteTensor* output;
312   TF_LITE_ENSURE_OK(context,
313                     GetOutputSafe(context, node, kOutputTensor, &output));
314 
315   // Note that quantized inference requires that all tensors have their
316   // parameters set. This is usually done during quantized training.
317   if ((lhs_data->type == kTfLiteInt8 || lhs_data->type == kTfLiteInt16) &&
318       output->type != kTfLiteInt32) {
319     double real_multiplier = 0.0;
320     TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
321         context, lhs_data, rhs_data, output, &real_multiplier));
322     int exponent;
323     QuantizeMultiplier(real_multiplier, &op_data->output_multiplier, &exponent);
324     op_data->output_shift = exponent;
325     // BatchMatMul has no fused activation functions. Therefore, set
326     // output activation min and max to min and max of int8_t or int16_t
327     // type.
328     if (lhs_data->type == kTfLiteInt8) {
329       op_data->output_activation_min = std::numeric_limits<int8_t>::min();
330       op_data->output_activation_max = std::numeric_limits<int8_t>::max();
331     } else {
332       op_data->output_activation_min = std::numeric_limits<int16_t>::min();
333       op_data->output_activation_max = std::numeric_limits<int16_t>::max();
334     }
335   }
336 
337   if (lhs_data->type == kTfLiteInt16) {
338     TF_LITE_ENSURE_EQ(context, lhs_data->params.zero_point, 0);
339     TF_LITE_ENSURE_EQ(context, rhs_data->params.zero_point, 0);
340     TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
341   }
342 
343   TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 ||
344                               lhs_data->type == kTfLiteInt8 ||
345                               lhs_data->type == kTfLiteInt16);
346   TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 ||
347                               rhs_data->type == kTfLiteInt8 ||
348                               rhs_data->type == kTfLiteInt16);
349   // Either we have a hybrid quantization with a float32 and an int8 input,
350   // otherwise both inputs should be of the same type.
351   TF_LITE_ENSURE(context, (lhs_data->type == kTfLiteFloat32 &&
352                            rhs_data->type == kTfLiteInt8) ||
353                               lhs_data->type == rhs_data->type);
354   // Support dimensions between 2 and 5, inclusive.
355   TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2);
356   TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5);
357   TF_LITE_ENSURE(context, NumDimensions(rhs_data) >= 2);
358   TF_LITE_ENSURE(context, NumDimensions(rhs_data) <= 5);
359 
360   const int lhs_rank = NumDimensions(lhs_data);
361   const int rhs_rank = NumDimensions(rhs_data);
362   const int output_rank = std::max(lhs_rank, rhs_rank);
363   const RuntimeShape extended_lhs_shape =
364       RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data));
365   const RuntimeShape extended_rhs_shape =
366       RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data));
367 
368   // Ensure any batch dimensions obey broacasting rules.
369   for (int i = 0; i < output_rank - 2; ++i) {
370     const int lhs_dim = extended_lhs_shape.Dims(i);
371     const int rhs_dim = extended_rhs_shape.Dims(i);
372     if (lhs_dim != rhs_dim) {
373       if (lhs_dim != 1) {
374         TF_LITE_ENSURE_EQ(context, rhs_dim, 1);
375       }
376     }
377   }
378   // Ensure other dimensions work for matrix multiplication.
379   int accum_dim_lhs = adj_x ? extended_lhs_shape.Dims(output_rank - 2)
380                             : extended_lhs_shape.Dims(output_rank - 1);
381   int accum_dim_rhs = adj_y ? extended_rhs_shape.Dims(output_rank - 1)
382                             : extended_rhs_shape.Dims(output_rank - 2);
383 
384   TF_LITE_ENSURE_EQ(context, accum_dim_lhs, accum_dim_rhs);
385   TfLiteStatus status =
386       ResizeOutputTensor(context, extended_lhs_shape, extended_rhs_shape, adj_x,
387                          adj_y, output_rank, output);
388   return status;
389 }
390 
391 template <typename scalar>
TransposeRowsColumnsImpl(const TfLiteTensor * tensor_in,const scalar * input,TfLiteTensor * tensor_out,scalar * output)392 void TransposeRowsColumnsImpl(const TfLiteTensor* tensor_in,
393                               const scalar* input, TfLiteTensor* tensor_out,
394                               scalar* output) {
395   RuntimeShape transposed_shape(GetTensorShape(tensor_in));
396   RuntimeShape shape(GetTensorShape(tensor_in));
397   TransposeParams params;
398   int rank = NumDimensions(tensor_in);
399   params.perm_count = rank;
400   for (int i = 0; i < rank - 2; ++i) {
401     params.perm[i] = i;
402   }
403   // Transpose the last two dimensions.
404   params.perm[rank - 2] = rank - 1;
405   params.perm[rank - 1] = rank - 2;
406   transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
407   transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
408   optimized_ops::Transpose(params, shape, input, transposed_shape, output);
409 }
410 
TransposeRowsColumns(TfLiteContext * context,const TfLiteTensor * tensor_in,TfLiteTensor * tensor_out)411 TfLiteStatus TransposeRowsColumns(TfLiteContext* context,
412                                   const TfLiteTensor* tensor_in,
413                                   TfLiteTensor* tensor_out) {
414   if (tensor_in->type == kTfLiteFloat32) {
415     TransposeRowsColumnsImpl<float>(tensor_in, GetTensorData<float>(tensor_in),
416                                     tensor_out,
417                                     GetTensorData<float>(tensor_out));
418     return kTfLiteOk;
419   } else if (tensor_in->type == kTfLiteInt8) {
420     TransposeRowsColumnsImpl<int8_t>(
421         tensor_in, GetTensorData<int8_t>(tensor_in), tensor_out,
422         GetTensorData<int8_t>(tensor_out));
423     return kTfLiteOk;
424   } else if (tensor_in->type == kTfLiteInt16) {
425     TransposeRowsColumnsImpl<int16_t>(
426         tensor_in, GetTensorData<int16_t>(tensor_in), tensor_out,
427         GetTensorData<int16_t>(tensor_out));
428     return kTfLiteOk;
429   } else {
430     TF_LITE_KERNEL_LOG(
431         context, "Can only transpose tensors with float, int8 or int16 type.");
432     return kTfLiteError;
433   }
434 }
435 
SwapRowColumnDims(const RuntimeShape & shape)436 RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
437   RuntimeShape swapped_shape(shape);
438   const int32_t dims = shape.DimensionsCount();
439   swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
440   swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
441   return swapped_shape;
442 }
443 
444 template <KernelType kernel_type>
EvalHybrid(TfLiteContext * context,TfLiteNode * node,OpData * data,const RuntimeShape & input_shape,const TfLiteTensor * input,const RuntimeShape & filter_shape,const TfLiteTensor * filter,TfLiteTensor * input_quantized,TfLiteTensor * scaling_factors,TfLiteTensor * accum_scratch,TfLiteTensor * row_sums,TfLiteTensor * input_offsets,TfLiteTensor * output)445 TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data,
446                         const RuntimeShape& input_shape,
447                         const TfLiteTensor* input,
448                         const RuntimeShape& filter_shape,
449                         const TfLiteTensor* filter,
450                         TfLiteTensor* input_quantized,
451                         TfLiteTensor* scaling_factors,
452                         TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
453                         TfLiteTensor* input_offsets, TfLiteTensor* output) {
454   const auto* params =
455       reinterpret_cast<TfLiteBatchMatMulParams*>(node->builtin_data);
456   const int32_t num_input_dims = input_shape.DimensionsCount();
457 
458   // Input row/cols have been swapped at this point, so dims are
459   // {input_size, num_batches}
460   const int input_size = input_shape.Dims(num_input_dims - 2);
461   const int batch_size = input_shape.Dims(num_input_dims - 1);
462 
463   int num_batches_to_quantize = batch_size;
464   for (int i = 0; i < input_shape.DimensionsCount() - 2; ++i) {
465     num_batches_to_quantize *= input_shape.Dims(i);
466   }
467   // Quantize input from float to uint8 + quantization params (scaling factor).
468   const int scaling_factor_size = GetTensorShape(scaling_factors).FlatSize();
469   TF_LITE_ENSURE(context, scaling_factor_size >= num_batches_to_quantize);
470   float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
471   int32_t* input_offset_ptr = nullptr;
472   int32_t* row_sums_ptr = nullptr;
473   input_offset_ptr = GetTensorData<int32_t>(input_offsets);
474   row_sums_ptr = GetTensorData<int32_t>(row_sums);
475   if (!params->asymmetric_quantize_inputs) {
476     memset(input_offset_ptr, 0, input_offsets->bytes);
477   }
478   int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
479   const int8_t* filter_data = GetTensorData<int8_t>(filter);
480   const float* input_ptr = GetTensorData<float>(input);
481   // Quantize each batch independently.
482   tensor_utils::BatchQuantizeFloats(input_ptr, num_batches_to_quantize,
483                                     input_size, quant_data, scaling_factors_ptr,
484                                     input_offset_ptr,
485                                     params->asymmetric_quantize_inputs);
486   for (int b = 0; b < num_batches_to_quantize; ++b) {
487     // Incorporate scaling of the filter.
488     scaling_factors_ptr[b] *= filter->params.scale;
489   }
490 
491   RuntimeShape output_shape = GetTensorShape(output);
492   int output_size = 1;
493   for (int i = 0; i < output_shape.DimensionsCount(); ++i) {
494     output_size *= output_shape.Dims(i);
495   }
496   std::fill_n(GetTensorData<float>(output), output_size, 0.0f);
497   if (kernel_type == kGenericOptimized) {
498     optimized_ops::BatchMatMul(
499         filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr,
500         input_offset_ptr, row_sums_ptr, GetTensorShape(output),
501         GetTensorData<int32_t>(accum_scratch), GetTensorData<float>(output),
502         &(data->compute_row_sums), CpuBackendContext::GetFromContext(context));
503   } else {
504     reference_ops::BatchMatMul(
505         filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr,
506         input_offset_ptr, row_sums_ptr, GetTensorShape(output),
507         GetTensorData<float>(output), &(data->compute_row_sums));
508   }
509 
510   return kTfLiteOk;
511 }
512 
513 template <KernelType kernel_type>
EvalInt8Int8(TfLiteContext * context,const OpData * data,const RuntimeShape & lhs_shape,const TfLiteTensor * lhs,const RuntimeShape & rhs_shape,const TfLiteTensor * rhs,const RuntimeShape & output_shape,TfLiteTensor * output)514 TfLiteStatus EvalInt8Int8(TfLiteContext* context, const OpData* data,
515                           const RuntimeShape& lhs_shape,
516                           const TfLiteTensor* lhs,
517                           const RuntimeShape& rhs_shape,
518                           const TfLiteTensor* rhs,
519                           const RuntimeShape& output_shape,
520                           TfLiteTensor* output) {
521   // Reuse params struct from FullyConnected Op.
522   FullyConnectedParams op_params;
523   int32_t input_offset = -lhs->params.zero_point;
524   int32_t filter_offset = -rhs->params.zero_point;
525   int32_t output_offset = output->params.zero_point;
526   op_params.input_offset = input_offset;
527   op_params.weights_offset = filter_offset;
528   op_params.output_offset = output_offset;
529   op_params.output_multiplier = data->output_multiplier;
530   op_params.output_shift = data->output_shift;
531   op_params.quantized_activation_min = data->output_activation_min;
532   op_params.quantized_activation_max = data->output_activation_max;
533   op_params.lhs_cacheable = IsConstantTensor(lhs);
534   op_params.rhs_cacheable = IsConstantTensor(rhs);
535 
536   if (kernel_type == kReference) {
537     reference_ops::BatchMatMul<int8_t, int32_t>(
538         op_params, rhs_shape, GetTensorData<int8_t>(rhs), lhs_shape,
539         GetTensorData<int8_t>(lhs), GetTensorShape(output),
540         GetTensorData<int8_t>(output));
541   } else {
542     optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
543                                lhs_shape, GetTensorData<int8_t>(lhs),
544                                GetTensorShape(output),
545                                GetTensorData<int8_t>(output),
546                                CpuBackendContext::GetFromContext(context));
547   }
548   return kTfLiteOk;
549 }
550 
551 template <KernelType kernel_type>
EvalInt8Int32(TfLiteContext * context,const OpData * data,const RuntimeShape & lhs_shape,const TfLiteTensor * lhs,const RuntimeShape & rhs_shape,const TfLiteTensor * rhs,const RuntimeShape & output_shape,TfLiteTensor * output)552 TfLiteStatus EvalInt8Int32(TfLiteContext* context, const OpData* data,
553                            const RuntimeShape& lhs_shape,
554                            const TfLiteTensor* lhs,
555                            const RuntimeShape& rhs_shape,
556                            const TfLiteTensor* rhs,
557                            const RuntimeShape& output_shape,
558                            TfLiteTensor* output) {
559   // Reuse params struct from FullyConnected Op.
560   FullyConnectedParams op_params;
561   int32_t input_offset = -lhs->params.zero_point;
562   int32_t weights_offset = -rhs->params.zero_point;
563   int32_t output_offset = output->params.zero_point;
564   op_params.input_offset = input_offset;
565   op_params.weights_offset = weights_offset;
566   op_params.output_offset = output_offset;
567   op_params.output_multiplier = data->output_multiplier;
568   op_params.output_shift = data->output_shift;
569   op_params.quantized_activation_min = data->output_activation_min;
570   op_params.quantized_activation_max = data->output_activation_max;
571   op_params.lhs_cacheable = IsConstantTensor(lhs);
572   op_params.rhs_cacheable = IsConstantTensor(rhs);
573 
574   // Set BatchMatMul lhs param to rhs(filter) and rhs param to lhs(input). For
575   // the reason, see comment of Eval() function.
576   if (kernel_type == kReference) {
577     reference_ops::BatchMatMul<int8, int8, int32>(
578         rhs_shape, GetTensorData<int8>(rhs), lhs_shape,
579         GetTensorData<int8>(lhs), GetTensorShape(output),
580         GetTensorData<int32>(output));
581   } else {
582     optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
583                                lhs_shape, GetTensorData<int8_t>(lhs),
584                                GetTensorShape(output),
585                                GetTensorData<int32_t>(output),
586                                CpuBackendContext::GetFromContext(context));
587   }
588   return kTfLiteOk;
589 }
590 
591 template <KernelType kernel_type>
EvalInt16(TfLiteContext * context,const OpData * data,const RuntimeShape & lhs_shape,const TfLiteTensor * lhs,const RuntimeShape & rhs_shape,const TfLiteTensor * rhs,const RuntimeShape & output_shape,TfLiteTensor * output)592 TfLiteStatus EvalInt16(TfLiteContext* context, const OpData* data,
593                        const RuntimeShape& lhs_shape, const TfLiteTensor* lhs,
594                        const RuntimeShape& rhs_shape, const TfLiteTensor* rhs,
595                        const RuntimeShape& output_shape, TfLiteTensor* output) {
596   // Reuse params struct from FullyConnected Op.
597   FullyConnectedParams op_params;
598   int32_t input_offset = -lhs->params.zero_point;
599   int32_t filter_offset = -rhs->params.zero_point;
600   int32_t output_offset = output->params.zero_point;
601   op_params.input_offset = input_offset;
602   op_params.weights_offset = filter_offset;
603   op_params.output_offset = output_offset;
604   op_params.output_multiplier = data->output_multiplier;
605   op_params.output_shift = data->output_shift;
606   op_params.quantized_activation_min = data->output_activation_min;
607   op_params.quantized_activation_max = data->output_activation_max;
608 
609   // optimized_ops not yet implemnted for int16_t, use reference_ops in all
610   // cases.
611   reference_ops::BatchMatMul<int16_t, int64_t>(
612       op_params, rhs_shape, GetTensorData<int16_t>(rhs), lhs_shape,
613       GetTensorData<int16_t>(lhs), GetTensorShape(output),
614       GetTensorData<int16_t>(output));
615   return kTfLiteOk;
616 }
617 
618 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,TfLiteNode * node,OpData * data,const RuntimeShape & lhs_shape,const TfLiteTensor * lhs,const RuntimeShape & rhs_shape,const TfLiteTensor * rhs,TfLiteTensor * output)619 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
620                            OpData* data, const RuntimeShape& lhs_shape,
621                            const TfLiteTensor* lhs,
622                            const RuntimeShape& rhs_shape,
623                            const TfLiteTensor* rhs, TfLiteTensor* output) {
624   if (lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8) {
625     TfLiteTensor* input_quantized;
626     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
627                                                 &input_quantized));
628     TfLiteTensor* scaling_factors;
629     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
630                                                 &scaling_factors));
631     TfLiteTensor* accum_scratch;
632     TF_LITE_ENSURE_OK(
633         context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
634     TfLiteTensor* input_offsets;
635     TF_LITE_ENSURE_OK(
636         context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
637     TfLiteTensor* row_sums;
638     TF_LITE_ENSURE_OK(context,
639                       GetTemporarySafe(context, node, /*index=*/6, &row_sums));
640     return EvalHybrid<kernel_type>(
641         context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
642         scaling_factors, accum_scratch, row_sums, input_offsets, output);
643   } else if (lhs->type == kTfLiteInt8 && rhs->type == kTfLiteInt8) {
644     if (output->type == kTfLiteInt8) {
645       return EvalInt8Int8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape,
646                                        rhs, GetTensorShape(output), output);
647     } else {
648       return EvalInt8Int32<kernel_type>(context, data, lhs_shape, lhs,
649                                         rhs_shape, rhs, GetTensorShape(output),
650                                         output);
651     }
652   } else if (lhs->type == kTfLiteInt16 && rhs->type == kTfLiteInt16) {
653     return EvalInt16<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs,
654                                   GetTensorShape(output), output);
655   } else {
656     TF_LITE_KERNEL_LOG(
657         context,
658         "Currently only hybrid, int8 and int16 quantization are supported.\n");
659     return kTfLiteError;
660   }
661   return kTfLiteOk;
662 }
663 
GetTempRhs(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * rhs)664 TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
665                          const TfLiteTensor* rhs) {
666   TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1);
667   if (transposed_rhs == nullptr) {
668     return nullptr;
669   }
670 
671   if (rhs->type == kTfLiteInt8 || rhs->type == kTfLiteInt16) {
672     // Get the quantization params from the RHS tensor.
673     transposed_rhs->params.scale = rhs->params.scale;
674     transposed_rhs->params.zero_point = rhs->params.zero_point;
675   }
676   return transposed_rhs;
677 }
678 
GetTempLhs(TfLiteContext * context,TfLiteNode * node,const TfLiteTensor * lhs)679 TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
680                          const TfLiteTensor* lhs) {
681   TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0);
682   if (transposed_lhs == nullptr) {
683     return nullptr;
684   }
685 
686   if (lhs->type == kTfLiteInt8 || lhs->type == kTfLiteInt16) {
687     // Get the quantization params from the LHS tensor.
688     transposed_lhs->params.scale = lhs->params.scale;
689     transposed_lhs->params.zero_point = lhs->params.zero_point;
690   }
691   return transposed_lhs;
692 }
693 
694 // Perform a batch matrix multiply on
695 // LHS <..., A, B>  X  RHS<..., B, C>
696 // where the leading dimensions of LHS and RHS obey broadcasting rules
697 // (this Op will apply broadcasting rules).
698 // We assume that LHS and RHS are both row oriented (adjacent values in memory
699 // are in the same row) and will output in the same memory layout. However,
700 // our fast GEMM libraries assume RCC layout (LHS row oriented,
701 // RHS column oriented, output column oriented). Therefore, we perform
702 // RHS <..., C, B> X LHS <..., B, A>
703 // where output is a C X A column-oriented, which is equivalent to
704 // A X C row-oriented.
705 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)706 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
707   OpContext op_context(context, node);
708   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
709   const TfLiteTensor* lhs;
710   TF_LITE_ENSURE_OK(context,
711                     GetInputSafe(context, node, kInputLHSTensor, &lhs));
712   const TfLiteTensor* rhs;
713   TF_LITE_ENSURE_OK(context,
714                     GetInputSafe(context, node, kInputRHSTensor, &rhs));
715   TfLiteTensor* output;
716   TF_LITE_ENSURE_OK(context,
717                     GetOutputSafe(context, node, kOutputTensor, &output));
718   RuntimeShape orig_lhs_shape = GetTensorShape(lhs);
719   RuntimeShape orig_rhs_shape = GetTensorShape(rhs);
720 
721   bool adj_y = op_context.params->adj_y;
722   bool adj_x = op_context.params->adj_x;
723 
724   const TfLiteTensor* rhs_tensor = adj_y ? rhs : GetTempRhs(context, node, rhs);
725   const TfLiteTensor* lhs_tensor = adj_x ? GetTempLhs(context, node, lhs) : lhs;
726   if (!adj_y) {
727     // TODO(b/154760341) Constant tensors should already be transposed, but
728     // we transpose once if necessary for now.
729     if (!(IsConstantTensor(rhs) && op_data->rhs_transposed)) {
730       TransposeRowsColumns(context, rhs, GetTemporary(context, node, 1));
731       op_data->rhs_transposed = true;
732     }
733   }
734   if (adj_x) {
735     TransposeRowsColumns(context, lhs, GetTemporary(context, node, 0));
736   }
737   RuntimeShape rhs_shape =
738       adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape);
739   RuntimeShape lhs_shape =
740       adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape);
741 
742   switch (rhs->type) {
743     case kTfLiteFloat32:
744       // Note we pass RHS args first, LHS args second. See note above.
745       if (kernel_type == kGenericOptimized) {
746         optimized_ops::BatchMatMul(rhs_shape, GetTensorData<float>(rhs_tensor),
747                                    lhs_shape, GetTensorData<float>(lhs_tensor),
748                                    GetTensorShape(output),
749                                    GetTensorData<float>(output),
750                                    CpuBackendContext::GetFromContext(context));
751       } else {
752         reference_ops::BatchMatMul(rhs_shape, GetTensorData<float>(rhs_tensor),
753                                    lhs_shape, GetTensorData<float>(lhs_tensor),
754                                    GetTensorShape(output),
755                                    GetTensorData<float>(output));
756       }
757       break;
758     case kTfLiteInt8:
759     case kTfLiteInt16:
760       EvalQuantized<kernel_type>(context, node, op_data, lhs_shape, lhs_tensor,
761                                  rhs_shape, rhs_tensor, output);
762       break;
763     default:
764       TF_LITE_KERNEL_LOG(context,
765                          "Currently BatchMatMul doesn't support type: %s",
766                          TfLiteTypeGetName(lhs->type));
767       return kTfLiteError;
768   }
769   return kTfLiteOk;
770 }
771 
772 }  // namespace batch_matmul
773 
Register_BATCH_MATMUL_REF()774 TfLiteRegistration* Register_BATCH_MATMUL_REF() {
775   static TfLiteRegistration r = {batch_matmul::Init, batch_matmul::Free,
776                                  batch_matmul::Prepare,
777                                  batch_matmul::Eval<batch_matmul::kReference>};
778   return &r;
779 }
780 
Register_BATCH_MATMUL_GENERIC_OPTIMIZED()781 TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() {
782   static TfLiteRegistration r = {
783       batch_matmul::Init, batch_matmul::Free, batch_matmul::Prepare,
784       batch_matmul::Eval<batch_matmul::kGenericOptimized>};
785   return &r;
786 }
787 
Register_BATCH_MATMUL()788 TfLiteRegistration* Register_BATCH_MATMUL() {
789   return Register_BATCH_MATMUL_GENERIC_OPTIMIZED();
790 }
791 
792 }  // namespace builtin
793 }  // namespace ops
794 }  // namespace tflite
795