xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_
17 
18 #include <algorithm>
19 
20 #include "fixedpoint/fixedpoint.h"
21 #include "ruy/profiler/instrumentation.h"  // from @ruy
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/compatibility.h"
24 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
25 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
26 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
27 #include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
28 #include "tensorflow/lite/kernels/internal/types.h"
29 
30 namespace tflite {
31 namespace optimized_integer_ops {
32 
33 // Element-wise mul that can often be used for inner loop of broadcast Mul as
34 // well as the non-broadcast Mul.
MulElementwise(int size,const ArithmeticParams & params,const int8 * input1_data,const int8 * input2_data,int8 * output_data)35 inline void MulElementwise(int size, const ArithmeticParams& params,
36                            const int8* input1_data, const int8* input2_data,
37                            int8* output_data) {
38   ruy::profiler::ScopeLabel label("MulElementwiseInt8/8bit");
39   int i = 0;
40   TFLITE_DCHECK_GT(params.input1_offset, -256);
41   TFLITE_DCHECK_LT(params.input1_offset, 256);
42   TFLITE_DCHECK_GT(params.input2_offset, -256);
43   TFLITE_DCHECK_LT(params.input2_offset, 256);
44   TFLITE_DCHECK_GT(params.output_offset, -256);
45   TFLITE_DCHECK_LT(params.output_offset, 256);
46 #ifdef USE_NEON
47   const int16x8_t input1_offset_vector = vdupq_n_s16(params.input1_offset);
48   const int16x8_t input2_offset_vector = vdupq_n_s16(params.input2_offset);
49   const int16x8_t output_offset_vector = vdupq_n_s16(params.output_offset);
50   const auto output_activation_min_vector =
51       vdupq_n_s8(params.quantized_activation_min);
52   const auto output_activation_max_vector =
53       vdupq_n_s8(params.quantized_activation_max);
54   const int left_shift = std::max(0, params.output_shift);
55   const int right_shift = std::max(0, -params.output_shift);
56   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
57   for (; i <= size - 16; i += 16) {
58     // We load / store 16 at a time, multiplying as four sets of 4 int32s.
59     const int8x16_t input1_val_original = vld1q_s8(input1_data + i);
60     const int8x16_t input2_val_original = vld1q_s8(input2_data + i);
61 
62     const int16x8_t input1_val_s16_high =
63         vmovl_s8(vget_high_s8(input1_val_original));
64     const int16x8_t input1_val_s16_low =
65         vmovl_s8(vget_low_s8(input1_val_original));
66 
67     const int16x8_t input2_val_s16_high =
68         vmovl_s8(vget_high_s8(input2_val_original));
69     const int16x8_t input2_val_s16_low =
70         vmovl_s8(vget_low_s8(input2_val_original));
71     const int16x8_t input1_val_high =
72         vaddq_s16(input1_val_s16_high, input1_offset_vector);
73     const int16x8_t input2_val_high =
74         vaddq_s16(input2_val_s16_high, input2_offset_vector);
75     const int16x8_t input1_val_low =
76         vaddq_s16(input1_val_s16_low, input1_offset_vector);
77     const int16x8_t input2_val_low =
78         vaddq_s16(input2_val_s16_low, input2_offset_vector);
79     const int16x4_t input1_val_high_high = vget_high_s16(input1_val_high);
80     const int16x4_t input1_val_high_low = vget_low_s16(input1_val_high);
81     const int16x4_t input1_val_low_high = vget_high_s16(input1_val_low);
82     const int16x4_t input1_val_low_low = vget_low_s16(input1_val_low);
83     const int16x4_t input2_val_high_high = vget_high_s16(input2_val_high);
84     const int16x4_t input2_val_high_low = vget_low_s16(input2_val_high);
85     const int16x4_t input2_val_low_high = vget_high_s16(input2_val_low);
86     const int16x4_t input2_val_low_low = vget_low_s16(input2_val_low);
87 
88     auto p1 = vmull_s16(input2_val_high_high, input1_val_high_high);
89     auto p2 = vmull_s16(input2_val_high_low, input1_val_high_low);
90     auto p3 = vmull_s16(input2_val_low_high, input1_val_low_high);
91     auto p4 = vmull_s16(input2_val_low_low, input1_val_low_low);
92 
93     p1 = vshlq_s32(p1, left_shift_vec);
94     p2 = vshlq_s32(p2, left_shift_vec);
95     p3 = vshlq_s32(p3, left_shift_vec);
96     p4 = vshlq_s32(p4, left_shift_vec);
97 
98     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
99     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
100     p3 = vqrdmulhq_n_s32(p3, params.output_multiplier);
101     p4 = vqrdmulhq_n_s32(p4, params.output_multiplier);
102     using gemmlowp::RoundingDivideByPOT;
103     p1 = RoundingDivideByPOT(p1, right_shift);
104     p2 = RoundingDivideByPOT(p2, right_shift);
105     p3 = RoundingDivideByPOT(p3, right_shift);
106     p4 = RoundingDivideByPOT(p4, right_shift);
107 
108     const auto p1_narrowed = vqmovn_s32(p1);
109     const auto p2_narrowed = vqmovn_s32(p2);
110     const auto p3_narrowed = vqmovn_s32(p3);
111     const auto p4_narrowed = vqmovn_s32(p4);
112 
113     const int16x8_t p_part1 =
114         vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
115     const int16x8_t p_part2 =
116         vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
117     const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
118 
119     const auto clamped = vmaxq_s8(output_activation_min_vector,
120                                   vminq_s8(output_activation_max_vector, p));
121     vst1q_s8(output_data + i, clamped);
122   }
123 #endif  // NEON
124 
125   for (; i < size; ++i) {
126     const int32 input1_val = params.input1_offset + input1_data[i];
127     const int32 input2_val = params.input2_offset + input2_data[i];
128     const int32 unclamped_result =
129         params.output_offset +
130         MultiplyByQuantizedMultiplier(input1_val * input2_val,
131                                       params.output_multiplier,
132                                       params.output_shift);
133     const int32 clamped_output =
134         std::min(params.quantized_activation_max,
135                  std::max(params.quantized_activation_min, unclamped_result));
136     output_data[i] = static_cast<int8>(clamped_output);
137   }
138 }
139 
140 // Broadcast mul that can often be used for inner loop of broadcast Mul.
MulSimpleBroadcast(int size,const ArithmeticParams & params,const int8 broadcast_value,const int8 * input2_data,int8 * output_data)141 inline void MulSimpleBroadcast(int size, const ArithmeticParams& params,
142                                const int8 broadcast_value,
143                                const int8* input2_data, int8* output_data) {
144   ruy::profiler::ScopeLabel label("BroadMulSimpleBroadcastInt8/8bit");
145   const int16 input1_val = params.input1_offset + broadcast_value;
146 
147   int i = 0;
148   TFLITE_DCHECK_GT(params.input1_offset, -256);
149   TFLITE_DCHECK_LT(params.input1_offset, 256);
150   TFLITE_DCHECK_GT(params.input2_offset, -256);
151   TFLITE_DCHECK_LT(params.input2_offset, 256);
152   TFLITE_DCHECK_GT(params.output_offset, -256);
153   TFLITE_DCHECK_LT(params.output_offset, 256);
154 #ifdef USE_NEON
155   const auto input2_offset_vector = vdupq_n_s16(params.input2_offset);
156   const auto output_offset_vector = vdupq_n_s16(params.output_offset);
157   const auto output_activation_min_vector =
158       vdupq_n_s8(params.quantized_activation_min);
159   const auto output_activation_max_vector =
160       vdupq_n_s8(params.quantized_activation_max);
161   const int left_shift = std::max(0, params.output_shift);
162   const int right_shift = std::max(0, -params.output_shift);
163   const int32x4_t left_shift_vec = vdupq_n_s32(left_shift);
164   for (; i <= size - 16; i += 16) {
165     // We load / store 16 at a time, multiplying as four sets of 4 int32s.
166     const auto input2_val_original = vld1q_s8(input2_data + i);
167     const auto input2_val_s16_high =
168         vmovl_s8(vget_high_s8(input2_val_original));
169     const auto input2_val_s16_low = vmovl_s8(vget_low_s8(input2_val_original));
170 
171     const auto input2_val_high =
172         vaddq_s16(input2_val_s16_high, input2_offset_vector);
173     const auto input2_val_low =
174         vaddq_s16(input2_val_s16_low, input2_offset_vector);
175 
176     const auto input2_val_low_low = vget_low_s16(input2_val_low);
177     const auto input2_val_low_high = vget_high_s16(input2_val_low);
178     const auto input2_val_high_low = vget_low_s16(input2_val_high);
179     const auto input2_val_high_high = vget_high_s16(input2_val_high);
180 
181     auto p1 = vmull_n_s16(input2_val_high_high, input1_val);
182     auto p2 = vmull_n_s16(input2_val_high_low, input1_val);
183     auto p3 = vmull_n_s16(input2_val_low_high, input1_val);
184     auto p4 = vmull_n_s16(input2_val_low_low, input1_val);
185 
186     p1 = vshlq_s32(p1, left_shift_vec);
187     p2 = vshlq_s32(p2, left_shift_vec);
188     p3 = vshlq_s32(p3, left_shift_vec);
189     p4 = vshlq_s32(p4, left_shift_vec);
190 
191     p1 = vqrdmulhq_n_s32(p1, params.output_multiplier);
192     p2 = vqrdmulhq_n_s32(p2, params.output_multiplier);
193     p3 = vqrdmulhq_n_s32(p3, params.output_multiplier);
194     p4 = vqrdmulhq_n_s32(p4, params.output_multiplier);
195     using gemmlowp::RoundingDivideByPOT;
196     p1 = RoundingDivideByPOT(p1, right_shift);
197     p2 = RoundingDivideByPOT(p2, right_shift);
198     p3 = RoundingDivideByPOT(p3, right_shift);
199     p4 = RoundingDivideByPOT(p4, right_shift);
200 
201     const auto p1_narrowed = vqmovn_s32(p1);
202     const auto p2_narrowed = vqmovn_s32(p2);
203     const auto p3_narrowed = vqmovn_s32(p3);
204     const auto p4_narrowed = vqmovn_s32(p4);
205 
206     const int16x8_t p_part1 =
207         vaddq_s16(vcombine_s16(p2_narrowed, p1_narrowed), output_offset_vector);
208     const int16x8_t p_part2 =
209         vaddq_s16(vcombine_s16(p4_narrowed, p3_narrowed), output_offset_vector);
210     const int8x16_t p = vcombine_s8(vqmovn_s16(p_part2), vqmovn_s16(p_part1));
211 
212     const auto clamped = vmaxq_s8(output_activation_min_vector,
213                                   vminq_s8(output_activation_max_vector, p));
214     vst1q_s8(output_data + i, clamped);
215   }
216 #endif  // NEON
217 
218   for (; i < size; ++i) {
219     const int32 input2_val = params.input2_offset + input2_data[i];
220     const int32 unclamped_result =
221         params.output_offset +
222         MultiplyByQuantizedMultiplier(input1_val * input2_val,
223                                       params.output_multiplier,
224                                       params.output_shift);
225     const int32 clamped_output =
226         std::min(params.quantized_activation_max,
227                  std::max(params.quantized_activation_min, unclamped_result));
228     output_data[i] = static_cast<int8>(clamped_output);
229   }
230 }
231 
Mul(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data)232 inline void Mul(const ArithmeticParams& params,
233                 const RuntimeShape& input1_shape, const int8* input1_data,
234                 const RuntimeShape& input2_shape, const int8* input2_data,
235                 const RuntimeShape& output_shape, int8* output_data) {
236   TFLITE_DCHECK_LE(params.quantized_activation_min,
237                    params.quantized_activation_max);
238   ruy::profiler::ScopeLabel label("MulInt8/8bit");
239   const int flat_size =
240       MatchingElementsSize(input1_shape, input2_shape, output_shape);
241 
242   MulElementwise(flat_size, params, input1_data, input2_data, output_data);
243 }
244 
BroadcastMulDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int8 * input1_data,const RuntimeShape & input2_shape,const int8 * input2_data,const RuntimeShape & output_shape,int8 * output_data)245 inline void BroadcastMulDispatch(const ArithmeticParams& params,
246                                  const RuntimeShape& input1_shape,
247                                  const int8* input1_data,
248                                  const RuntimeShape& input2_shape,
249                                  const int8* input2_data,
250                                  const RuntimeShape& output_shape,
251                                  int8* output_data) {
252   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
253     return reference_integer_ops::BroadcastMul4DSlow(
254         params, input1_shape, input1_data, input2_shape, input2_data,
255         output_shape, output_data);
256   }
257 
258   optimized_ops::BinaryBroadcastFiveFold(
259       params, input1_shape, input1_data, input2_shape, input2_data,
260       output_shape, output_data, MulElementwise, MulSimpleBroadcast);
261 }
262 
263 }  // namespace optimized_integer_ops
264 }  // namespace tflite
265 
266 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_MUL_H_
267