1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_
17
18 #include <algorithm>
19
20 #include "fixedpoint/fixedpoint.h"
21 #include "ruy/profiler/instrumentation.h" // from @ruy
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
25 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
26 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
27 #include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
28 #include "tensorflow/lite/kernels/internal/types.h"
29
30 namespace tflite {
31 namespace optimized_integer_ops {
32
33 // Element-wise mul that can often be used for inner loop of broadcast Mul as
34 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)35 inline void MulElementwise(int size, const ArithmeticParams& params,
36 const int8* input1_data, const int8* input2_data,
37 int8* output_data) {
38 ruy::profiler::ScopeLabel label("MulElementwiseInt8/8bit");
39 int i = 0;
40 TFLITE_DCHECK_GT(params.input1_offset, -256);
41 TFLITE_DCHECK_LT(params.input1_offset, 256);
42 TFLITE_DCHECK_GT(params.input2_offset, -256);
43 TFLITE_DCHECK_LT(params.input2_offset, 256);
44 TFLITE_DCHECK_GT(params.output_offset, -256);
45 TFLITE_DCHECK_LT(params.output_offset, 256);
46 #ifdef USE_NEON
47 const int16x8_t input1_offset_vector = vdupq_n_s16(params.input1_offset);
48 const int16x8_t input2_offset_vector = vdupq_n_s16(params.input2_offset);
49 const int16x8_t output_offset_vector = vdupq_n_s16(params.output_offset);
50 const auto output_activation_min_vector =
51 vdupq_n_s8(params.quantized_activation_min);
52 const auto output_activation_max_vector =
53 vdupq_n_s8(params.quantized_activation_max);
54 const int left_shift = std::max(0, params.output_shift);
55 const int right_shift = std::max(0, -params.output_shift);
56 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
57 for (; i <= size - 16; i += 16) {
58 // We load / store 16 at a time, multiplying as four sets of 4 int32s.
59 const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
60 const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
61
62 const int16x8_t input1_val_s16_high =
63 vmovl_s8(vget_high_s8(input1_val_original));
64 const int16x8_t input1_val_s16_low =
65 vmovl_s8(vget_low_s8(input1_val_original));
66
67 const int16x8_t input2_val_s16_high =
68 vmovl_s8(vget_high_s8(input2_val_original));
69 const int16x8_t input2_val_s16_low =
70 vmovl_s8(vget_low_s8(input2_val_original));
71 const int16x8_t input1_val_high =
72 vaddq_s16(input1_val_s16_high, input1_offset_vector);
73 const int16x8_t input2_val_high =
74 vaddq_s16(input2_val_s16_high, input2_offset_vector);
75 const int16x8_t input1_val_low =
76 vaddq_s16(input1_val_s16_low, input1_offset_vector);
77 const int16x8_t input2_val_low =
78 vaddq_s16(input2_val_s16_low, input2_offset_vector);
79 const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
80 const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
81 const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
82 const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
83 const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
84 const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
85 const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
86 const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
87
88 auto p1 = vmull_s16(input2_val_high_high, input1_val_high_high);
89 auto p2 = vmull_s16(input2_val_high_low, input1_val_high_low);
90 auto p3 = vmull_s16(input2_val_low_high, input1_val_low_high);
91 auto p4 = vmull_s16(input2_val_low_low, input1_val_low_low);
92
93 p1 = vshlq_s32(p1, left_shift_vec);
94 p2 = vshlq_s32(p2, left_shift_vec);
95 p3 = vshlq_s32(p3, left_shift_vec);
96 p4 = vshlq_s32(p4, left_shift_vec);
97
98 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
99 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
100 p3 = vqrdmulhq_n_s32(p3, params.output_multiplier);
101 p4 = vqrdmulhq_n_s32(p4, params.output_multiplier);
102 using gemmlowp::RoundingDivideByPOT;
103 p1 = RoundingDivideByPOT(p1, right_shift);
104 p2 = RoundingDivideByPOT(p2, right_shift);
105 p3 = RoundingDivideByPOT(p3, right_shift);
106 p4 = RoundingDivideByPOT(p4, right_shift);
107
108 const auto p1_narrowed = vqmovn_s32(p1);
109 const auto p2_narrowed = vqmovn_s32(p2);
110 const auto p3_narrowed = vqmovn_s32(p3);
111 const auto p4_narrowed = vqmovn_s32(p4);
112
113 const int16x8_t p_part1 =
114 vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
115 const int16x8_t p_part2 =
116 vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
117 const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
118
119 const auto clamped = vmaxq_s8(output_activation_min_vector,
120 vminq_s8(output_activation_max_vector, p));
121 vst1q_s8(output_data + i, clamped);
122 }
123 #endif // NEON
124
125 for (; i < size; ++i) {
126 const int32 input1_val = params.input1_offset + input1_data[i];
127 const int32 input2_val = params.input2_offset + input2_data[i];
128 const int32 unclamped_result =
129 params.output_offset +
130 MultiplyByQuantizedMultiplier(input1_val * input2_val,
131 params.output_multiplier,
132 params.output_shift);
133 const int32 clamped_output =
134 std::min(params.quantized_activation_max,
135 std::max(params.quantized_activation_min, unclamped_result));
136 output_data[i] = static_cast<int8>(clamped_output);
137 }
138 }
139
140 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const int8 broadcast_value,const int8 * input2_data,int8 * output_data)141 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
142 const int8 broadcast_value,
143 const int8* input2_data, int8* output_data) {
144 ruy::profiler::ScopeLabel label("BroadMulSimpleBroadcastInt8/8bit");
145 const int16 input1_val = params.input1_offset + broadcast_value;
146
147 int i = 0;
148 TFLITE_DCHECK_GT(params.input1_offset, -256);
149 TFLITE_DCHECK_LT(params.input1_offset, 256);
150 TFLITE_DCHECK_GT(params.input2_offset, -256);
151 TFLITE_DCHECK_LT(params.input2_offset, 256);
152 TFLITE_DCHECK_GT(params.output_offset, -256);
153 TFLITE_DCHECK_LT(params.output_offset, 256);
154 #ifdef USE_NEON
155 const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
156 const auto output_offset_vector = vdupq_n_s16(params.output_offset);
157 const auto output_activation_min_vector =
158 vdupq_n_s8(params.quantized_activation_min);
159 const auto output_activation_max_vector =
160 vdupq_n_s8(params.quantized_activation_max);
161 const int left_shift = std::max(0, params.output_shift);
162 const int right_shift = std::max(0, -params.output_shift);
163 const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
164 for (; i <= size - 16; i += 16) {
165 // We load / store 16 at a time, multiplying as four sets of 4 int32s.
166 const auto input2_val_original = vld1q_s8(input2_data + i);
167 const auto input2_val_s16_high =
168 vmovl_s8(vget_high_s8(input2_val_original));
169 const auto input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
170
171 const auto input2_val_high =
172 vaddq_s16(input2_val_s16_high, input2_offset_vector);
173 const auto input2_val_low =
174 vaddq_s16(input2_val_s16_low, input2_offset_vector);
175
176 const auto input2_val_low_low = vget_low_s16(input2_val_low);
177 const auto input2_val_low_high = vget_high_s16(input2_val_low);
178 const auto input2_val_high_low = vget_low_s16(input2_val_high);
179 const auto input2_val_high_high = vget_high_s16(input2_val_high);
180
181 auto p1 = vmull_n_s16(input2_val_high_high, input1_val);
182 auto p2 = vmull_n_s16(input2_val_high_low, input1_val);
183 auto p3 = vmull_n_s16(input2_val_low_high, input1_val);
184 auto p4 = vmull_n_s16(input2_val_low_low, input1_val);
185
186 p1 = vshlq_s32(p1, left_shift_vec);
187 p2 = vshlq_s32(p2, left_shift_vec);
188 p3 = vshlq_s32(p3, left_shift_vec);
189 p4 = vshlq_s32(p4, left_shift_vec);
190
191 p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
192 p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
193 p3 = vqrdmulhq_n_s32(p3, params.output_multiplier);
194 p4 = vqrdmulhq_n_s32(p4, params.output_multiplier);
195 using gemmlowp::RoundingDivideByPOT;
196 p1 = RoundingDivideByPOT(p1, right_shift);
197 p2 = RoundingDivideByPOT(p2, right_shift);
198 p3 = RoundingDivideByPOT(p3, right_shift);
199 p4 = RoundingDivideByPOT(p4, right_shift);
200
201 const auto p1_narrowed = vqmovn_s32(p1);
202 const auto p2_narrowed = vqmovn_s32(p2);
203 const auto p3_narrowed = vqmovn_s32(p3);
204 const auto p4_narrowed = vqmovn_s32(p4);
205
206 const int16x8_t p_part1 =
207 vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
208 const int16x8_t p_part2 =
209 vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
210 const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
211
212 const auto clamped = vmaxq_s8(output_activation_min_vector,
213 vminq_s8(output_activation_max_vector, p));
214 vst1q_s8(output_data + i, clamped);
215 }
216 #endif // NEON
217
218 for (; i < size; ++i) {
219 const int32 input2_val = params.input2_offset + input2_data[i];
220 const int32 unclamped_result =
221 params.output_offset +
222 MultiplyByQuantizedMultiplier(input1_val * input2_val,
223 params.output_multiplier,
224 params.output_shift);
225 const int32 clamped_output =
226 std::min(params.quantized_activation_max,
227 std::max(params.quantized_activation_min, unclamped_result));
228 output_data[i] = static_cast<int8>(clamped_output);
229 }
230 }
231
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data)232 inline void Mul(const ArithmeticParams& params,
233 const RuntimeShape& input1_shape, const int8* input1_data,
234 const RuntimeShape& input2_shape, const int8* input2_data,
235 const RuntimeShape& output_shape, int8* output_data) {
236 TFLITE_DCHECK_LE(params.quantized_activation_min,
237 params.quantized_activation_max);
238 ruy::profiler::ScopeLabel label("MulInt8/8bit");
239 const int flat_size =
240 MatchingElementsSize(input1_shape, input2_shape, output_shape);
241
242 MulElementwise(flat_size, params, input1_data, input2_data, output_data);
243 }
244
BroadcastMulDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data)245 inline void BroadcastMulDispatch(const ArithmeticParams& params,
246 const RuntimeShape& input1_shape,
247 const int8* input1_data,
248 const RuntimeShape& input2_shape,
249 const int8* input2_data,
250 const RuntimeShape& output_shape,
251 int8* output_data) {
252 if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
253 return reference_integer_ops::BroadcastMul4DSlow(
254 params, input1_shape, input1_data, input2_shape, input2_data,
255 output_shape, output_data);
256 }
257
258 optimized_ops::BinaryBroadcastFiveFold(
259 params, input1_shape, input1_data, input2_shape, input2_data,
260 output_shape, output_data, MulElementwise, MulSimpleBroadcast);
261 }
262
263 } // namespace optimized_integer_ops
264 } // namespace tflite
265
266 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_
267