xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
17 
18 #include <algorithm>
19 #include <type_traits>
20 
21 #include "ruy/profiler/instrumentation.h"  // from @ruy
22 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
23 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h"
24 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 
27 #ifdef __AVX2__
28 #include <immintrin.h>
29 #endif
30 
31 namespace tflite {
32 namespace optimized_ops {
33 namespace depthwise_conv {
34 
35 // Implementation of quantized DepthwiseConv
36 
37 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
38 struct QuantizedDepthwiseConvKernel {};
39 
40 #ifdef USE_NEON
41 template <>
42 struct QuantizedDepthwiseConvKernel<true, 8, 2> {
43   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
44                   const uint8* input_ptr, int16 input_offset,
45                   int input_ptr_increment, const uint8* filter_ptr,
46                   int16 filter_offset, int32* acc_buffer_ptr) {
47     // Load the filters, add filter_offset.
48     uint8x8x2_t filter_u8;
49     filter_u8.val[0] = vld1_u8(filter_ptr);
50     filter_u8.val[1] = vld1_u8(filter_ptr + 8);
51     int16x8_t filter[2];
52     for (int i = 0; i < 2; i++) {
53       filter[i] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])),
54                             vdupq_n_s16(filter_offset));
55     }
56     // Handle one output pixel at a time.
57     for (int outp = 0; outp < num_output_pixels; outp++) {
58       // Load the accumulators from acc_buffer
59       int32x4x2_t acc[2];
60       for (int i = 0; i < 2; i++) {
61         acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
62         acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
63       }
64       // Load the inputs, add input_offset.
65       const uint8x8_t input_u8 = vld1_u8(input_ptr);
66       input_ptr += input_ptr_increment;
67       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
68       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
69       // Duplicate the input values, 2-fold
70       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
71       // Multiply-accumulate
72       for (int i = 0; i < 2; i++) {
73         acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]),
74                                   vget_low_s16(input_dup2.val[i]));
75         acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]),
76                                   vget_high_s16(input_dup2.val[i]));
77       }
78       // Store the accumulators back to acc_buffer
79       for (int i = 0; i < 2; i++) {
80         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
81         vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
82       }
83       acc_buffer_ptr += 16;
84     }
85   }
86 };
87 
88 template <>
89 struct QuantizedDepthwiseConvKernel<false, 8, 1> {
90   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
91                   const uint8* input_ptr, int16 input_offset,
92                   int input_ptr_increment, const uint8* filter_ptr,
93                   int16 filter_offset, int32* acc_buffer_ptr) {
94     // Load the filters, add filter_offset.
95     const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
96     const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
97     const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
98 
99     int outp = 0;
100     // Handle 2 output pixels at a time.
101     for (; outp <= num_output_pixels - 2; outp += 2) {
102       // Load the accumulators from acc_buffer.
103       int32x4_t acc[4];
104       for (int i = 0; i < 4; i++) {
105         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
106       }
107       // Load the inputs, add input_offset.
108       uint8x8_t input_u8[2];
109       for (int i = 0; i < 2; i++) {
110         input_u8[i] = vld1_u8(input_ptr + 8 * i);
111       }
112       input_ptr += 16;
113       int16x8_t input[2];
114       for (int i = 0; i < 2; i++) {
115         input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
116       }
117       for (int i = 0; i < 2; i++) {
118         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
119       }
120       // Multiply-accumulate.
121       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0]));
122       acc[1] =
123           vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0]));
124       acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1]));
125       acc[3] =
126           vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1]));
127       // Store the accumulators back to acc_buffer
128       for (int i = 0; i < 4; i++) {
129         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
130       }
131       acc_buffer_ptr += 16;
132     }
133     // Handle 1 output pixel at a time.
134     for (; outp < num_output_pixels; outp++) {
135       // Load the accumulators from acc_buffer.
136       int32x4_t acc[2];
137       acc[0] = vld1q_s32(acc_buffer_ptr);
138       acc[1] = vld1q_s32(acc_buffer_ptr + 4);
139 
140       // Load the inputs, add input_offset.
141       const uint8x8_t input_u8 = vld1_u8(input_ptr);
142       input_ptr += 8;
143       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
144       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
145       // Multiply-accumulate.
146       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input));
147       acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input));
148       // Store the accumulators back to acc_buffer
149       vst1q_s32(acc_buffer_ptr, acc[0]);
150       vst1q_s32(acc_buffer_ptr + 4, acc[1]);
151       acc_buffer_ptr += 8;
152     }
153   }
154 };
155 
156 template <>
157 struct QuantizedDepthwiseConvKernel<false, 4, 2> {
158   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
159                   const uint8* input_ptr, int16 input_offset,
160                   int input_ptr_increment, const uint8* filter_ptr,
161                   int16 filter_offset, int32* acc_buffer_ptr) {
162     // Load the filters, add filter_offset.
163     const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
164     const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
165     const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
166 
167     int outp = 0;
168     // Handle 2 output pixels at a time.
169     for (; outp <= num_output_pixels - 2; outp += 2) {
170       // Load the accumulators from acc_buffer
171       int32x4_t acc[4];
172       for (int i = 0; i < 4; i++) {
173         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
174       }
175       // Load the inputs, add input_offset.
176       const uint8x8_t input_u8 = vld1_u8(input_ptr);
177       input_ptr += 8;
178       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
179       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
180       // Duplicate the input values, 2-fold
181       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
182       // Multiply-accumulate
183       for (int i = 0; i < 2; i++) {
184         acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter),
185                                    vget_low_s16(input_dup2.val[i]));
186         acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter),
187                                    vget_high_s16(input_dup2.val[i]));
188       }
189       // Store the accumulators back to acc_buffer
190       for (int i = 0; i < 4; i++) {
191         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
192       }
193       acc_buffer_ptr += 16;
194     }
195     // Handle one output pixel at a time.
196     for (; outp < num_output_pixels; outp++) {
197       // Load the accumulators from acc_buffer
198       int32x4_t acc[2];
199       for (int i = 0; i < 2; i++) {
200         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
201       }
202       // Load the inputs, add input_offset.
203       uint8x8_t input_u8 = vdup_n_u8(0);
204       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
205       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
206       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
207       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
208       input_ptr += 4;
209       const int16x4_t input_s16 =
210           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
211       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
212       // Duplicate the input values, 2-fold
213       const int16x4x2_t input_dup2 = vzip_s16(input, input);
214       // Multiply-accumulate
215       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]);
216       acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]);
217       // Store the accumulators back to acc_buffer
218       for (int i = 0; i < 2; i++) {
219         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
220       }
221       acc_buffer_ptr += 8;
222     }
223   }
224 };
225 
226 template <>
227 struct QuantizedDepthwiseConvKernel<false, 2, 8> {
228   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
229                   const uint8* input_ptr, int16 input_offset,
230                   int input_ptr_increment, const uint8* filter_ptr,
231                   int16 filter_offset, int32* acc_buffer_ptr) {
232     // Load the filters, add filter_offset.
233     int16x8_t filter[2];
234     for (int i = 0; i < 2; i++) {
235       const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
236       const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
237       filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
238     }
239     int outp = 0;
240     // Handle two output pixels at a time.
241     for (; outp <= num_output_pixels - 2; outp += 2) {
242       // Load the accumulators from acc_buffer.
243       int32x4_t acc[8];
244       for (int i = 0; i < 8; i++) {
245         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
246       }
247       // Load the inputs, add input_offset.
248       uint8x8_t input_u8 = vdup_n_u8(0);
249       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
250       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
251       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
252       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
253       input_ptr += 4;
254       const int16x4_t input_s16 =
255           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
256       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
257       // Multiply-accumulate.
258       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
259       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
260       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
261       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
262       acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2);
263       acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2);
264       acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3);
265       acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3);
266       // Store the accumulators back to acc_buffer.
267       for (int i = 0; i < 8; i++) {
268         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
269       }
270       acc_buffer_ptr += 32;
271     }
272     // Handle one output pixel at a time.
273     for (; outp < num_output_pixels; outp++) {
274       // Load the accumulators from acc_buffer.
275       int32x4_t acc[4];
276       for (int i = 0; i < 4; i++) {
277         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
278       }
279       // Load the inputs, add input_offset.
280       uint8x8_t input_u8 = vdup_n_u8(0);
281       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
282       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
283       input_ptr += 2;
284       const int16x4_t input_s16 =
285           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
286       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
287 
288       // Multiply-accumulate.
289       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
290       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
291       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
292       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
293 
294       // Store the accumulators back to acc_buffer.
295       for (int i = 0; i < 4; i++) {
296         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
297       }
298       acc_buffer_ptr += 16;
299     }
300   }
301 };
302 
303 template <>
304 struct QuantizedDepthwiseConvKernel<false, 2, 2> {
305   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
306                   const uint8* input_ptr, int16 input_offset,
307                   int input_ptr_increment, const uint8* filter_ptr,
308                   int16 filter_offset, int32* acc_buffer_ptr) {
309     // Load the filters, add filter_offset.
310     uint8x8_t filter_u8 = vdup_n_u8(0);
311     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
312     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
313     filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
314     filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
315     const int16x4_t filter_s16 =
316         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
317     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
318 
319     int outp = 0;
320     // Handle 4 output pixels at a time.
321     for (; outp <= num_output_pixels - 4; outp += 4) {
322       // Load the accumulators from acc_buffer
323       int32x4_t acc[4];
324       for (int i = 0; i < 4; i++) {
325         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
326       }
327 
328       // Load the inputs, add input_offset.
329       const uint8x8_t input_u8 = vld1_u8(input_ptr);
330       input_ptr += 8;
331       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
332       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
333       // Duplicate the input values, 2-fold
334       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
335       // Multiply-accumulate
336       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
337       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
338       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
339       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
340       // Store the accumulators back to acc_buffer
341       for (int i = 0; i < 4; i++) {
342         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
343       }
344       acc_buffer_ptr += 16;
345     }
346     // Handle one output pixel at a time.
347     for (; outp < num_output_pixels; outp++) {
348       // Load the accumulators from acc_buffer
349       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
350 
351       uint8x8_t input_u8 = vdup_n_u8(0);
352       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
353       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
354       input_ptr += 2;
355       const int16x4_t input_s16 =
356           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
357       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
358       // Duplicate the input values, 2-fold
359       const int16x4_t input_dup2 = vzip_s16(input, input).val[0];
360       // Multiply-accumulate
361       acc = vmlal_s16(acc, filter, input_dup2);
362       // Store the accumulators back to acc_buffer
363       vst1q_s32(acc_buffer_ptr, acc);
364       acc_buffer_ptr += 4;
365     }
366   }
367 };
368 
369 template <>
370 struct QuantizedDepthwiseConvKernel<false, 2, 1> {
371   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
372                   const uint8* input_ptr, int16 input_offset,
373                   int input_ptr_increment, const uint8* filter_ptr,
374                   int16 filter_offset, int32* acc_buffer_ptr) {
375     // Load the filters, add filter_offset.
376     uint8x8_t filter_u8 = vdup_n_u8(0);
377     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
378     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
379     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
380     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
381     const int16x4_t filter_s16 =
382         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
383     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
384 
385     int outp = 0;
386     // Handle 8 output pixels at a time.
387     for (; outp <= num_output_pixels - 8; outp += 8) {
388       // Load the accumulators from acc_buffer.
389       int32x4_t acc[4];
390       for (int i = 0; i < 4; i++) {
391         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
392       }
393       // Load the inputs, add input_offset.
394       uint8x8_t input_u8[2];
395       for (int i = 0; i < 2; i++) {
396         input_u8[i] = vld1_u8(input_ptr + 8 * i);
397       }
398       input_ptr += 16;
399       int16x8_t input[2];
400       for (int i = 0; i < 2; i++) {
401         input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
402       }
403       for (int i = 0; i < 2; i++) {
404         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
405       }
406 
407       // Multiply-accumulate.
408       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0]));
409       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0]));
410       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1]));
411       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1]));
412       // Store the accumulators back to acc_buffer.
413       for (int i = 0; i < 4; i++) {
414         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
415       }
416       acc_buffer_ptr += 16;
417     }
418     // Handle 4 output pixels at a time.
419     for (; outp <= num_output_pixels - 4; outp += 4) {
420       // Load the accumulators from acc_buffer.
421       int32x4_t acc[2];
422       for (int i = 0; i < 2; i++) {
423         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
424       }
425       // Load the inputs, add input_offset.
426       const uint8x8_t input_u8 = vld1_u8(input_ptr);
427       input_ptr += 8;
428       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
429       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
430 
431       // Multiply-accumulate.
432       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input));
433       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input));
434       // Store the accumulators back to acc_buffer.
435       for (int i = 0; i < 2; i++) {
436         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
437       }
438       acc_buffer_ptr += 8;
439     }
440     // Handle 2 output pixels at a time.
441     for (; outp <= num_output_pixels - 2; outp += 2) {
442       // Load the accumulators from acc_buffer.
443       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
444       // Load the inputs, add input_offset.
445       uint8x8_t input_u8 = vdup_n_u8(0);
446       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
447       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
448       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
449       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
450       input_ptr += 4;
451       const int16x4_t input_s16 =
452           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
453       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
454 
455       // Multiply-accumulate.
456       acc = vmlal_s16(acc, filter, input);
457       // Store the accumulators back to acc_buffer.
458       vst1q_s32(acc_buffer_ptr, acc);
459       acc_buffer_ptr += 4;
460     }
461     // Handle 1 output pixel at a time.
462     for (; outp < num_output_pixels; outp++) {
463       // Load the accumulators from acc_buffer.
464       int32x2_t acc = vld1_s32(acc_buffer_ptr);
465       // Load the inputs, add input_offset.
466       uint8x8_t input_u8 = vdup_n_u8(0);
467       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
468       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
469       input_ptr += 2;
470       const int16x4_t input_s16 =
471           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
472       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
473 
474       // Multiply-accumulate.
475       acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
476       // Store the accumulators back to acc_buffer.
477       vst1_s32(acc_buffer_ptr, acc);
478       acc_buffer_ptr += 2;
479     }
480   }
481 };
482 
483 template <>
484 struct QuantizedDepthwiseConvKernel<false, 1, 2> {
485   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
486                   const uint8* input_ptr, int16 input_offset,
487                   int input_ptr_increment, const uint8* filter_ptr,
488                   int16 filter_offset, int32* acc_buffer_ptr) {
489     // Load the filters, add filter_offset.
490     uint8x8_t filter_u8 = vdup_n_u8(0);
491     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
492     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
493     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
494     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
495     const int16x4_t filter_s16 =
496         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
497     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
498 
499     int outp = 0;
500     // Handle 8 output pixels at a time.
501     for (; outp <= num_output_pixels - 8; outp += 8) {
502       // Load the accumulators from acc_buffer
503       int32x4_t acc[4];
504       for (int i = 0; i < 4; i++) {
505         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
506       }
507 
508       // Load the inputs, add input_offset.
509       const uint8x8_t input_u8 = vld1_u8(input_ptr);
510       input_ptr += 8;
511       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
512       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
513       // Duplicate the input values, 2-fold
514       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
515       // Multiply-accumulate
516       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
517       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
518       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
519       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
520       // Store the accumulators back to acc_buffer
521       for (int i = 0; i < 4; i++) {
522         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
523       }
524       acc_buffer_ptr += 16;
525     }
526     // Handle one output pixel at a time.
527     for (; outp < num_output_pixels; outp++) {
528       // Load the accumulators from acc_buffer
529       int32x2_t acc = vld1_s32(acc_buffer_ptr);
530 
531       // Load the inputs, add input_offset.
532       const uint32 input = *input_ptr++ + input_offset;
533 
534       // Multiply-accumulate
535       acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input));
536       // Store the accumulators back to acc_buffer
537       vst1_s32(acc_buffer_ptr, acc);
538       acc_buffer_ptr += 2;
539     }
540   }
541 };
542 
543 template <>
544 struct QuantizedDepthwiseConvKernel<false, 1, 4> {
545   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
546                   const uint8* input_ptr, int16 input_offset,
547                   int input_ptr_increment, const uint8* filter_ptr,
548                   int16 filter_offset, int32* acc_buffer_ptr) {
549     // Load the filters, add filter_offset.
550     uint8x8_t filter_u8 = vdup_n_u8(0);
551     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
552     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
553     filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
554     filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
555     const int16x4_t filter_s16 =
556         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
557     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
558 
559     int outp = 0;
560     // Handle 8 output pixels at a time.
561     for (; outp <= num_output_pixels - 8; outp += 8) {
562       // Load the accumulators from acc_buffer
563       int32x4_t acc[8];
564       for (int i = 0; i < 8; i++) {
565         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
566       }
567 
568       // Load the inputs, add input_offset.
569       uint8x8_t input_u8 = vld1_u8(input_ptr);
570       input_ptr += 8;
571       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
572       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
573 
574       // Multiply-accumulate
575       acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0);
576       acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1);
577       acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2);
578       acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3);
579       acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0);
580       acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1);
581       acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2);
582       acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3);
583 
584       // Store the accumulators back to acc_buffer
585       for (int i = 0; i < 8; i++) {
586         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
587       }
588       acc_buffer_ptr += 32;
589     }
590     // Handle 4 output pixels at a time.
591     for (; outp <= num_output_pixels - 4; outp += 4) {
592       // Load the accumulators from acc_buffer
593       int32x4_t acc[4];
594       for (int i = 0; i < 4; i++) {
595         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
596       }
597 
598       // Load the inputs, add input_offset.
599       uint8x8_t input_u8 = vdup_n_u8(0);
600       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
601       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
602       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
603       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
604       input_ptr += 4;
605       const int16x4_t input_s16 =
606           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
607       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
608 
609       // Multiply-accumulate
610       acc[0] = vmlal_lane_s16(acc[0], filter, input, 0);
611       acc[1] = vmlal_lane_s16(acc[1], filter, input, 1);
612       acc[2] = vmlal_lane_s16(acc[2], filter, input, 2);
613       acc[3] = vmlal_lane_s16(acc[3], filter, input, 3);
614 
615       // Store the accumulators back to acc_buffer
616       for (int i = 0; i < 4; i++) {
617         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
618       }
619       acc_buffer_ptr += 16;
620     }
621     // Handle one output pixel at a time.
622     for (; outp < num_output_pixels; outp++) {
623       // Load the accumulators from acc_buffer
624       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
625 
626       // Load the inputs, add input_offset.
627       const uint32 input = *input_ptr++ + input_offset;
628 
629       // Multiply-accumulate
630       acc = vmlal_n_s16(acc, filter, input);
631       // Store the accumulators back to acc_buffer
632       vst1q_s32(acc_buffer_ptr, acc);
633       acc_buffer_ptr += 4;
634     }
635   }
636 };
637 
638 template <>
639 struct QuantizedDepthwiseConvKernel<false, 4, 1> {
640   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
641                   const uint8* input_ptr, int16 input_offset,
642                   int input_ptr_increment, const uint8* filter_ptr,
643                   int16 filter_offset, int32* acc_buffer_ptr) {
644     // Load the filters, add filter_offset.
645     uint8x8_t filter_u8 = vdup_n_u8(0);
646     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
647     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
648     filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
649     filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
650     const int16x4_t filter_s16 =
651         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
652     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
653 
654     int outp = 0;
655     // Handle 4 output pixels at a time.
656     for (; outp <= num_output_pixels - 4; outp += 4) {
657       // Load the accumulators from acc_buffer
658       int32x4_t acc[4];
659       for (int i = 0; i < 4; i++) {
660         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
661       }
662       // Load the inputs, add input_offset.
663       int16x8_t input[2];
664       for (int i = 0; i < 2; i++) {
665         const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i);
666         const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
667         input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
668       }
669       input_ptr += 16;
670       // Multiply-accumulate
671       for (int i = 0; i < 2; i++) {
672         acc[2 * i + 0] =
673             vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i]));
674         acc[2 * i + 1] =
675             vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i]));
676       }
677       // Store the accumulators back to acc_buffer
678       for (int i = 0; i < 4; i++) {
679         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
680       }
681       acc_buffer_ptr += 16;
682     }
683     // Handle one output pixel at a time.
684     for (; outp < num_output_pixels; outp++) {
685       // Load the accumulators from acc_buffer
686       int32x4_t acc;
687       acc = vld1q_s32(acc_buffer_ptr);
688 
689       // Load the inputs, add input_offset.
690       uint8x8_t input_u8 = vdup_n_u8(0);
691       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
692       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
693       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
694       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
695       input_ptr += 4;
696       const int16x4_t input_s16 =
697           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
698       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
699       // Multiply-accumulate
700       acc = vmlal_s16(acc, filter, input);
701       // Store the accumulators back to acc_buffer
702       vst1q_s32(acc_buffer_ptr, acc);
703       acc_buffer_ptr += 4;
704     }
705   }
706 };
707 
708 template <>
709 struct QuantizedDepthwiseConvKernel<false, 4, 4> {
710   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
711                   const uint8* input_ptr, int16 input_offset,
712                   int input_ptr_increment, const uint8* filter_ptr,
713                   int16 filter_offset, int32* acc_buffer_ptr) {
714     // Load the filters, add filter_offset.
715     int16x8_t filter[2];
716     for (int i = 0; i < 2; i++) {
717       const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i);
718       const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
719       filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
720     }
721 
722     int outp = 0;
723     // Handle 2 output pixels at a time.
724     for (; outp <= num_output_pixels - 2; outp += 2) {
725       // Load the accumulators from acc_buffer
726       int32x4_t acc[8];
727       for (int i = 0; i < 8; i++) {
728         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
729       }
730 
731       // Load the inputs, add input_offset.
732       uint8x8_t input_u8 = vld1_u8(input_ptr);
733       input_ptr += 8;
734       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
735       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
736 
737       // Multiply-accumulate
738       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]),
739                               vget_low_s16(input), 0);
740       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]),
741                               vget_low_s16(input), 1);
742       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]),
743                               vget_low_s16(input), 2);
744       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]),
745                               vget_low_s16(input), 3);
746       acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]),
747                               vget_high_s16(input), 0);
748       acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]),
749                               vget_high_s16(input), 1);
750       acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]),
751                               vget_high_s16(input), 2);
752       acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]),
753                               vget_high_s16(input), 3);
754       // Store the accumulators back to acc_buffer
755       for (int i = 0; i < 8; i++) {
756         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
757       }
758       acc_buffer_ptr += 32;
759     }
760     // Handle one output pixel at a time.
761     for (; outp < num_output_pixels; outp++) {
762       // Load the accumulators from acc_buffer
763       int32x4_t acc[4];
764       for (int i = 0; i < 4; i++) {
765         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
766       }
767 
768       // Load the inputs, add input_offset.
769       uint8x8_t input_u8 = vdup_n_u8(0);
770       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
771       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
772       input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
773       input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
774       input_ptr += 4;
775       const int16x4_t input_s16 =
776           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
777       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
778 
779       // Multiply-accumulate
780       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
781       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1);
782       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2);
783       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3);
784       // Store the accumulators back to acc_buffer
785       for (int i = 0; i < 4; i++) {
786         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
787       }
788       acc_buffer_ptr += 16;
789     }
790   }
791 };
792 
793 template <>
794 struct QuantizedDepthwiseConvKernel<true, 0, 3> {
795   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
796                   const uint8* input_ptr, int16 input_offset,
797                   int input_ptr_increment, const uint8* filter_ptr,
798                   int16 filter_offset, int32* acc_buffer_ptr) {
799     // We will have to duplicate bytes in a NEON register, 3-fold.
800     // We will do that by register-level table-look-up using VTBL instructions.
801     // Here we prepare the registers containing the table-lookup indices.
802     static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2},
803                                                    {2, 3, 3, 3, 4, 4, 4, 5},
804                                                    {5, 5, 6, 6, 6, 7, 7, 7}};
805     uint8x8_t dup3_indices[3];
806     for (int i = 0; i < 3; i++) {
807       dup3_indices[i] = vld1_u8(dup3_indices_array[i]);
808     }
809 
810     // Handle one output pixel at a time.
811     for (int outp = 0; outp < num_output_pixels; outp++) {
812       const uint8* local_filter_ptr = filter_ptr;
813       const uint8* local_input_ptr = input_ptr;
814       int ic = 0;
815       // Handle 8 input channels at a time.
816       for (; ic <= input_depth - 8; ic += 8) {
817         // Load the filters, add filter_offset.
818         int16x8_t filter[3];
819         uint8x8x3_t filter_u8;
820         filter_u8.val[0] = vld1_u8(local_filter_ptr);
821         filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
822         filter_u8.val[2] = vld1_u8(local_filter_ptr + 16);
823         local_filter_ptr += 24;
824         for (int i = 0; i < 3; i++) {
825           const int16x8_t filter_s16 =
826               vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
827           filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
828         }
829         // Load the inputs, duplicate 3-fold, add input_offset.
830         const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
831         local_input_ptr += 8;
832 
833         uint8x8_t input_u8_dup3[3];
834         for (int i = 0; i < 3; i++) {
835           input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]);
836         }
837         int16x8_t input_dup3[3];
838         for (int i = 0; i < 3; i++) {
839           const int16x8_t input_s16_dup3 =
840               vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i]));
841           input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset));
842         }
843         // Load the accumulators from acc_buffer
844         int32x4x3_t acc[2];
845         for (int i = 0; i < 2; i++) {
846           acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
847           acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
848           acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16);
849         }
850         // Multiply-accumulate
851         for (int j = 0; j < 3; j++) {
852           acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]),
853                                     vget_low_s16(filter[j]));
854           acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]),
855                                     vget_high_s16(filter[j]));
856         }
857         // Store the accumulators back to acc_buffer
858         for (int i = 0; i < 2; i++) {
859           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
860           vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
861           vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]);
862         }
863         acc_buffer_ptr += 24;
864       }
865       // Handle one input channel at a time.
866       for (; ic < input_depth; ic++) {
867         const int16 input_val = *local_input_ptr++ + input_offset;
868         for (int i = 0; i < 3; i++) {
869           const int16 filter_val = local_filter_ptr[i] + filter_offset;
870           *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
871         }
872         local_filter_ptr += 3;
873       }
874       input_ptr += input_ptr_increment;
875     }
876   }
877 };
878 
879 template <>
880 struct QuantizedDepthwiseConvKernel<true, 0, 2> {
881   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
882                   const uint8* input_ptr, int16 input_offset,
883                   int input_ptr_increment, const uint8* filter_ptr,
884                   int16 filter_offset, int32* acc_buffer_ptr) {
885     // Handle one output pixel at a time.
886     for (int outp = 0; outp < num_output_pixels; outp++) {
887       const uint8* local_filter_ptr = filter_ptr;
888       const uint8* local_input_ptr = input_ptr;
889       int ic = 0;
890       // Handle 8 input channels at a time.
891       for (; ic <= input_depth - 8; ic += 8) {
892         // Load the filters, add filter_offset.
893         int16x8_t filter[2];
894         uint8x8x2_t filter_u8;
895         filter_u8.val[0] = vld1_u8(local_filter_ptr);
896         filter_u8.val[1] = vld1_u8(local_filter_ptr + 8);
897         local_filter_ptr += 16;
898         for (int i = 0; i < 2; i++) {
899           const int16x8_t filter_s16 =
900               vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i]));
901           filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
902         }
903         // Load the inputs, add input_offset, duplicate 2-fold.
904         const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
905         local_input_ptr += 8;
906         const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
907         const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
908         const int16x8x2_t input_dup2 = vzipq_s16(input, input);
909         // Load the accumulators from acc_buffer.
910         int32x4x2_t acc[2];
911         for (int i = 0; i < 2; i++) {
912           acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
913           acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
914         }
915         // Multiply-accumulate.
916         for (int j = 0; j < 2; j++) {
917           acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]),
918                                     vget_low_s16(input_dup2.val[j]));
919           acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]),
920                                     vget_high_s16(input_dup2.val[j]));
921         }
922         // Store the accumulators back to acc_buffer.
923         for (int i = 0; i < 2; i++) {
924           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
925           vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
926         }
927         acc_buffer_ptr += 16;
928       }
929       // Handle one input channel at a time.
930       for (; ic < input_depth; ic++) {
931         // Load the inputs.
932         const int16 input_val = *local_input_ptr++ + input_offset;
933         for (int i = 0; i < 2; i++) {
934           const int16 filter_val = local_filter_ptr[i] + filter_offset;
935           *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
936         }
937         local_filter_ptr += 2;
938       }
939       input_ptr += input_ptr_increment;
940     }
941   }
942 };
943 
944 template <>
945 struct QuantizedDepthwiseConvKernel<true, 0, 1> {
946   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
947                   const uint8* input_ptr, int16 input_offset,
948                   int input_ptr_increment, const uint8* filter_ptr,
949                   int16 filter_offset, int32* acc_buffer_ptr) {
950     // Handle one output pixel at a time.
951     for (int outp = 0; outp < num_output_pixels; outp++) {
952       const uint8* local_filter_ptr = filter_ptr;
953       const uint8* local_input_ptr = input_ptr;
954       int ic = 0;
955       // Handle 16 input channels at a time.
956       for (; ic <= input_depth - 16; ic += 16) {
957 #ifdef __AVX2__
958         // Load the filters, add filter_offset.
959         __m128i filter_u8_0 = _mm_loadl_epi64(
960             reinterpret_cast<const __m128i*>(local_filter_ptr + 8 * 0));
961         __m128i filter_u8_1 = _mm_loadl_epi64(
962             reinterpret_cast<const __m128i*>(local_filter_ptr + 8 * 1));
963         local_filter_ptr += 16;
964         __m256i filter_0 = _mm256_cvtepu8_epi32(filter_u8_0);
965         __m256i filter_1 = _mm256_cvtepu8_epi32(filter_u8_1);
966         __m256i filter_offset_vec = _mm256_set1_epi32(filter_offset);
967         filter_0 = _mm256_add_epi32(filter_0, filter_offset_vec);
968         filter_1 = _mm256_add_epi32(filter_1, filter_offset_vec);
969         // Load the inputs, add input_offset.
970         __m128i input_u8_0 = _mm_loadl_epi64(
971             reinterpret_cast<const __m128i*>(local_input_ptr + 8 * 0));
972         __m128i input_u8_1 = _mm_loadl_epi64(
973             reinterpret_cast<const __m128i*>(local_input_ptr + 8 * 1));
974         local_input_ptr += 16;
975         __m256i input_0 = _mm256_cvtepu8_epi32(input_u8_0);
976         __m256i input_1 = _mm256_cvtepu8_epi32(input_u8_1);
977         __m256i input_offset_vec = _mm256_set1_epi32(input_offset);
978         input_0 = _mm256_add_epi32(input_0, input_offset_vec);
979         input_1 = _mm256_add_epi32(input_1, input_offset_vec);
980         // Load the accumulators from acc_buffer
981         __m256i acc_0 = _mm256_loadu_si256(
982             reinterpret_cast<const __m256i*>(acc_buffer_ptr + 8 * 0));
983         __m256i acc_1 = _mm256_loadu_si256(
984             reinterpret_cast<const __m256i*>(acc_buffer_ptr + 8 * 1));
985         acc_0 = _mm256_add_epi32(acc_0, _mm256_mullo_epi32(input_0, filter_0));
986         acc_1 = _mm256_add_epi32(acc_1, _mm256_mullo_epi32(input_1, filter_1));
987         // Store the accumulators back to acc_buffer
988         _mm256_storeu_si256(reinterpret_cast<__m256i*>(acc_buffer_ptr + 8 * 0),
989                             acc_0);
990         _mm256_storeu_si256(reinterpret_cast<__m256i*>(acc_buffer_ptr + 8 * 1),
991                             acc_1);
992         acc_buffer_ptr += 16;
993 #else
994         // Load the filters, add filter_offset.
995         uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0);
996         uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1);
997         local_filter_ptr += 16;
998         int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
999         int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1000         filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
1001         filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
1002         // Load the inputs, add input_offset.
1003         uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0);
1004         uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1);
1005         local_input_ptr += 16;
1006         int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
1007         int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
1008         input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
1009         input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
1010         // Load the accumulators from acc_buffer
1011         int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1012         int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1013         int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1014         int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1015         acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0));
1016         acc_1 =
1017             vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0));
1018         acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1));
1019         acc_3 =
1020             vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1));
1021         // Store the accumulators back to acc_buffer
1022         vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1023         vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1024         vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1025         vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1026         acc_buffer_ptr += 16;
1027 #endif
1028       }
1029       // Handle 8 input channels at a time.
1030       for (; ic <= input_depth - 8; ic += 8) {
1031         // Load the filters, add filter_offset.
1032         const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr);
1033         local_filter_ptr += 8;
1034         const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
1035         const int16x8_t filter =
1036             vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
1037         // Load the inputs, add input_offset.
1038         const uint8x8_t input_u8 = vld1_u8(local_input_ptr);
1039         local_input_ptr += 8;
1040         const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
1041         const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
1042         // Load the accumulators from acc_buffer
1043         int32x4_t acc[2];
1044         for (int i = 0; i < 2; i++) {
1045           acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1046         }
1047         // Multiply-accumulate
1048         acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
1049         acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
1050         // Store the accumulators back to acc_buffer
1051         for (int i = 0; i < 2; i++) {
1052           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1053         }
1054         acc_buffer_ptr += 8;
1055       }
1056       // Handle one input channel at a time.
1057       for (; ic < input_depth; ic++) {
1058         const int16 input_val = *local_input_ptr++ + input_offset;
1059         const int16 filter_val = *local_filter_ptr++ + filter_offset;
1060         *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
1061       }
1062       input_ptr += input_ptr_increment;
1063     }
1064   }
1065 };
1066 
1067 template <>
1068 struct QuantizedDepthwiseConvKernel<true, 16, 1> {
1069   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1070                   const uint8* input_ptr, int16 input_offset,
1071                   int input_ptr_increment, const uint8* filter_ptr,
1072                   int16 filter_offset, int32* acc_buffer_ptr) {
1073     // Load the filters, add filter_offset.
1074     uint8x8_t filter_u8[2];
1075     for (int i = 0; i < 2; i++) {
1076       filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
1077     }
1078     int16x8_t filter[2];
1079     for (int i = 0; i < 2; i++) {
1080       filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
1081     }
1082     for (int i = 0; i < 2; i++) {
1083       filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
1084     }
1085     // Handle one output pixel at a time.
1086     for (int outp = 0; outp < num_output_pixels; outp++) {
1087       // Load the inputs, add input_offset.
1088       uint8x8_t input_u8[2];
1089       for (int i = 0; i < 2; i++) {
1090         input_u8[i] = vld1_u8(input_ptr + 8 * i);
1091       }
1092       input_ptr += input_ptr_increment;
1093       int16x8_t input[2];
1094       for (int i = 0; i < 2; i++) {
1095         input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i]));
1096       }
1097       for (int i = 0; i < 2; i++) {
1098         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
1099       }
1100       // Load the accumulators from acc_buffer
1101       int32x4_t acc[4];
1102       for (int i = 0; i < 4; i++) {
1103         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1104       }
1105       // Multiply-accumulate
1106       for (int i = 0; i < 2; i++) {
1107         acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]),
1108                                    vget_low_s16(filter[i]));
1109         acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]),
1110                                    vget_high_s16(filter[i]));
1111       }
1112       // Store the accumulators back to acc_buffer
1113       for (int i = 0; i < 4; i++) {
1114         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1115       }
1116       acc_buffer_ptr += 16;
1117     }
1118   }
1119 };
1120 
1121 template <>
1122 struct QuantizedDepthwiseConvKernel<true, 8, 1> {
1123   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1124                   const uint8* input_ptr, int16 input_offset,
1125                   int input_ptr_increment, const uint8* filter_ptr,
1126                   int16 filter_offset, int32* acc_buffer_ptr) {
1127     // Load the filters, add filter_offset.
1128     const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
1129     const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8));
1130     const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset));
1131     // Handle one output pixel at a time.
1132     for (int outp = 0; outp < num_output_pixels; outp++) {
1133       // Load the inputs, add input_offset.
1134       const uint8x8_t input_u8 = vld1_u8(input_ptr);
1135       const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8));
1136       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
1137       // Load the accumulators from acc_buffer
1138       int32x4_t acc[2];
1139       for (int i = 0; i < 2; i++) {
1140         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1141       }
1142       // Multiply-accumulate
1143       acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
1144       acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
1145       // Store the accumulators back to acc_buffer
1146       for (int i = 0; i < 2; i++) {
1147         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1148       }
1149       acc_buffer_ptr += 8;
1150       input_ptr += input_ptr_increment;
1151     }
1152   }
1153 };
1154 
1155 template <>
1156 struct QuantizedDepthwiseConvKernel<true, 1, 16> {
1157   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1158                   const uint8* input_ptr, int16 input_offset,
1159                   int input_ptr_increment, const uint8* filter_ptr,
1160                   int16 filter_offset, int32* acc_buffer_ptr) {
1161     // Load the filters, add filter_offset.
1162     uint8x8_t filter_u8[2];
1163     for (int i = 0; i < 2; i++) {
1164       filter_u8[i] = vld1_u8(filter_ptr + 8 * i);
1165     }
1166     int16x8_t filter[2];
1167     for (int i = 0; i < 2; i++) {
1168       filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i]));
1169     }
1170     for (int i = 0; i < 2; i++) {
1171       filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset));
1172     }
1173     // Handle one output pixel at a time.
1174     for (int outp = 0; outp < num_output_pixels; outp++) {
1175       uint8 input_u8 = *input_ptr;
1176       input_ptr += input_ptr_increment;
1177       int16 input = static_cast<int16>(input_u8 + input_offset);
1178       // Load the accumulators from acc_buffer
1179       int32x4_t acc[4];
1180       for (int i = 0; i < 4; i++) {
1181         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1182       }
1183       // Multiply-accumulate
1184       for (int i = 0; i < 2; i++) {
1185         acc[2 * i + 0] =
1186             vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input);
1187         acc[2 * i + 1] =
1188             vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input);
1189       }
1190       // Store the accumulators back to acc_buffer
1191       for (int i = 0; i < 4; i++) {
1192         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1193       }
1194       acc_buffer_ptr += 16;
1195     }
1196   }
1197 };
1198 
1199 template <>
1200 struct QuantizedDepthwiseConvKernel<true, 1, 32> {
1201   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1202                   const uint8* input_ptr, int16 input_offset,
1203                   int input_ptr_increment, const uint8* filter_ptr,
1204                   int16 filter_offset, int32* acc_buffer_ptr) {
1205     // Load the filters, add filter_offset.
1206     uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0);
1207     uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1);
1208     uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2);
1209     uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3);
1210     int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1211     int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1212     int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2));
1213     int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3));
1214     filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
1215     filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
1216     filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset));
1217     filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset));
1218     // Handle one output pixel at a time.
1219     for (int outp = 0; outp < num_output_pixels; outp++) {
1220       uint8 input_u8 = *input_ptr;
1221       input_ptr += input_ptr_increment;
1222       int16 input = static_cast<int16>(input_u8 + input_offset);
1223       // Load the accumulators from acc_buffer
1224       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1225       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1226       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1227       int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1228       int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
1229       int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5);
1230       int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6);
1231       int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7);
1232       // Multiply-accumulate
1233       acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
1234       acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
1235       acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
1236       acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
1237       acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input);
1238       acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input);
1239       acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input);
1240       acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input);
1241       // Store the accumulators back to acc_buffer
1242       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1243       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1244       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1245       vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1246       vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
1247       vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5);
1248       vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6);
1249       vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7);
1250       acc_buffer_ptr += 32;
1251     }
1252   }
1253 };
1254 
1255 template <>
1256 struct QuantizedDepthwiseConvKernel<true, 1, 20> {
1257   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1258                   const uint8* input_ptr, int16 input_offset,
1259                   int input_ptr_increment, const uint8* filter_ptr,
1260                   int16 filter_offset, int32* acc_buffer_ptr) {
1261     // Load the filters, add filter_offset.
1262     // NEON wants to load 8 bytes at a time, but 20 is not divisible by 8.
1263     // We load the first 16 bytes into filter_u8_{0,1} as usual.
1264     // Then we load the 8 last bytes into filter_u8_x  (x for 'extra').
1265     // This is redundant: the first 4 bytes of filter_u8_x are the same
1266     // as the last 4 bytes of filter_u8_x.
1267     uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0);
1268     uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1);
1269     uint8x8_t filter_u8_x = vld1_u8(filter_ptr + 8 * 1 + 4);
1270     int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1271     int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1272     int16x8_t filter_x = vreinterpretq_s16_u16(vmovl_u8(filter_u8_x));
1273     filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset));
1274     filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset));
1275     filter_x = vaddq_s16(filter_x, vdupq_n_s16(filter_offset));
1276     // Handle one output pixel at a time.
1277     for (int outp = 0; outp < num_output_pixels; outp++) {
1278       uint8 input_u8 = *input_ptr;
1279       input_ptr += input_ptr_increment;
1280       int16 input = static_cast<int16>(input_u8 + input_offset);
1281       // Load the accumulators from acc_buffer
1282       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1283       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1284       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1285       int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1286       int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
1287       // Multiply-accumulate
1288       acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
1289       acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
1290       acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
1291       acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
1292       acc_4 = vmlal_n_s16(acc_4, vget_high_s16(filter_x), input);
1293       // Store the accumulators back to acc_buffer
1294       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1295       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1296       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1297       vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1298       vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
1299       acc_buffer_ptr += 20;
1300     }
1301   }
1302 };
1303 
1304 template <>
1305 struct QuantizedDepthwiseConvKernel<true, 1, 8> {
1306   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1307                   const uint8* input_ptr, int16 input_offset,
1308                   int input_ptr_increment, const uint8* filter_ptr,
1309                   int16 filter_offset, int32* acc_buffer_ptr) {
1310     // Load the filters, add filter_offset.
1311     const uint8x8_t filter_u8 = vld1_u8(filter_ptr);
1312     const int16x8_t filter = vaddq_s16(
1313         vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset));
1314     // Handle one output pixel at a time.
1315     for (int outp = 0; outp < num_output_pixels; outp++) {
1316       uint8 input_u8 = *input_ptr;
1317       input_ptr += input_ptr_increment;
1318       int16 input = static_cast<int16>(input_u8 + input_offset);
1319       // Load the accumulators from acc_buffer
1320       int32x4_t acc[2];
1321       for (int i = 0; i < 2; i++) {
1322         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1323       }
1324       // Multiply-accumulate
1325       acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input);
1326       acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input);
1327       // Store the accumulators back to acc_buffer
1328       for (int i = 0; i < 2; i++) {
1329         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1330       }
1331       acc_buffer_ptr += 8;
1332     }
1333   }
1334 };
1335 
1336 template <>
1337 struct QuantizedDepthwiseConvKernel<true, 2, 1> {
1338   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1339                   const uint8* input_ptr, int16 input_offset,
1340                   int input_ptr_increment, const uint8* filter_ptr,
1341                   int16 filter_offset, int32* acc_buffer_ptr) {
1342     // Load the filters, add filter_offset.
1343     uint8x8_t filter_u8 = vdup_n_u8(0);
1344     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
1345     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
1346     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2);
1347     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3);
1348     const int16x4_t filter_s16 =
1349         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
1350     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
1351 
1352     int outp = 0;
1353 
1354     // Handle 2 output pixels at a time.
1355     for (; outp <= num_output_pixels - 2; outp += 2) {
1356       // Load the accumulators from acc_buffer.
1357       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
1358       // Load the inputs, add input_offset.
1359       uint16x4_t input_u16 = vdup_n_u16(0);
1360       input_u16 = vset_lane_u16((reinterpret_cast<const uint16*>(input_ptr))[0],
1361                                 input_u16, 0);
1362       input_ptr += input_ptr_increment;
1363       input_u16 = vset_lane_u16((reinterpret_cast<const uint16*>(input_ptr))[0],
1364                                 input_u16, 1);
1365       input_ptr += input_ptr_increment;
1366       const int16x4_t input_s16 = vreinterpret_s16_u16(
1367           vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16))));
1368       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1369 
1370       // Multiply-accumulate.
1371       acc = vmlal_s16(acc, filter, input);
1372       // Store the accumulators back to acc_buffer.
1373       vst1q_s32(acc_buffer_ptr, acc);
1374       acc_buffer_ptr += 4;
1375     }
1376 
1377     // Handle 1 output pixel at a time.
1378     for (; outp < num_output_pixels; outp++) {
1379       // Load the accumulators from acc_buffer.
1380       int32x2_t acc = vld1_s32(acc_buffer_ptr);
1381       // Load the inputs, add input_offset.
1382       uint8x8_t input_u8 = vdup_n_u8(0);
1383       input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
1384       input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
1385       input_ptr += input_ptr_increment;
1386       const int16x4_t input_s16 =
1387           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1388       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1389 
1390       // Multiply-accumulate.
1391       acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
1392       // Store the accumulators back to acc_buffer.
1393       vst1_s32(acc_buffer_ptr, acc);
1394       acc_buffer_ptr += 2;
1395     }
1396   }
1397 };
1398 
1399 template <>
1400 struct QuantizedDepthwiseConvKernel<true, 4, 1> {
1401   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1402                   const uint8* input_ptr, int16 input_offset,
1403                   int input_ptr_increment, const uint8* filter_ptr,
1404                   int16 filter_offset, int32* acc_buffer_ptr) {
1405     if (num_output_pixels <= 0) {
1406       return;
1407     }
1408 
1409     // Load the filters, add filter_offset.
1410     uint8x8_t filter_u8 = vdup_n_u8(0);
1411     filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0);
1412     filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1);
1413     filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2);
1414     filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3);
1415     const int16x4_t filter_s16 =
1416         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8)));
1417     const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset));
1418 
1419     int outp = 0;
1420 
1421     // Handle one output pixel at a time until second to the last pixel. Second
1422     // to the last because we read eight input pixels while only processing
1423     // four.
1424     for (; outp < num_output_pixels - 1; outp++) {
1425       // Load the accumulators from acc_buffer
1426       int32x4_t acc;
1427       acc = vld1q_s32(acc_buffer_ptr);
1428 
1429       // Load the inputs, add input_offset.
1430       uint8x8_t input_u8 = vld1_u8(input_ptr);
1431       input_ptr += input_ptr_increment;
1432       const int16x4_t input_s16 =
1433           vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1434       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1435       // Multiply-accumulate
1436       acc = vmlal_s16(acc, filter, input);
1437       // Store the accumulators back to acc_buffer
1438       vst1q_s32(acc_buffer_ptr, acc);
1439       acc_buffer_ptr += 4;
1440     }
1441 
1442     // Handle the last output pixel.
1443     // Load the accumulators from acc_buffer
1444     int32x4_t acc;
1445     acc = vld1q_s32(acc_buffer_ptr);
1446 
1447     // Load the inputs, add input_offset.
1448     uint8x8_t input_u8 = vdup_n_u8(0);
1449     input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0);
1450     input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1);
1451     input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2);
1452     input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3);
1453     const int16x4_t input_s16 =
1454         vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8)));
1455     const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1456     // Multiply-accumulate
1457     acc = vmlal_s16(acc, filter, input);
1458     // Store the accumulators back to acc_buffer
1459     vst1q_s32(acc_buffer_ptr, acc);
1460   }
1461 };
1462 
1463 template <>
1464 struct QuantizedDepthwiseConvKernel<false, 12, 1> {
1465   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1466                   const uint8* input_ptr, int16 input_offset,
1467                   int input_ptr_increment, const uint8* filter_ptr,
1468                   int16 filter_offset, int32* acc_buffer_ptr) {
1469     // Load the filters, add filter_offset.
1470     uint8x8_t filter_u8_0 = vld1_u8(filter_ptr);
1471     uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4);
1472     int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0));
1473     int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1));
1474     filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset));
1475     filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset));
1476     int16x4_t filter_0 = vget_low_s16(filter_s16_0);
1477     int16x4_t filter_1 = vget_high_s16(filter_s16_0);
1478     int16x4_t filter_2 = vget_high_s16(filter_s16_1);
1479 
1480     // Handle one output pixel at a time.
1481     for (int outp = 0; outp < num_output_pixels; outp++) {
1482       // Load the inputs, add input_offset.
1483       uint8x8_t input_u8_0 = vld1_u8(input_ptr);
1484       uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4);
1485       input_ptr += input_ptr_increment;
1486       int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0));
1487       int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1));
1488       input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
1489       input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
1490 
1491       // Load the accumulators from acc_buffer
1492       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1493       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1494       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1495 
1496       // Multiply-accumulate
1497       acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0);
1498       acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1);
1499       acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2);
1500 
1501       // Store the accumulators back to acc_buffer
1502       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1503       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1504       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1505 
1506       acc_buffer_ptr += 12;
1507     }
1508   }
1509 };
1510 #endif
1511 
1512 // Accumulates the effect of one row of the filter, on a segment of one row
1513 // of the output, accessing the corresponding one row of the input.
1514 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
1515 void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
1516                                     int input_depth, int input_width,
1517                                     const uint8* input_data, int16 input_offset,
1518                                     int pad_width, int depth_multiplier,
1519                                     int filter_width, const uint8* filter_data,
1520                                     int16 filter_offset, int out_x_buffer_start,
1521                                     int out_x_buffer_end, int output_depth,
1522                                     int32* acc_buffer) {
1523   ruy::profiler::ScopeLabel label(TFLITE_PRETTY_FUNCTION);
1524   // Consistency check parameters. This is important in particular to ensure
1525   // that we keep the number of template instantiations minimal, so we don't
1526   // increase binary size unnecessarily.
1527   static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
1528   static_assert(kFixedInputDepth || kAllowStrided, "");
1529   TFLITE_DCHECK(stride == 1 || kAllowStrided);
1530   if (kFixedInputDepth) {
1531     TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth);
1532   }
1533   if (kFixedDepthMultiplier) {
1534     TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
1535   }
1536   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
1537   const int input_ptr_increment = stride * input_depth;
1538   const uint8* filter_base_ptr = filter_data;
1539   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
1540     // For the current (filter_x, filter_y) point in the filter,
1541     // compute the boundaries of the corresponding output row segment.
1542     int out_x_loop_start_unclamped = 0;
1543     int out_x_loop_end_unclamped = 0;
1544     if (kAllowStrided) {
1545       if (stride == 2) {
1546         out_x_loop_start_unclamped =
1547             (pad_width - dilation_factor * filter_x + 1) / 2;
1548         out_x_loop_end_unclamped =
1549             (pad_width + input_width - dilation_factor * filter_x + 1) / 2;
1550       } else if (stride == 4) {
1551         out_x_loop_start_unclamped =
1552             (pad_width - dilation_factor * filter_x + 3) / 4;
1553         out_x_loop_end_unclamped =
1554             (pad_width + input_width - dilation_factor * filter_x + 3) / 4;
1555       } else {
1556         out_x_loop_start_unclamped =
1557             (pad_width - dilation_factor * filter_x + stride - 1) / stride;
1558         out_x_loop_end_unclamped = (pad_width + input_width -
1559                                     dilation_factor * filter_x + stride - 1) /
1560                                    stride;
1561       }
1562     } else {
1563       out_x_loop_start_unclamped = pad_width - dilation_factor * filter_x;
1564       out_x_loop_end_unclamped =
1565           pad_width + input_width - dilation_factor * filter_x;
1566     }
1567     // The kernel will have to iterate on the segment of the
1568     // output row that starts at out_x_loop_start and out_x_loop_end.
1569     const int out_x_loop_start =
1570         std::max(out_x_buffer_start, out_x_loop_start_unclamped);
1571     const int out_x_loop_end =
1572         std::min(out_x_buffer_end, out_x_loop_end_unclamped);
1573 
1574     int32* acc_buffer_ptr =
1575         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
1576     const int in_x_origin =
1577         (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
1578     const uint8* input_ptr = input_data + in_x_origin * input_depth;
1579     const int num_output_pixels = out_x_loop_end - out_x_loop_start;
1580     QuantizedDepthwiseConvKernel<
1581         kAllowStrided, kFixedInputDepth,
1582         kFixedDepthMultiplier>::Run(num_output_pixels, input_depth,
1583                                     depth_multiplier, input_ptr, input_offset,
1584                                     input_ptr_increment, filter_base_ptr,
1585                                     filter_offset, acc_buffer_ptr);
1586     filter_base_ptr += output_depth;
1587   }
1588 }
1589 
1590 // generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
1591 inline void QuantizedDepthwiseConvAccumRowGeneric(
1592     int stride, int dilation_factor, int input_depth, int input_width,
1593     const uint8* input_data, int16 input_offset, int pad_width,
1594     int depth_multiplier, int filter_width, const uint8* filter_data,
1595     int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end,
1596     int output_depth, int32* acc_buffer) {
1597   ruy::profiler::ScopeLabel label("DepthwiseConvAccumRowGeneric (slow)");
1598   const uint8* filter_base_ptr = filter_data;
1599   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
1600     const int out_x_loop_start = std::max(
1601         out_x_buffer_start,
1602         (pad_width - dilation_factor * filter_x + stride - 1) / stride);
1603     const int out_x_loop_end = std::min(
1604         out_x_buffer_end,
1605         (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
1606             stride);
1607 
1608     int32* acc_buffer_ptr =
1609         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
1610     const int in_x_origin =
1611         (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
1612     const uint8* input_ptr = input_data + in_x_origin * input_depth;
1613     const int input_ptr_increment = (stride - 1) * input_depth;
1614     for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
1615       const uint8* filter_ptr = filter_base_ptr;
1616       for (int ic = 0; ic < input_depth; ++ic) {
1617         const int16 input_val = *input_ptr++ + input_offset;
1618         for (int m = 0; m < depth_multiplier; m++) {
1619           const int16 filter_val = *filter_ptr++ + filter_offset;
1620           *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
1621         }
1622       }
1623       input_ptr += input_ptr_increment;
1624     }
1625     filter_base_ptr += output_depth;
1626   }
1627 }
1628 
1629 // Initializes the accumulator buffer with bias values.
1630 inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
1631                                        const int32* bias_data,
1632                                        int32* acc_buffer) {
1633   int i = 0;
1634 #ifdef USE_NEON
1635   if (output_depth == 1) {
1636     const int32x4_t b = vdupq_n_s32(bias_data[0]);
1637     for (; i <= num_output_pixels - 16; i += 16) {
1638       vst1q_s32(acc_buffer + i + 0, b);
1639       vst1q_s32(acc_buffer + i + 4, b);
1640       vst1q_s32(acc_buffer + i + 8, b);
1641       vst1q_s32(acc_buffer + i + 12, b);
1642     }
1643     for (; i <= num_output_pixels - 4; i += 4) {
1644       vst1q_s32(acc_buffer + i, b);
1645     }
1646   } else if (output_depth == 2) {
1647     int32x4_t b = vdupq_n_s32(bias_data[0]);
1648     b = vsetq_lane_s32(bias_data[1], b, 1);
1649     b = vsetq_lane_s32(bias_data[1], b, 3);
1650     for (; i <= num_output_pixels - 8; i += 8) {
1651       vst1q_s32(acc_buffer + 2 * i + 0, b);
1652       vst1q_s32(acc_buffer + 2 * i + 4, b);
1653       vst1q_s32(acc_buffer + 2 * i + 8, b);
1654       vst1q_s32(acc_buffer + 2 * i + 12, b);
1655     }
1656     for (; i <= num_output_pixels - 2; i += 2) {
1657       vst1q_s32(acc_buffer + 2 * i, b);
1658     }
1659   } else if (output_depth == 4) {
1660     const int32x4_t b = vld1q_s32(bias_data);
1661     for (; i <= num_output_pixels - 4; i += 4) {
1662       vst1q_s32(acc_buffer + 4 * i + 0, b);
1663       vst1q_s32(acc_buffer + 4 * i + 4, b);
1664       vst1q_s32(acc_buffer + 4 * i + 8, b);
1665       vst1q_s32(acc_buffer + 4 * i + 12, b);
1666     }
1667     for (; i < num_output_pixels; i++) {
1668       vst1q_s32(acc_buffer + 4 * i, b);
1669     }
1670   } else if (output_depth == 8) {
1671     const int32x4_t b0 = vld1q_s32(bias_data);
1672     const int32x4_t b1 = vld1q_s32(bias_data + 4);
1673     for (; i <= num_output_pixels - 2; i += 2) {
1674       vst1q_s32(acc_buffer + 8 * i + 0, b0);
1675       vst1q_s32(acc_buffer + 8 * i + 4, b1);
1676       vst1q_s32(acc_buffer + 8 * i + 8, b0);
1677       vst1q_s32(acc_buffer + 8 * i + 12, b1);
1678     }
1679     for (; i < num_output_pixels; i++) {
1680       vst1q_s32(acc_buffer + 8 * i + 0, b0);
1681       vst1q_s32(acc_buffer + 8 * i + 4, b1);
1682     }
1683   } else if (output_depth == 16) {
1684     const int32x4_t b0 = vld1q_s32(bias_data);
1685     const int32x4_t b1 = vld1q_s32(bias_data + 4);
1686     const int32x4_t b2 = vld1q_s32(bias_data + 8);
1687     const int32x4_t b3 = vld1q_s32(bias_data + 12);
1688     for (; i < num_output_pixels; i++) {
1689       vst1q_s32(acc_buffer + 16 * i + 0, b0);
1690       vst1q_s32(acc_buffer + 16 * i + 4, b1);
1691       vst1q_s32(acc_buffer + 16 * i + 8, b2);
1692       vst1q_s32(acc_buffer + 16 * i + 12, b3);
1693     }
1694   }
1695 #endif
1696   for (; i < num_output_pixels; i++) {
1697     memcpy(acc_buffer + i * output_depth, bias_data,
1698            sizeof(acc_buffer[0]) * output_depth);
1699   }
1700 }
1701 
1702 inline void DepthwiseConvGeneral(
1703     const DepthwiseParams& params, const RuntimeShape& input_shape,
1704     const uint8* input_data, const RuntimeShape& filter_shape,
1705     const uint8* filter_data, const RuntimeShape& bias_shape,
1706     const int32* bias_data, const RuntimeShape& output_shape,
1707     uint8* output_data, int thread_start, int thread_end, int thread_dim) {
1708   const int stride_width = params.stride_width;
1709   const int stride_height = params.stride_height;
1710   const int pad_width = params.padding_values.width;
1711   const int pad_height = params.padding_values.height;
1712   const int depth_multiplier = params.depth_multiplier;
1713   const int32 output_activation_min = params.quantized_activation_min;
1714   const int32 output_activation_max = params.quantized_activation_max;
1715   const int32 input_offset = params.input_offset;
1716   const int32 filter_offset = params.weights_offset;
1717   const int32 output_offset = params.output_offset;
1718   const int32 output_multiplier = params.output_multiplier;
1719   const int output_shift = params.output_shift;
1720   const int dilation_width_factor = params.dilation_width_factor;
1721   const int dilation_height_factor = params.dilation_height_factor;
1722   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
1723   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
1724   const int input_height = input_shape.Dims(1);
1725   const int input_width = input_shape.Dims(2);
1726   const int input_depth = input_shape.Dims(3);
1727   const int filter_height = filter_shape.Dims(1);
1728   const int filter_width = filter_shape.Dims(2);
1729   const int output_height = output_shape.Dims(1);
1730   const int output_width = output_shape.Dims(2);
1731 #ifdef USE_NEON
1732   const bool shift_left = (output_shift > 0);
1733   const int32 multiplier_power_of_two = shift_left ? (1 << output_shift) : 1;
1734 #endif
1735 
1736   // The default Accbuffer size is 2048, will allocate a bigger memory if it's
1737   // not enough.
1738   // TODO(b/136089667): If output_depth > 2048 happens a lot, we should just use
1739   // a scratch tensor.
1740   static const int kStackAccBufferSize = 2048;
1741   int acc_buffer_size = kStackAccBufferSize;
1742   int32 stack_acc_buffer[kStackAccBufferSize];
1743   int32* acc_buffer = stack_acc_buffer;
1744   std::unique_ptr<int32[]> heap_acc_buffer;
1745   if (kStackAccBufferSize < output_depth) {
1746     heap_acc_buffer.reset(new int32[output_depth]);
1747     acc_buffer = heap_acc_buffer.get();
1748     acc_buffer_size = output_depth;
1749   }
1750   const int kOutputPixelsInAccBuffer = acc_buffer_size / output_depth;
1751   const int acc_buffer_size_actually_used =
1752       kOutputPixelsInAccBuffer * output_depth;
1753   TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth,
1754                    acc_buffer_size_actually_used);
1755   TFLITE_DCHECK_LE(acc_buffer_size_actually_used, acc_buffer_size);
1756   TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1);
1757   TFLITE_DCHECK(thread_dim == 0 || thread_dim == 1);
1758 
1759   // row_accum_func will point to the core accumulation function to be used
1760   // for this DepthwiseConv op.
1761   using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric);
1762   row_accum_func_t row_accum_func = nullptr;
1763 
1764 #define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
1765                                         FIXED_DEPTH_MULTIPLIER)           \
1766   if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) &&          \
1767       (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) &&     \
1768       depth_multiplier == FIXED_DEPTH_MULTIPLIER) {                       \
1769     row_accum_func =                                                      \
1770         QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH,  \
1771                                        FIXED_DEPTH_MULTIPLIER>;           \
1772   }
1773 
1774 #ifdef USE_NEON
1775   // We go over our list of kernels by decreasing order of preference
1776   // for the cases where multiple kernels could apply.
1777 
1778   // Start with the fastest kernels: AllowStrided=false, fixed input depth.
1779 
1780   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2)
1781   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2)
1782   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2)
1783   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4)
1784   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1)
1785   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4)
1786   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
1787   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8)
1788   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
1789   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1)
1790 
1791   // Next come the strided kernels: AllowStrided=true, fixed input depth.
1792   // They are a bit less efficient, but allow stride!=1.
1793 
1794   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2)
1795   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1)
1796   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16)
1797   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20)
1798   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32)
1799   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
1800   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1)
1801   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1)
1802   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1)
1803 
1804   // Finally, the kernels allowing a variable input depth,
1805   // these are the least efficient but most general kernels.
1806 
1807   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
1808   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
1809   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3)
1810 #endif  // USE_NEON
1811 
1812   // No matching fast kernel found, use slow fallback.
1813   if (!row_accum_func) {
1814     row_accum_func = QuantizedDepthwiseConvAccumRowGeneric;
1815   }
1816 
1817 #undef TFMINI_USE_DEPTHWISECONV_KERNEL
1818 
1819   const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
1820   const int input_batch_stride = input_height_stride * input_shape.Dims(1);
1821   const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
1822 
1823   // Now that we have determined row_accum_func, we can start work.
1824   int batch_start = 0;
1825   int batch_end = batches;
1826   int row_start = 0;
1827   int row_end = output_height;
1828   int output_ptr_offset = 0;
1829 
1830   switch (thread_dim) {
1831     case 0:
1832       // Multithread along with the batch axis
1833       TFLITE_DCHECK_GE(thread_start, 0);
1834       TFLITE_DCHECK_LE(thread_end, batches);
1835       batch_start = thread_start;
1836       batch_end = thread_end;
1837       output_ptr_offset = batch_start * FlatSizeSkipDim(output_shape, 0);
1838       break;
1839     case 1:
1840       // Multithread along with the row axis
1841       TFLITE_DCHECK_GE(thread_start, 0);
1842       TFLITE_DCHECK_LE(thread_end, output_height);
1843       row_start = thread_start;
1844       row_end = thread_end;
1845       output_ptr_offset = row_start * output_width * output_depth;
1846       break;
1847   }
1848 
1849   uint8* output_ptr = output_data + output_ptr_offset;
1850   int batch_step =
1851       (output_height + row_start - row_end) * output_width * output_depth;
1852   for (int b = batch_start; b < batch_end; ++b) {
1853     for (int out_y = row_start; out_y < row_end; ++out_y) {
1854       const int in_y_origin = (out_y * stride_height) - pad_height;
1855       const int filter_y_start =
1856           std::max(0, (-in_y_origin + dilation_height_factor - 1) /
1857                           dilation_height_factor);
1858       const int filter_y_end =
1859           std::min(filter_height,
1860                    (input_height - in_y_origin + dilation_height_factor - 1) /
1861                        dilation_height_factor);
1862       for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
1863            out_x_buffer_start += kOutputPixelsInAccBuffer) {
1864         const int out_x_buffer_end = std::min(
1865             output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
1866         // We call a 'pixel' a group of activation that share all but the
1867         // 'depth'/'channel' coordinate. num_output_pixels is the number of
1868         // output pixels that we will accumulate in this loop iteration.
1869         const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
1870         // Initialize our local accumulator with the bias values, so we don't
1871         // have to add them later.
1872         DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
1873                                    acc_buffer);
1874         // Accumulation loop. Most of the time should be spent in here.
1875         for (int filter_y = filter_y_start; filter_y < filter_y_end;
1876              ++filter_y) {
1877           const int in_y = in_y_origin + dilation_height_factor * filter_y;
1878           row_accum_func(
1879               stride_width, dilation_width_factor, input_depth, input_width,
1880               input_data + in_y * input_height_stride + b * input_batch_stride,
1881               input_offset, pad_width, depth_multiplier, filter_width,
1882               filter_data + filter_y * filter_height_stride, filter_offset,
1883               out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
1884         }
1885         // Finished accumulating int32 values. Now need to convert them to
1886         // the final 8bit form and store them.
1887         ruy::profiler::ScopeLabel label("downquantize+store");
1888         const int num_output_values = output_depth * num_output_pixels;
1889         int i = 0;
1890 #ifdef USE_NEON
1891         using gemmlowp::RoundingDivideByPOT;
1892         const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
1893         const int32x4_t output_activation_min_vec =
1894             vdupq_n_s32(output_activation_min);
1895         const int32x4_t output_activation_max_vec =
1896             vdupq_n_s32(output_activation_max);
1897         // Handle 16 values at once.
1898         // This allows us to issue 4 mutually independent int32
1899         // multiplications (vqrdmulh), which should alleviate most of their
1900         // high latency.
1901         for (; i <= num_output_values - 16; i += 16) {
1902           int32x4_t acc[4];
1903           for (int j = 0; j < 4; j++) {
1904             acc[j] = vld1q_s32(acc_buffer + i + 4 * j);
1905           }
1906 
1907           if (!shift_left) {
1908             // Fixed-point multiplication.
1909             for (int j = 0; j < 4; j++) {
1910               acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
1911             }
1912             for (int j = 0; j < 4; j++) {
1913               acc[j] = RoundingDivideByPOT(acc[j], -output_shift);
1914             }
1915           } else {
1916             // Fixed-point multiplication.
1917             for (int j = 0; j < 4; j++) {
1918               acc[j] = vmulq_n_s32(acc[j], multiplier_power_of_two);
1919               acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
1920             }
1921           }
1922           // Add the output offset.
1923           for (int j = 0; j < 4; j++) {
1924             acc[j] = vaddq_s32(acc[j], output_offset_vec);
1925           }
1926           // Apply the activation function.
1927           for (int j = 0; j < 4; j++) {
1928             acc[j] = vmaxq_s32(acc[j], output_activation_min_vec);
1929           }
1930           for (int j = 0; j < 4; j++) {
1931             acc[j] = vminq_s32(acc[j], output_activation_max_vec);
1932           }
1933           // Saturating cast to uint8 and store to destination.
1934           int16x4_t acc_s16[4];
1935           for (int j = 0; j < 4; j++) {
1936             acc_s16[j] = vqmovn_s32(acc[j]);
1937           }
1938           const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]);
1939           const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]);
1940           const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0);
1941           const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1);
1942           vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1));
1943           output_ptr += 16;
1944         }
1945         // Handle 8 values at once.
1946         // Not as good as 16 (now we're only issuing 2 mutually independent
1947         // vqrdmulh instructions, so we're probably paying for their high
1948         // latency).
1949         for (; i <= num_output_values - 8; i += 8) {
1950           int32x4_t acc0 = vld1q_s32(acc_buffer + i);
1951           int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4);
1952           if (!shift_left) {
1953             // Fixed-point multiplication.
1954             acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
1955             acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
1956             // Rounding right shift.
1957             acc0 = RoundingDivideByPOT(acc0, -output_shift);
1958             acc1 = RoundingDivideByPOT(acc1, -output_shift);
1959           } else {
1960             // Fixed-point multiplication.
1961             acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
1962             acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
1963 
1964             acc1 = vmulq_n_s32(acc1, multiplier_power_of_two);
1965             acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
1966           }
1967           // Add the output offset.
1968           acc0 = vaddq_s32(acc0, output_offset_vec);
1969           acc1 = vaddq_s32(acc1, output_offset_vec);
1970           // Apply the activation function.
1971           acc0 = vmaxq_s32(acc0, output_activation_min_vec);
1972           acc1 = vmaxq_s32(acc1, output_activation_min_vec);
1973           acc0 = vminq_s32(acc0, output_activation_max_vec);
1974           acc1 = vminq_s32(acc1, output_activation_max_vec);
1975           // Saturating cast to uint8 and store to destination.
1976           const int16x4_t acc0_s16 = vqmovn_s32(acc0);
1977           const int16x4_t acc1_s16 = vqmovn_s32(acc1);
1978           const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16);
1979           const uint8x8_t res_u8 = vqmovun_s16(res_s16);
1980           vst1_u8(output_ptr, res_u8);
1981           output_ptr += 8;
1982         }
1983         // Handle 4 values at once. Now we're paying the full price of the
1984         // high latency of vqrdmulh. Also, storing only 4 bytes at the end
1985         // (without any alignment) can only be done 1 byte at a time.
1986         // Yet, that is still worth doing to minimize the amount of leftover
1987         // that will have to go through the very slow scalar code.
1988         for (; i <= num_output_values - 4; i += 4) {
1989           int32x4_t acc = vld1q_s32(acc_buffer + i);
1990           if (!shift_left) {
1991             // Fixed-point multiplication.
1992             acc = vqrdmulhq_n_s32(acc, output_multiplier);
1993             // Rounding right shift.
1994             acc = RoundingDivideByPOT(acc, -output_shift);
1995           } else {
1996             // Fixed-point multiplication.
1997             acc = vmulq_n_s32(acc, multiplier_power_of_two);
1998             acc = vqrdmulhq_n_s32(acc, output_multiplier);
1999           }
2000           // Add the output offset.
2001           acc = vaddq_s32(acc, output_offset_vec);
2002           // Apply the activation function.
2003           acc = vmaxq_s32(acc, output_activation_min_vec);
2004           acc = vminq_s32(acc, output_activation_max_vec);
2005           // Saturating cast to uint8 and store to destination.
2006           const int16x4_t acc_s16 = vqmovn_s32(acc);
2007           const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16);
2008           const uint8x8_t res_u8 = vqmovun_s16(res_s16);
2009           vst1_lane_u8(output_ptr + 0, res_u8, 0);
2010           vst1_lane_u8(output_ptr + 1, res_u8, 1);
2011           vst1_lane_u8(output_ptr + 2, res_u8, 2);
2012           vst1_lane_u8(output_ptr + 3, res_u8, 3);
2013           output_ptr += 4;
2014         }
2015 #endif  // USE_NEON
2016 
2017         // Handle leftover values, one by one. This is very slow.
2018         for (; i < num_output_values; i++) {
2019           int32 acc = acc_buffer[i];
2020           acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
2021                                               output_shift);
2022           acc += output_offset;
2023           acc = std::max(acc, output_activation_min);
2024           acc = std::min(acc, output_activation_max);
2025           *output_ptr++ = static_cast<uint8>(acc);
2026         }
2027       }
2028     }
2029     output_ptr += batch_step;
2030   }
2031 }
2032 
2033 }  // namespace depthwise_conv
2034 
2035 template <DepthwiseConvOutputRounding kOutputRounding>
2036 inline void DepthwiseConvWithRounding(
2037     const DepthwiseParams& params, const RuntimeShape& input_shape,
2038     const uint8* input_data, const RuntimeShape& filter_shape,
2039     const uint8* filter_data, const RuntimeShape& bias_shape,
2040     const int32* bias_data, const RuntimeShape& output_shape,
2041     uint8* output_data, const CpuFlags& cpu_flags, int thread_start,
2042     int thread_end, int thread_dim) {
2043   ruy::profiler::ScopeLabel label("DepthwiseConv/8bit");
2044   const int depth_multiplier = params.depth_multiplier;
2045   const int32 output_activation_min = params.quantized_activation_min;
2046   const int32 output_activation_max = params.quantized_activation_max;
2047   const int dilation_width_factor = params.dilation_width_factor;
2048   const int dilation_height_factor = params.dilation_height_factor;
2049   TFLITE_DCHECK_GE(dilation_width_factor, 1);
2050   TFLITE_DCHECK_GE(dilation_height_factor, 1);
2051   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
2052   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
2053   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
2054   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
2055   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
2056   const int input_depth = input_shape.Dims(3);
2057   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
2058   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
2059 
2060 // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
2061 // Jetson TX-2. This compiler does not support the offsetof() macro.
2062 #if defined(__aarch64__) && !defined(GOOGLE_L4T)
2063 #if defined(__ANDROID__) && defined(__clang__)
2064   // Dispatch to dot-product 3x3 kernels when supported.
2065   if (cpu_flags.neon_dotprod) {
2066     using optimized_ops::depthwise_conv::DotProduct3x3KernelType;
2067     DotProduct3x3KernelType kernel_type =
2068         optimized_ops::depthwise_conv::CategorizeDotProductKernel(
2069             input_shape, filter_shape, output_shape, params);
2070     if (kernel_type != DotProduct3x3KernelType::kNone) {
2071       ruy::profiler::ScopeLabel specialized_label(
2072           "DepthwiseConv/8bit/3x3XDotProduct");
2073       optimized_ops::depthwise_conv::DepthwiseConvDotProduct3x3<
2074           DepthwiseConvImplementation::kUseNeon3x3DotProduct>(
2075           params, input_shape, input_data, filter_shape, filter_data,
2076           bias_shape, bias_data, output_shape, output_data, thread_start,
2077           thread_end, thread_dim);
2078       return;
2079     }
2080   }
2081 
2082 #endif
2083   // Dispatch to non-dot-product 3x3 kernels when supported.
2084 
2085   const int stride_width = params.stride_width;
2086   const int stride_height = params.stride_height;
2087   const int pad_width = params.padding_values.width;
2088   const int pad_height = params.padding_values.height;
2089   const int output_shift = params.output_shift;
2090 
2091   // Call kernel optimized for depthwise convolutions using 3x3 filters if
2092   // parameters are supported.
2093   if (depthwise_conv::Fast3x3FilterKernelSupported(
2094           input_shape, filter_shape, stride_width, stride_height,
2095           dilation_width_factor, dilation_height_factor, pad_width, pad_height,
2096           depth_multiplier, output_shape, output_shift)) {
2097     ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/3x3");
2098     depthwise_conv::DepthwiseConv3x3Filter<kOutputRounding>(
2099         params, input_shape, input_data, filter_shape, filter_data, bias_shape,
2100         bias_data, output_shape, output_data, thread_start, thread_end,
2101         thread_dim);
2102     return;
2103   }
2104 #endif
2105 
2106   ruy::profiler::ScopeLabel specialized_label("DepthwiseConv/8bit/General");
2107   depthwise_conv::DepthwiseConvGeneral(params, input_shape, input_data,
2108                                        filter_shape, filter_data, bias_shape,
2109                                        bias_data, output_shape, output_data,
2110                                        thread_start, thread_end, thread_dim);
2111 }
2112 
2113 inline void DepthwiseConvImpl(
2114     const DepthwiseParams& params, const RuntimeShape& input_shape,
2115     const uint8* input_data, const RuntimeShape& filter_shape,
2116     const uint8* filter_data, const RuntimeShape& bias_shape,
2117     const int32* bias_data, const RuntimeShape& output_shape,
2118     uint8* output_data, const CpuFlags& cpu_flags, int thread_start,
2119     int thread_end, int thread_dim) {
2120   return DepthwiseConvWithRounding<DepthwiseConvOutputRounding::kUpward>(
2121       params, input_shape, input_data, filter_shape, filter_data, bias_shape,
2122       bias_data, output_shape, output_data, cpu_flags, thread_start, thread_end,
2123       thread_dim);
2124 }
2125 
2126 }  // namespace optimized_ops
2127 }  // namespace tflite
2128 
2129 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_
2130