xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_
17 
18 #include "ruy/profiler/instrumentation.h"  // from @ruy
19 #include "tensorflow/lite/kernels/cpu_backend_context.h"
20 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
21 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 
27 namespace tflite {
28 namespace optimized_integer_ops {
29 
30 template <typename InputScalar, typename DstScalar>
FullyConnectedPerChannel(const FullyConnectedParams & params,const int32 * output_multiplier,const int * output_shift,const RuntimeShape & input_shape,const InputScalar * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,DstScalar * output_data,CpuBackendContext * cpu_backend_context)31 inline void FullyConnectedPerChannel(
32     const FullyConnectedParams& params, const int32* output_multiplier,
33     const int* output_shift, const RuntimeShape& input_shape,
34     const InputScalar* input_data, const RuntimeShape& filter_shape,
35     const int8* filter_data, const RuntimeShape& bias_shape,
36     const int32* bias_data, const RuntimeShape& output_shape,
37     DstScalar* output_data, CpuBackendContext* cpu_backend_context) {
38   ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit");
39 
40   const int32 input_offset = params.input_offset;
41   const int32 output_offset = params.output_offset;
42   const int32 output_activation_min = params.quantized_activation_min;
43   const int32 output_activation_max = params.quantized_activation_max;
44   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
45   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
46   // TODO(b/62193649): This really should be:
47   //     const int batches = ArraySize(output_dims, 1);
48   // but the current --variable_batch hack consists in overwriting the 3rd
49   // dimension with the runtime batch size, as we don't keep track for each
50   // array of which dimension is the batch dimension in it.
51   const int output_dim_count = output_shape.DimensionsCount();
52   const int filter_dim_count = filter_shape.DimensionsCount();
53   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
54   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
55   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
56   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
57   const int output_rows = output_shape.Dims(output_dim_count - 1);
58   TFLITE_DCHECK_EQ(output_rows, filter_rows);
59   if (bias_data) {
60     TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
61   }
62   const bool use_caching =
63       (cpu_backend_context != nullptr) && cpu_backend_context->use_caching();
64 
65   cpu_backend_gemm::MatrixParams<int8> lhs_params;
66   lhs_params.rows = filter_rows;
67   lhs_params.cols = filter_cols;
68   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
69   lhs_params.zero_point = 0;
70   lhs_params.cache_policy =
71       use_caching ? cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable)
72                   : cpu_backend_gemm::CachePolicy::kNeverCache;
73   cpu_backend_gemm::MatrixParams<InputScalar> rhs_params;
74   rhs_params.rows = filter_cols;
75   rhs_params.cols = batches;
76   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
77   rhs_params.zero_point = -input_offset;
78   rhs_params.cache_policy =
79       use_caching ? cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable)
80                   : cpu_backend_gemm::CachePolicy::kNeverCache;
81   cpu_backend_gemm::MatrixParams<DstScalar> dst_params;
82   dst_params.rows = filter_rows;
83   dst_params.cols = batches;
84   dst_params.order = cpu_backend_gemm::Order::kColMajor;
85   dst_params.zero_point = output_offset;
86   cpu_backend_gemm::GemmParams<
87       int32, DstScalar,
88       cpu_backend_gemm::QuantizationFlavor::kIntegerWithPerRowMultiplier>
89       gemm_params;
90   gemm_params.bias = bias_data;
91   gemm_params.clamp_min = output_activation_min;
92   gemm_params.clamp_max = output_activation_max;
93   gemm_params.multiplier_fixedpoint_perchannel = output_multiplier;
94   gemm_params.multiplier_exponent_perchannel = output_shift;
95   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
96                          dst_params, output_data, gemm_params,
97                          cpu_backend_context);
98 }
99 
100 template <typename InputScalar, typename DstScalar>
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const InputScalar * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,DstScalar * output_data,CpuBackendContext * cpu_backend_context)101 inline void FullyConnected(
102     const FullyConnectedParams& params, const RuntimeShape& input_shape,
103     const InputScalar* input_data, const RuntimeShape& filter_shape,
104     const int8* filter_data, const RuntimeShape& bias_shape,
105     const int32* bias_data, const RuntimeShape& output_shape,
106     DstScalar* output_data, CpuBackendContext* cpu_backend_context) {
107   ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit");
108 
109   const int32 input_offset = params.input_offset;
110   const int32 filter_offset = params.weights_offset;
111   const int32 output_offset = params.output_offset;
112   const int32 output_multiplier = params.output_multiplier;
113   const int output_shift = params.output_shift;
114   const int32 output_activation_min = params.quantized_activation_min;
115   const int32 output_activation_max = params.quantized_activation_max;
116   TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
117   TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
118   // TODO(b/62193649): This really should be:
119   //     const int batches = ArraySize(output_dims, 1);
120   // but the current --variable_batch hack consists in overwriting the 3rd
121   // dimension with the runtime batch size, as we don't keep track for each
122   // array of which dimension is the batch dimension in it.
123   const int output_dim_count = output_shape.DimensionsCount();
124   const int filter_dim_count = filter_shape.DimensionsCount();
125   const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
126   const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
127   const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
128   TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
129   const int output_rows = output_shape.Dims(output_dim_count - 1);
130   TFLITE_DCHECK_EQ(output_rows, filter_rows);
131   if (bias_data) {
132     TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
133   }
134   const bool use_caching =
135       (cpu_backend_context != nullptr) && cpu_backend_context->use_caching();
136 
137   cpu_backend_gemm::MatrixParams<int8> lhs_params;
138   lhs_params.rows = filter_rows;
139   lhs_params.cols = filter_cols;
140   lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
141   lhs_params.zero_point = -filter_offset;
142   lhs_params.cache_policy =
143       use_caching ? cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable)
144                   : cpu_backend_gemm::CachePolicy::kNeverCache;
145   cpu_backend_gemm::MatrixParams<InputScalar> rhs_params;
146   rhs_params.rows = filter_cols;
147   rhs_params.cols = batches;
148   rhs_params.order = cpu_backend_gemm::Order::kColMajor;
149   rhs_params.zero_point = -input_offset;
150   rhs_params.cache_policy =
151       use_caching ? cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable)
152                   : cpu_backend_gemm::CachePolicy::kNeverCache;
153   cpu_backend_gemm::MatrixParams<DstScalar> dst_params;
154   dst_params.rows = filter_rows;
155   dst_params.cols = batches;
156   dst_params.order = cpu_backend_gemm::Order::kColMajor;
157   dst_params.zero_point = output_offset;
158   cpu_backend_gemm::GemmParams<int32, DstScalar> gemm_params;
159   gemm_params.bias = bias_data;
160   gemm_params.clamp_min = output_activation_min;
161   gemm_params.clamp_max = output_activation_max;
162   gemm_params.multiplier_fixedpoint = output_multiplier;
163   gemm_params.multiplier_exponent = output_shift;
164   cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
165                          dst_params, output_data, gemm_params,
166                          cpu_backend_context);
167 }
168 
169 }  // namespace optimized_integer_ops
170 }  // namespace tflite
171 
172 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_
173