1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_
17
18 #include "ruy/profiler/instrumentation.h" // from @ruy
19 #include "tensorflow/lite/kernels/cpu_backend_context.h"
20 #include "tensorflow/lite/kernels/cpu_backend_gemm.h"
21 #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26
27 namespace tflite {
28 namespace optimized_integer_ops {
29
30 template <typename InputScalar, typename DstScalar>
FullyConnectedPerChannel(const FullyConnectedParams & params,const int32 * output_multiplier,const int * output_shift,const RuntimeShape & input_shape,const InputScalar * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,DstScalar * output_data,CpuBackendContext * cpu_backend_context)31 inline void FullyConnectedPerChannel(
32 const FullyConnectedParams& params, const int32* output_multiplier,
33 const int* output_shift, const RuntimeShape& input_shape,
34 const InputScalar* input_data, const RuntimeShape& filter_shape,
35 const int8* filter_data, const RuntimeShape& bias_shape,
36 const int32* bias_data, const RuntimeShape& output_shape,
37 DstScalar* output_data, CpuBackendContext* cpu_backend_context) {
38 ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit");
39
40 const int32 input_offset = params.input_offset;
41 const int32 output_offset = params.output_offset;
42 const int32 output_activation_min = params.quantized_activation_min;
43 const int32 output_activation_max = params.quantized_activation_max;
44 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
45 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
46 // TODO(b/62193649): This really should be:
47 // const int batches = ArraySize(output_dims, 1);
48 // but the current --variable_batch hack consists in overwriting the 3rd
49 // dimension with the runtime batch size, as we don't keep track for each
50 // array of which dimension is the batch dimension in it.
51 const int output_dim_count = output_shape.DimensionsCount();
52 const int filter_dim_count = filter_shape.DimensionsCount();
53 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
54 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
55 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
56 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
57 const int output_rows = output_shape.Dims(output_dim_count - 1);
58 TFLITE_DCHECK_EQ(output_rows, filter_rows);
59 if (bias_data) {
60 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
61 }
62 const bool use_caching =
63 (cpu_backend_context != nullptr) && cpu_backend_context->use_caching();
64
65 cpu_backend_gemm::MatrixParams<int8> lhs_params;
66 lhs_params.rows = filter_rows;
67 lhs_params.cols = filter_cols;
68 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
69 lhs_params.zero_point = 0;
70 lhs_params.cache_policy =
71 use_caching ? cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable)
72 : cpu_backend_gemm::CachePolicy::kNeverCache;
73 cpu_backend_gemm::MatrixParams<InputScalar> rhs_params;
74 rhs_params.rows = filter_cols;
75 rhs_params.cols = batches;
76 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
77 rhs_params.zero_point = -input_offset;
78 rhs_params.cache_policy =
79 use_caching ? cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable)
80 : cpu_backend_gemm::CachePolicy::kNeverCache;
81 cpu_backend_gemm::MatrixParams<DstScalar> dst_params;
82 dst_params.rows = filter_rows;
83 dst_params.cols = batches;
84 dst_params.order = cpu_backend_gemm::Order::kColMajor;
85 dst_params.zero_point = output_offset;
86 cpu_backend_gemm::GemmParams<
87 int32, DstScalar,
88 cpu_backend_gemm::QuantizationFlavor::kIntegerWithPerRowMultiplier>
89 gemm_params;
90 gemm_params.bias = bias_data;
91 gemm_params.clamp_min = output_activation_min;
92 gemm_params.clamp_max = output_activation_max;
93 gemm_params.multiplier_fixedpoint_perchannel = output_multiplier;
94 gemm_params.multiplier_exponent_perchannel = output_shift;
95 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
96 dst_params, output_data, gemm_params,
97 cpu_backend_context);
98 }
99
100 template <typename InputScalar, typename DstScalar>
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const InputScalar * input_data,const RuntimeShape & filter_shape,const int8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,DstScalar * output_data,CpuBackendContext * cpu_backend_context)101 inline void FullyConnected(
102 const FullyConnectedParams& params, const RuntimeShape& input_shape,
103 const InputScalar* input_data, const RuntimeShape& filter_shape,
104 const int8* filter_data, const RuntimeShape& bias_shape,
105 const int32* bias_data, const RuntimeShape& output_shape,
106 DstScalar* output_data, CpuBackendContext* cpu_backend_context) {
107 ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit");
108
109 const int32 input_offset = params.input_offset;
110 const int32 filter_offset = params.weights_offset;
111 const int32 output_offset = params.output_offset;
112 const int32 output_multiplier = params.output_multiplier;
113 const int output_shift = params.output_shift;
114 const int32 output_activation_min = params.quantized_activation_min;
115 const int32 output_activation_max = params.quantized_activation_max;
116 TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
117 TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
118 // TODO(b/62193649): This really should be:
119 // const int batches = ArraySize(output_dims, 1);
120 // but the current --variable_batch hack consists in overwriting the 3rd
121 // dimension with the runtime batch size, as we don't keep track for each
122 // array of which dimension is the batch dimension in it.
123 const int output_dim_count = output_shape.DimensionsCount();
124 const int filter_dim_count = filter_shape.DimensionsCount();
125 const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
126 const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
127 const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
128 TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
129 const int output_rows = output_shape.Dims(output_dim_count - 1);
130 TFLITE_DCHECK_EQ(output_rows, filter_rows);
131 if (bias_data) {
132 TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
133 }
134 const bool use_caching =
135 (cpu_backend_context != nullptr) && cpu_backend_context->use_caching();
136
137 cpu_backend_gemm::MatrixParams<int8> lhs_params;
138 lhs_params.rows = filter_rows;
139 lhs_params.cols = filter_cols;
140 lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
141 lhs_params.zero_point = -filter_offset;
142 lhs_params.cache_policy =
143 use_caching ? cpu_backend_gemm::DefaultCachePolicy(params.lhs_cacheable)
144 : cpu_backend_gemm::CachePolicy::kNeverCache;
145 cpu_backend_gemm::MatrixParams<InputScalar> rhs_params;
146 rhs_params.rows = filter_cols;
147 rhs_params.cols = batches;
148 rhs_params.order = cpu_backend_gemm::Order::kColMajor;
149 rhs_params.zero_point = -input_offset;
150 rhs_params.cache_policy =
151 use_caching ? cpu_backend_gemm::DefaultCachePolicy(params.rhs_cacheable)
152 : cpu_backend_gemm::CachePolicy::kNeverCache;
153 cpu_backend_gemm::MatrixParams<DstScalar> dst_params;
154 dst_params.rows = filter_rows;
155 dst_params.cols = batches;
156 dst_params.order = cpu_backend_gemm::Order::kColMajor;
157 dst_params.zero_point = output_offset;
158 cpu_backend_gemm::GemmParams<int32, DstScalar> gemm_params;
159 gemm_params.bias = bias_data;
160 gemm_params.clamp_min = output_activation_min;
161 gemm_params.clamp_max = output_activation_max;
162 gemm_params.multiplier_fixedpoint = output_multiplier;
163 gemm_params.multiplier_exponent = output_shift;
164 cpu_backend_gemm::Gemm(lhs_params, filter_data, rhs_params, input_data,
165 dst_params, output_data, gemm_params,
166 cpu_backend_context);
167 }
168
169 } // namespace optimized_integer_ops
170 } // namespace tflite
171
172 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_FULLY_CONNECTED_H_
173