xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_
17 
18 #include <string.h>
19 
20 #include <algorithm>
21 #include <vector>
22 
23 #include "ruy/profiler/instrumentation.h"  // from @ruy
24 #include "tensorflow/lite/kernels/cpu_backend_context.h"
25 #include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
28 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h"
29 #include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h"
30 #include "tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv_3x3_filter.h"
31 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
32 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
33 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
34 #include "tensorflow/lite/kernels/internal/types.h"
35 
36 namespace tflite {
37 namespace optimized_integer_ops {
38 namespace depthwise_conv {
39 
40 // Implementation of quantized DepthwiseConv
41 
42 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
43 struct QuantizedDepthwiseConvKernel {};
44 
45 #ifdef USE_NEON
46 template <>
47 struct QuantizedDepthwiseConvKernel<true, 8, 2> {
48   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
49                   const int8* input_ptr, int16 input_offset,
50                   int input_ptr_increment, const int8* filter_ptr,
51                   int32* acc_buffer_ptr) {
52     // Load the filters.
53     int8x8x2_t filter_s8;
54     filter_s8.val[0] = vld1_s8(filter_ptr);
55     filter_s8.val[1] = vld1_s8(filter_ptr + 8);
56     int16x8_t filter[2];
57     for (int i = 0; i < 2; i++) {
58       filter[i] = vmovl_s8(filter_s8.val[i]);
59     }
60     // Handle one output pixel at a time.
61     for (int outp = 0; outp < num_output_pixels; outp++) {
62       // Load the accumulators from acc_buffer
63       int32x4x2_t acc[2];
64       for (int i = 0; i < 2; i++) {
65         acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
66         acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
67       }
68       // Load the inputs, add input_offset.
69       const int8x8_t input_s8 = vld1_s8(input_ptr);
70       input_ptr += input_ptr_increment;
71       const int16x8_t input_s16 = vmovl_s8(input_s8);
72       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
73       // Duplicate the input values, 2-fold
74       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
75       // Multiply-accumulate
76       for (int i = 0; i < 2; i++) {
77         acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]),
78                                   vget_low_s16(input_dup2.val[i]));
79         acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]),
80                                   vget_high_s16(input_dup2.val[i]));
81       }
82       // Store the accumulators back to acc_buffer
83       for (int i = 0; i < 2; i++) {
84         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
85         vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
86       }
87       acc_buffer_ptr += 16;
88     }
89   }
90 };
91 
92 template <>
93 struct QuantizedDepthwiseConvKernel<false, 8, 1> {
94   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
95                   const int8* input_ptr, int16 input_offset,
96                   int input_ptr_increment, const int8* filter_ptr,
97                   int32* acc_buffer_ptr) {
98     // Load the filters.
99     const int8x8_t filter_s8 = vld1_s8(filter_ptr);
100     const int16x8_t filter = vmovl_s8(filter_s8);
101 
102     int outp = 0;
103     // Handle 2 output pixels at a time.
104     for (; outp <= num_output_pixels - 2; outp += 2) {
105       // Load the accumulators from acc_buffer.
106       int32x4_t acc[4];
107       for (int i = 0; i < 4; i++) {
108         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
109       }
110       // Load the inputs, add input_offset.
111       int8x8_t input_s8[2];
112       for (int i = 0; i < 2; i++) {
113         input_s8[i] = vld1_s8(input_ptr + 8 * i);
114       }
115       input_ptr += 16;
116       int16x8_t input[2];
117       for (int i = 0; i < 2; i++) {
118         input[i] = vmovl_s8(input_s8[i]);
119       }
120       for (int i = 0; i < 2; i++) {
121         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
122       }
123       // Multiply-accumulate.
124       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0]));
125       acc[1] =
126           vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0]));
127       acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1]));
128       acc[3] =
129           vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1]));
130       // Store the accumulators back to acc_buffer
131       for (int i = 0; i < 4; i++) {
132         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
133       }
134       acc_buffer_ptr += 16;
135     }
136     // Handle 1 output pixel at a time.
137     for (; outp < num_output_pixels; outp++) {
138       // Load the accumulators from acc_buffer.
139       int32x4_t acc[2];
140       acc[0] = vld1q_s32(acc_buffer_ptr);
141       acc[1] = vld1q_s32(acc_buffer_ptr + 4);
142 
143       // Load the inputs, add input_offset.
144       const int8x8_t input_s8 = vld1_s8(input_ptr);
145       input_ptr += 8;
146       const int16x8_t input_s16 = vmovl_s8(input_s8);
147       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
148       // Multiply-accumulate.
149       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input));
150       acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input));
151       // Store the accumulators back to acc_buffer
152       vst1q_s32(acc_buffer_ptr, acc[0]);
153       vst1q_s32(acc_buffer_ptr + 4, acc[1]);
154       acc_buffer_ptr += 8;
155     }
156   }
157 };
158 
159 template <>
160 struct QuantizedDepthwiseConvKernel<false, 4, 2> {
161   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
162                   const int8* input_ptr, int16 input_offset,
163                   int input_ptr_increment, const int8* filter_ptr,
164                   int32* acc_buffer_ptr) {
165     // Load the filters.
166     const int8x8_t filter_s8 = vld1_s8(filter_ptr);
167     const int16x8_t filter = vmovl_s8(filter_s8);
168 
169     int outp = 0;
170     // Handle 2 output pixels at a time.
171     for (; outp <= num_output_pixels - 2; outp += 2) {
172       // Load the accumulators from acc_buffer
173       int32x4_t acc[4];
174       for (int i = 0; i < 4; i++) {
175         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
176       }
177       // Load the inputs, add input_offset.
178       const int8x8_t input_s8 = vld1_s8(input_ptr);
179       input_ptr += 8;
180       const int16x8_t input_s16 = vmovl_s8(input_s8);
181       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
182       // Duplicate the input values, 2-fold
183       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
184       // Multiply-accumulate
185       for (int i = 0; i < 2; i++) {
186         acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter),
187                                    vget_low_s16(input_dup2.val[i]));
188         acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter),
189                                    vget_high_s16(input_dup2.val[i]));
190       }
191       // Store the accumulators back to acc_buffer
192       for (int i = 0; i < 4; i++) {
193         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
194       }
195       acc_buffer_ptr += 16;
196     }
197     // Handle one output pixel at a time.
198     for (; outp < num_output_pixels; outp++) {
199       // Load the accumulators from acc_buffer
200       int32x4_t acc[2];
201       for (int i = 0; i < 2; i++) {
202         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
203       }
204       // Load the inputs, add input_offset.
205       int8x8_t input_s8 = vdup_n_s8(0);
206       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
207       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
208       input_s8 = vset_lane_s8(input_ptr[2], input_s8, 2);
209       input_s8 = vset_lane_s8(input_ptr[3], input_s8, 3);
210       input_ptr += 4;
211       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
212       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
213       // Duplicate the input values, 2-fold
214       const int16x4x2_t input_dup2 = vzip_s16(input, input);
215       // Multiply-accumulate
216       acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]);
217       acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]);
218       // Store the accumulators back to acc_buffer
219       for (int i = 0; i < 2; i++) {
220         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
221       }
222       acc_buffer_ptr += 8;
223     }
224   }
225 };
226 
227 template <>
228 struct QuantizedDepthwiseConvKernel<false, 2, 8> {
229   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
230                   const int8* input_ptr, int16 input_offset,
231                   int input_ptr_increment, const int8* filter_ptr,
232                   int32* acc_buffer_ptr) {
233     // Load the filters.
234     int16x8_t filter[2];
235     for (int i = 0; i < 2; i++) {
236       const int8x8_t filter_s8 = vld1_s8(filter_ptr + 8 * i);
237       filter[i] = vmovl_s8(filter_s8);
238     }
239     int outp = 0;
240     // Handle two output pixels at a time.
241     for (; outp <= num_output_pixels - 2; outp += 2) {
242       // Load the accumulators from acc_buffer.
243       int32x4_t acc[8];
244       for (int i = 0; i < 8; i++) {
245         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
246       }
247       // Load the inputs, add input_offset.
248       int8x8_t input_s8 = vdup_n_s8(0);
249       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
250       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
251       input_s8 = vset_lane_s8(input_ptr[2], input_s8, 2);
252       input_s8 = vset_lane_s8(input_ptr[3], input_s8, 3);
253       input_ptr += 4;
254       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
255       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
256       // Multiply-accumulate.
257       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
258       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
259       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
260       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
261       acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2);
262       acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2);
263       acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3);
264       acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3);
265       // Store the accumulators back to acc_buffer.
266       for (int i = 0; i < 8; i++) {
267         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
268       }
269       acc_buffer_ptr += 32;
270     }
271     // Handle one output pixel at a time.
272     for (; outp < num_output_pixels; outp++) {
273       // Load the accumulators from acc_buffer.
274       int32x4_t acc[4];
275       for (int i = 0; i < 4; i++) {
276         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
277       }
278       // Load the inputs, add input_offset.
279       int8x8_t input_s8 = vdup_n_s8(0);
280       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
281       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
282       input_ptr += 2;
283       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
284       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
285 
286       // Multiply-accumulate.
287       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
288       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0);
289       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1);
290       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1);
291 
292       // Store the accumulators back to acc_buffer.
293       for (int i = 0; i < 4; i++) {
294         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
295       }
296       acc_buffer_ptr += 16;
297     }
298   }
299 };
300 
301 template <>
302 struct QuantizedDepthwiseConvKernel<false, 2, 2> {
303   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
304                   const int8* input_ptr, int16 input_offset,
305                   int input_ptr_increment, const int8* filter_ptr,
306                   int32* acc_buffer_ptr) {
307     // Load the filters.
308     int8x8_t filter_s8 = vdup_n_s8(0);
309     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0);
310     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 1);
311     filter_s8 = vset_lane_s8(filter_ptr[2], filter_s8, 2);
312     filter_s8 = vset_lane_s8(filter_ptr[3], filter_s8, 3);
313     const int16x4_t filter = vget_low_s16(vmovl_s8(filter_s8));
314 
315     int outp = 0;
316     // Handle 4 output pixels at a time.
317     for (; outp <= num_output_pixels - 4; outp += 4) {
318       // Load the accumulators from acc_buffer
319       int32x4_t acc[4];
320       for (int i = 0; i < 4; i++) {
321         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
322       }
323 
324       // Load the inputs, add input_offset.
325       const int8x8_t input_s8 = vld1_s8(input_ptr);
326       input_ptr += 8;
327       const int16x8_t input_s16 = vmovl_s8(input_s8);
328       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
329       // Duplicate the input values, 2-fold
330       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
331       // Multiply-accumulate
332       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
333       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
334       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
335       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
336       // Store the accumulators back to acc_buffer
337       for (int i = 0; i < 4; i++) {
338         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
339       }
340       acc_buffer_ptr += 16;
341     }
342     // Handle one output pixel at a time.
343     for (; outp < num_output_pixels; outp++) {
344       // Load the accumulators from acc_buffer
345       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
346 
347       int8x8_t input_s8 = vdup_n_s8(0);
348       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
349       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
350       input_ptr += 2;
351       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
352       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
353       // Duplicate the input values, 2-fold
354       const int16x4_t input_dup2 = vzip_s16(input, input).val[0];
355       // Multiply-accumulate
356       acc = vmlal_s16(acc, filter, input_dup2);
357       // Store the accumulators back to acc_buffer
358       vst1q_s32(acc_buffer_ptr, acc);
359       acc_buffer_ptr += 4;
360     }
361   }
362 };
363 
364 template <>
365 struct QuantizedDepthwiseConvKernel<false, 2, 1> {
366   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
367                   const int8* input_ptr, int16 input_offset,
368                   int input_ptr_increment, const int8* filter_ptr,
369                   int32* acc_buffer_ptr) {
370     // Load the filters.
371     int8x8_t filter_s8 = vdup_n_s8(0);
372     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0);
373     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 1);
374     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 2);
375     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 3);
376     const int16x4_t filter = vget_low_s16(vmovl_s8(filter_s8));
377 
378     int outp = 0;
379     // Handle 8 output pixels at a time.
380     for (; outp <= num_output_pixels - 8; outp += 8) {
381       // Load the accumulators from acc_buffer.
382       int32x4_t acc[4];
383       for (int i = 0; i < 4; i++) {
384         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
385       }
386       // Load the inputs, add input_offset.
387       int8x8_t input_s8[2];
388       for (int i = 0; i < 2; i++) {
389         input_s8[i] = vld1_s8(input_ptr + 8 * i);
390       }
391       input_ptr += 16;
392       int16x8_t input[2];
393       for (int i = 0; i < 2; i++) {
394         input[i] = vmovl_s8(input_s8[i]);
395       }
396       for (int i = 0; i < 2; i++) {
397         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
398       }
399 
400       // Multiply-accumulate.
401       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0]));
402       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0]));
403       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1]));
404       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1]));
405       // Store the accumulators back to acc_buffer.
406       for (int i = 0; i < 4; i++) {
407         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
408       }
409       acc_buffer_ptr += 16;
410     }
411     // Handle 4 output pixels at a time.
412     for (; outp <= num_output_pixels - 4; outp += 4) {
413       // Load the accumulators from acc_buffer.
414       int32x4_t acc[2];
415       for (int i = 0; i < 2; i++) {
416         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
417       }
418       // Load the inputs, add input_offset.
419       const int8x8_t input_s8 = vld1_s8(input_ptr);
420       input_ptr += 8;
421       const int16x8_t input_s16 = vmovl_s8(input_s8);
422       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
423 
424       // Multiply-accumulate.
425       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input));
426       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input));
427       // Store the accumulators back to acc_buffer.
428       for (int i = 0; i < 2; i++) {
429         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
430       }
431       acc_buffer_ptr += 8;
432     }
433     // Handle 2 output pixels at a time.
434     for (; outp <= num_output_pixels - 2; outp += 2) {
435       // Load the accumulators from acc_buffer.
436       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
437       // Load the inputs, add input_offset.
438       int8x8_t input_s8 = vdup_n_s8(0);
439       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
440       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
441       input_s8 = vset_lane_s8(input_ptr[2], input_s8, 2);
442       input_s8 = vset_lane_s8(input_ptr[3], input_s8, 3);
443       input_ptr += 4;
444       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
445       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
446 
447       // Multiply-accumulate.
448       acc = vmlal_s16(acc, filter, input);
449       // Store the accumulators back to acc_buffer.
450       vst1q_s32(acc_buffer_ptr, acc);
451       acc_buffer_ptr += 4;
452     }
453     // Handle 1 output pixel at a time.
454     for (; outp < num_output_pixels; outp++) {
455       // Load the accumulators from acc_buffer.
456       int32x2_t acc = vld1_s32(acc_buffer_ptr);
457       // Load the inputs, add input_offset.
458       int8x8_t input_s8 = vdup_n_s8(0);
459       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
460       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
461       input_ptr += 2;
462       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
463       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
464 
465       // Multiply-accumulate.
466       acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
467       // Store the accumulators back to acc_buffer.
468       vst1_s32(acc_buffer_ptr, acc);
469       acc_buffer_ptr += 2;
470     }
471   }
472 };
473 
474 template <>
475 struct QuantizedDepthwiseConvKernel<false, 1, 2> {
476   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
477                   const int8* input_ptr, int16 input_offset,
478                   int input_ptr_increment, const int8* filter_ptr,
479                   int32* acc_buffer_ptr) {
480     // Load the filters.
481     int8x8_t filter_s8 = vdup_n_s8(0);
482     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0);
483     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 1);
484     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 2);
485     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 3);
486     const int16x4_t filter = vget_low_s16(vmovl_s8(filter_s8));
487 
488     int outp = 0;
489     // Handle 8 output pixels at a time.
490     for (; outp <= num_output_pixels - 8; outp += 8) {
491       // Load the accumulators from acc_buffer
492       int32x4_t acc[4];
493       for (int i = 0; i < 4; i++) {
494         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
495       }
496 
497       // Load the inputs, add input_offset.
498       const int8x8_t input_s8 = vld1_s8(input_ptr);
499       input_ptr += 8;
500       const int16x8_t input_s16 = vmovl_s8(input_s8);
501       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
502       // Duplicate the input values, 2-fold
503       const int16x8x2_t input_dup2 = vzipq_s16(input, input);
504       // Multiply-accumulate
505       acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0]));
506       acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0]));
507       acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1]));
508       acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1]));
509       // Store the accumulators back to acc_buffer
510       for (int i = 0; i < 4; i++) {
511         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
512       }
513       acc_buffer_ptr += 16;
514     }
515     // Handle one output pixel at a time.
516     for (; outp < num_output_pixels; outp++) {
517       // Load the accumulators from acc_buffer
518       int32x2_t acc = vld1_s32(acc_buffer_ptr);
519 
520       // Load the inputs, add input_offset.
521       const uint32 input = *input_ptr++ + input_offset;
522 
523       // Multiply-accumulate
524       acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input));
525       // Store the accumulators back to acc_buffer
526       vst1_s32(acc_buffer_ptr, acc);
527       acc_buffer_ptr += 2;
528     }
529   }
530 };
531 
532 template <>
533 struct QuantizedDepthwiseConvKernel<false, 1, 4> {
534   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
535                   const int8* input_ptr, int16 input_offset,
536                   int input_ptr_increment, const int8* filter_ptr,
537                   int32* acc_buffer_ptr) {
538     // Load the filters.
539     int8x8_t filter_s8 = vdup_n_s8(0);
540     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0);
541     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 1);
542     filter_s8 = vset_lane_s8(filter_ptr[2], filter_s8, 2);
543     filter_s8 = vset_lane_s8(filter_ptr[3], filter_s8, 3);
544     const int16x4_t filter = vget_low_s16(vmovl_s8(filter_s8));
545 
546     int outp = 0;
547     // Handle 8 output pixels at a time.
548     for (; outp <= num_output_pixels - 8; outp += 8) {
549       // Load the accumulators from acc_buffer
550       int32x4_t acc[8];
551       for (int i = 0; i < 8; i++) {
552         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
553       }
554 
555       // Load the inputs, add input_offset.
556       int8x8_t input_s8 = vld1_s8(input_ptr);
557       input_ptr += 8;
558       const int16x8_t input_s16 = vmovl_s8(input_s8);
559       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
560 
561       // Multiply-accumulate
562       acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0);
563       acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1);
564       acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2);
565       acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3);
566       acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0);
567       acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1);
568       acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2);
569       acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3);
570 
571       // Store the accumulators back to acc_buffer
572       for (int i = 0; i < 8; i++) {
573         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
574       }
575       acc_buffer_ptr += 32;
576     }
577     // Handle 4 output pixels at a time.
578     for (; outp <= num_output_pixels - 4; outp += 4) {
579       // Load the accumulators from acc_buffer
580       int32x4_t acc[4];
581       for (int i = 0; i < 4; i++) {
582         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
583       }
584 
585       // Load the inputs, add input_offset.
586       int8x8_t input_s8 = vdup_n_s8(0);
587       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
588       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
589       input_s8 = vset_lane_s8(input_ptr[2], input_s8, 2);
590       input_s8 = vset_lane_s8(input_ptr[3], input_s8, 3);
591       input_ptr += 4;
592       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
593       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
594 
595       // Multiply-accumulate
596       acc[0] = vmlal_lane_s16(acc[0], filter, input, 0);
597       acc[1] = vmlal_lane_s16(acc[1], filter, input, 1);
598       acc[2] = vmlal_lane_s16(acc[2], filter, input, 2);
599       acc[3] = vmlal_lane_s16(acc[3], filter, input, 3);
600 
601       // Store the accumulators back to acc_buffer
602       for (int i = 0; i < 4; i++) {
603         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
604       }
605       acc_buffer_ptr += 16;
606     }
607     // Handle one output pixel at a time.
608     for (; outp < num_output_pixels; outp++) {
609       // Load the accumulators from acc_buffer
610       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
611 
612       // Load the inputs, add input_offset.
613       const uint32 input = *input_ptr++ + input_offset;
614 
615       // Multiply-accumulate
616       acc = vmlal_n_s16(acc, filter, input);
617       // Store the accumulators back to acc_buffer
618       vst1q_s32(acc_buffer_ptr, acc);
619       acc_buffer_ptr += 4;
620     }
621   }
622 };
623 
624 template <>
625 struct QuantizedDepthwiseConvKernel<false, 4, 1> {
626   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
627                   const int8* input_ptr, int16 input_offset,
628                   int input_ptr_increment, const int8* filter_ptr,
629                   int32* acc_buffer_ptr) {
630     // Load the filters.
631     int8x8_t filter_s8 = vdup_n_s8(0);
632     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0);
633     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 1);
634     filter_s8 = vset_lane_s8(filter_ptr[2], filter_s8, 2);
635     filter_s8 = vset_lane_s8(filter_ptr[3], filter_s8, 3);
636     const int16x4_t filter = vget_low_s16(vmovl_s8(filter_s8));
637 
638     int outp = 0;
639     // Handle 4 output pixels at a time.
640     for (; outp <= num_output_pixels - 4; outp += 4) {
641       // Load the accumulators from acc_buffer
642       int32x4_t acc[4];
643       for (int i = 0; i < 4; i++) {
644         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
645       }
646       // Load the inputs, add input_offset.
647       int16x8_t input[2];
648       for (int i = 0; i < 2; i++) {
649         const int8x8_t input_s8 = vld1_s8(input_ptr + 8 * i);
650         const int16x8_t input_s16 = vmovl_s8(input_s8);
651         input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
652       }
653       input_ptr += 16;
654       // Multiply-accumulate
655       for (int i = 0; i < 2; i++) {
656         acc[2 * i + 0] =
657             vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i]));
658         acc[2 * i + 1] =
659             vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i]));
660       }
661       // Store the accumulators back to acc_buffer
662       for (int i = 0; i < 4; i++) {
663         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
664       }
665       acc_buffer_ptr += 16;
666     }
667     // Handle one output pixel at a time.
668     for (; outp < num_output_pixels; outp++) {
669       // Load the accumulators from acc_buffer
670       int32x4_t acc;
671       acc = vld1q_s32(acc_buffer_ptr);
672 
673       // Load the inputs, add input_offset.
674       int8x8_t input_s8 = vdup_n_s8(0);
675       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
676       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
677       input_s8 = vset_lane_s8(input_ptr[2], input_s8, 2);
678       input_s8 = vset_lane_s8(input_ptr[3], input_s8, 3);
679       input_ptr += 4;
680       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
681       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
682       // Multiply-accumulate
683       acc = vmlal_s16(acc, filter, input);
684       // Store the accumulators back to acc_buffer
685       vst1q_s32(acc_buffer_ptr, acc);
686       acc_buffer_ptr += 4;
687     }
688   }
689 };
690 
691 template <>
692 struct QuantizedDepthwiseConvKernel<false, 4, 4> {
693   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
694                   const int8* input_ptr, int16 input_offset,
695                   int input_ptr_increment, const int8* filter_ptr,
696                   int32* acc_buffer_ptr) {
697     // Load the filters.
698     int16x8_t filter[2];
699     for (int i = 0; i < 2; i++) {
700       const int8x8_t filter_s8 = vld1_s8(filter_ptr + 8 * i);
701       filter[i] = vmovl_s8(filter_s8);
702     }
703 
704     int outp = 0;
705     // Handle 2 output pixels at a time.
706     for (; outp <= num_output_pixels - 2; outp += 2) {
707       // Load the accumulators from acc_buffer
708       int32x4_t acc[8];
709       for (int i = 0; i < 8; i++) {
710         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
711       }
712 
713       // Load the inputs, add input_offset.
714       int8x8_t input_s8 = vld1_s8(input_ptr);
715       input_ptr += 8;
716       const int16x8_t input_s16 = vmovl_s8(input_s8);
717       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
718 
719       // Multiply-accumulate
720       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]),
721                               vget_low_s16(input), 0);
722       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]),
723                               vget_low_s16(input), 1);
724       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]),
725                               vget_low_s16(input), 2);
726       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]),
727                               vget_low_s16(input), 3);
728       acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]),
729                               vget_high_s16(input), 0);
730       acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]),
731                               vget_high_s16(input), 1);
732       acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]),
733                               vget_high_s16(input), 2);
734       acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]),
735                               vget_high_s16(input), 3);
736       // Store the accumulators back to acc_buffer
737       for (int i = 0; i < 8; i++) {
738         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
739       }
740       acc_buffer_ptr += 32;
741     }
742     // Handle one output pixel at a time.
743     for (; outp < num_output_pixels; outp++) {
744       // Load the accumulators from acc_buffer
745       int32x4_t acc[4];
746       for (int i = 0; i < 4; i++) {
747         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
748       }
749 
750       // Load the inputs, add input_offset.
751       int8x8_t input_s8 = vdup_n_s8(0);
752       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
753       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
754       input_s8 = vset_lane_s8(input_ptr[2], input_s8, 2);
755       input_s8 = vset_lane_s8(input_ptr[3], input_s8, 3);
756       input_ptr += 4;
757       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
758       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
759 
760       // Multiply-accumulate
761       acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0);
762       acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1);
763       acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2);
764       acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3);
765       // Store the accumulators back to acc_buffer
766       for (int i = 0; i < 4; i++) {
767         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
768       }
769       acc_buffer_ptr += 16;
770     }
771   }
772 };
773 
774 template <>
775 struct QuantizedDepthwiseConvKernel<true, 0, 3> {
776   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
777                   const int8* input_ptr, int16 input_offset,
778                   int input_ptr_increment, const int8* filter_ptr,
779                   int32* acc_buffer_ptr) {
780     // We will have to duplicate bytes in a NEON register, 3-fold.
781     // We will do that by register-level table-look-up using VTBL instructions.
782     // Here we prepare the registers containing the table-lookup indices.
783     static const int8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2},
784                                                   {2, 3, 3, 3, 4, 4, 4, 5},
785                                                   {5, 5, 6, 6, 6, 7, 7, 7}};
786     int8x8_t dup3_indices[3];
787     for (int i = 0; i < 3; i++) {
788       dup3_indices[i] = vld1_s8(dup3_indices_array[i]);
789     }
790 
791     // Handle one output pixel at a time.
792     for (int outp = 0; outp < num_output_pixels; outp++) {
793       const int8* local_filter_ptr = filter_ptr;
794       const int8* local_input_ptr = input_ptr;
795       int ic = 0;
796       // Handle 8 input channels at a time.
797       for (; ic <= input_depth - 8; ic += 8) {
798         // Load the filters.
799         int16x8_t filter[3];
800         int8x8x3_t filter_s8;
801         filter_s8.val[0] = vld1_s8(local_filter_ptr);
802         filter_s8.val[1] = vld1_s8(local_filter_ptr + 8);
803         filter_s8.val[2] = vld1_s8(local_filter_ptr + 16);
804         local_filter_ptr += 24;
805         for (int i = 0; i < 3; i++) {
806           filter[i] = vmovl_s8(filter_s8.val[i]);
807         }
808         // Load the inputs, duplicate 3-fold, add input_offset.
809         const int8x8_t input_s8 = vld1_s8(local_input_ptr);
810         local_input_ptr += 8;
811 
812         int8x8_t input_s8_dup3[3];
813         for (int i = 0; i < 3; i++) {
814           input_s8_dup3[i] = vtbl1_s8(input_s8, dup3_indices[i]);
815         }
816         int16x8_t input_dup3[3];
817         for (int i = 0; i < 3; i++) {
818           const int16x8_t input_s16_dup3 = vmovl_s8(input_s8_dup3[i]);
819           input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset));
820         }
821         // Load the accumulators from acc_buffer
822         int32x4x3_t acc[2];
823         for (int i = 0; i < 2; i++) {
824           acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
825           acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
826           acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16);
827         }
828         // Multiply-accumulate
829         for (int j = 0; j < 3; j++) {
830           acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]),
831                                     vget_low_s16(filter[j]));
832           acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]),
833                                     vget_high_s16(filter[j]));
834         }
835         // Store the accumulators back to acc_buffer
836         for (int i = 0; i < 2; i++) {
837           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
838           vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
839           vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]);
840         }
841         acc_buffer_ptr += 24;
842       }
843       // Handle one input channel at a time.
844       for (; ic < input_depth; ic++) {
845         const int16 input_val = *local_input_ptr++ + input_offset;
846         for (int i = 0; i < 3; i++) {
847           *acc_buffer_ptr++ +=
848               static_cast<int32>(local_filter_ptr[i]) * input_val;
849         }
850         local_filter_ptr += 3;
851       }
852       input_ptr += input_ptr_increment;
853     }
854   }
855 };
856 
857 template <>
858 struct QuantizedDepthwiseConvKernel<true, 0, 2> {
859   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
860                   const int8* input_ptr, int16 input_offset,
861                   int input_ptr_increment, const int8* filter_ptr,
862                   int32* acc_buffer_ptr) {
863     // Handle one output pixel at a time.
864     for (int outp = 0; outp < num_output_pixels; outp++) {
865       const int8* local_filter_ptr = filter_ptr;
866       const int8* local_input_ptr = input_ptr;
867       int ic = 0;
868       // Handle 8 input channels at a time.
869       for (; ic <= input_depth - 8; ic += 8) {
870         // Load the filters.
871         int16x8_t filter[2];
872         int8x8x2_t filter_s8;
873         filter_s8.val[0] = vld1_s8(local_filter_ptr);
874         filter_s8.val[1] = vld1_s8(local_filter_ptr + 8);
875         local_filter_ptr += 16;
876         for (int i = 0; i < 2; i++) {
877           filter[i] = vmovl_s8(filter_s8.val[i]);
878         }
879         // Load the inputs, add input_offset, duplicate 2-fold.
880         const int8x8_t input_s8 = vld1_s8(local_input_ptr);
881         local_input_ptr += 8;
882         const int16x8_t input_s16 = vmovl_s8(input_s8);
883         const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
884         const int16x8x2_t input_dup2 = vzipq_s16(input, input);
885         // Load the accumulators from acc_buffer.
886         int32x4x2_t acc[2];
887         for (int i = 0; i < 2; i++) {
888           acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i);
889           acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8);
890         }
891         // Multiply-accumulate.
892         for (int j = 0; j < 2; j++) {
893           acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]),
894                                     vget_low_s16(input_dup2.val[j]));
895           acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]),
896                                     vget_high_s16(input_dup2.val[j]));
897         }
898         // Store the accumulators back to acc_buffer.
899         for (int i = 0; i < 2; i++) {
900           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]);
901           vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]);
902         }
903         acc_buffer_ptr += 16;
904       }
905       // Handle one input channel at a time.
906       for (; ic < input_depth; ic++) {
907         // Load the inputs.
908         const int16 input_val = *local_input_ptr++ + input_offset;
909         for (int i = 0; i < 2; i++) {
910           *acc_buffer_ptr++ +=
911               static_cast<int32>(local_filter_ptr[i]) * input_val;
912         }
913         local_filter_ptr += 2;
914       }
915       input_ptr += input_ptr_increment;
916     }
917   }
918 };
919 
920 template <>
921 struct QuantizedDepthwiseConvKernel<true, 0, 1> {
922   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
923                   const int8* input_ptr, int16 input_offset,
924                   int input_ptr_increment, const int8* filter_ptr,
925                   int32* acc_buffer_ptr) {
926     // Handle one output pixel at a time.
927     for (int outp = 0; outp < num_output_pixels; outp++) {
928       const int8* local_filter_ptr = filter_ptr;
929       const int8* local_input_ptr = input_ptr;
930       int ic = 0;
931       // Handle 16 input channels at a time.
932       for (; ic <= input_depth - 16; ic += 16) {
933         // Load the filters.
934         int8x8_t filter_s8_0 = vld1_s8(local_filter_ptr + 8 * 0);
935         int8x8_t filter_s8_1 = vld1_s8(local_filter_ptr + 8 * 1);
936         local_filter_ptr += 16;
937         int16x8_t filter_0 = vmovl_s8(filter_s8_0);
938         int16x8_t filter_1 = vmovl_s8(filter_s8_1);
939         // Load the inputs, add input_offset.
940         int8x8_t input_s8_0 = vld1_s8(local_input_ptr + 8 * 0);
941         int8x8_t input_s8_1 = vld1_s8(local_input_ptr + 8 * 1);
942         local_input_ptr += 16;
943         int16x8_t input_0 = vmovl_s8(input_s8_0);
944         int16x8_t input_1 = vmovl_s8(input_s8_1);
945         input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
946         input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
947         // Load the accumulators from acc_buffer
948         int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
949         int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
950         int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
951         int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
952         acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0));
953         acc_1 =
954             vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0));
955         acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1));
956         acc_3 =
957             vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1));
958         // Store the accumulators back to acc_buffer
959         vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
960         vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
961         vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
962         vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
963         acc_buffer_ptr += 16;
964       }
965       // Handle 8 input channels at a time.
966       for (; ic <= input_depth - 8; ic += 8) {
967         // Load the filters.
968         const int8x8_t filter_s8 = vld1_s8(local_filter_ptr);
969         local_filter_ptr += 8;
970         const int16x8_t filter = vmovl_s8(filter_s8);
971         // Load the inputs, add input_offset.
972         const int8x8_t input_s8 = vld1_s8(local_input_ptr);
973         local_input_ptr += 8;
974         const int16x8_t input_s16 = vmovl_s8(input_s8);
975         const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
976         // Load the accumulators from acc_buffer
977         int32x4_t acc[2];
978         for (int i = 0; i < 2; i++) {
979           acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
980         }
981         // Multiply-accumulate
982         acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
983         acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
984         // Store the accumulators back to acc_buffer
985         for (int i = 0; i < 2; i++) {
986           vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
987         }
988         acc_buffer_ptr += 8;
989       }
990       // Handle one input channel at a time.
991       for (; ic < input_depth; ic++) {
992         const int16 input_val = *local_input_ptr++ + input_offset;
993         const int16 filter_val = *local_filter_ptr++;
994         *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
995       }
996       input_ptr += input_ptr_increment;
997     }
998   }
999 };
1000 
1001 template <>
1002 struct QuantizedDepthwiseConvKernel<true, 16, 1> {
1003   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1004                   const int8* input_ptr, int16 input_offset,
1005                   int input_ptr_increment, const int8* filter_ptr,
1006                   int32* acc_buffer_ptr) {
1007     // Load the filters.
1008     int8x8_t filter_s8[2];
1009     for (int i = 0; i < 2; i++) {
1010       filter_s8[i] = vld1_s8(filter_ptr + 8 * i);
1011     }
1012     int16x8_t filter[2];
1013     for (int i = 0; i < 2; i++) {
1014       filter[i] = vmovl_s8(filter_s8[i]);
1015     }
1016     // Handle one output pixel at a time.
1017     for (int outp = 0; outp < num_output_pixels; outp++) {
1018       // Load the inputs, add input_offset.
1019       int8x8_t input_s8[2];
1020       for (int i = 0; i < 2; i++) {
1021         input_s8[i] = vld1_s8(input_ptr + 8 * i);
1022       }
1023       input_ptr += input_ptr_increment;
1024       int16x8_t input[2];
1025       for (int i = 0; i < 2; i++) {
1026         input[i] = vmovl_s8(input_s8[i]);
1027       }
1028       for (int i = 0; i < 2; i++) {
1029         input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset));
1030       }
1031       // Load the accumulators from acc_buffer
1032       int32x4_t acc[4];
1033       for (int i = 0; i < 4; i++) {
1034         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1035       }
1036       // Multiply-accumulate
1037       for (int i = 0; i < 2; i++) {
1038         acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]),
1039                                    vget_low_s16(filter[i]));
1040         acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]),
1041                                    vget_high_s16(filter[i]));
1042       }
1043       // Store the accumulators back to acc_buffer
1044       for (int i = 0; i < 4; i++) {
1045         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1046       }
1047       acc_buffer_ptr += 16;
1048     }
1049   }
1050 };
1051 
1052 template <>
1053 struct QuantizedDepthwiseConvKernel<true, 8, 1> {
1054   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1055                   const int8* input_ptr, int16 input_offset,
1056                   int input_ptr_increment, const int8* filter_ptr,
1057                   int32* acc_buffer_ptr) {
1058     // Load the filters.
1059     const int8x8_t filter_s8 = vld1_s8(filter_ptr);
1060     const int16x8_t filter = vmovl_s8(filter_s8);
1061     // Handle one output pixel at a time.
1062     for (int outp = 0; outp < num_output_pixels; outp++) {
1063       // Load the inputs, add input_offset.
1064       const int8x8_t input_s8 = vld1_s8(input_ptr);
1065       const int16x8_t input_s16 = vmovl_s8(input_s8);
1066       const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset));
1067       // Load the accumulators from acc_buffer
1068       int32x4_t acc[2];
1069       for (int i = 0; i < 2; i++) {
1070         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1071       }
1072       // Multiply-accumulate
1073       acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter));
1074       acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter));
1075       // Store the accumulators back to acc_buffer
1076       for (int i = 0; i < 2; i++) {
1077         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1078       }
1079       acc_buffer_ptr += 8;
1080       input_ptr += input_ptr_increment;
1081     }
1082   }
1083 };
1084 
1085 template <>
1086 struct QuantizedDepthwiseConvKernel<true, 1, 16> {
1087   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1088                   const int8* input_ptr, int16 input_offset,
1089                   int input_ptr_increment, const int8* filter_ptr,
1090                   int32* acc_buffer_ptr) {
1091     // Load the filters.
1092     int8x8_t filter_s8[2];
1093     for (int i = 0; i < 2; i++) {
1094       filter_s8[i] = vld1_s8(filter_ptr + 8 * i);
1095     }
1096     int16x8_t filter[2];
1097     for (int i = 0; i < 2; i++) {
1098       filter[i] = vmovl_s8(filter_s8[i]);
1099     }
1100     // Handle one output pixel at a time.
1101     for (int outp = 0; outp < num_output_pixels; outp++) {
1102       int8 input_s8 = *input_ptr;
1103       input_ptr += input_ptr_increment;
1104       int16 input = static_cast<int16>(input_s8 + input_offset);
1105       // Load the accumulators from acc_buffer
1106       int32x4_t acc[4];
1107       for (int i = 0; i < 4; i++) {
1108         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1109       }
1110       // Multiply-accumulate
1111       for (int i = 0; i < 2; i++) {
1112         acc[2 * i + 0] =
1113             vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input);
1114         acc[2 * i + 1] =
1115             vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input);
1116       }
1117       // Store the accumulators back to acc_buffer
1118       for (int i = 0; i < 4; i++) {
1119         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1120       }
1121       acc_buffer_ptr += 16;
1122     }
1123   }
1124 };
1125 
1126 template <>
1127 struct QuantizedDepthwiseConvKernel<true, 1, 32> {
1128   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1129                   const int8* input_ptr, int16 input_offset,
1130                   int input_ptr_increment, const int8* filter_ptr,
1131                   int32* acc_buffer_ptr) {
1132     // Load the filters.
1133     int8x8_t filter_s8_0 = vld1_s8(filter_ptr + 8 * 0);
1134     int8x8_t filter_s8_1 = vld1_s8(filter_ptr + 8 * 1);
1135     int8x8_t filter_s8_2 = vld1_s8(filter_ptr + 8 * 2);
1136     int8x8_t filter_s8_3 = vld1_s8(filter_ptr + 8 * 3);
1137     int16x8_t filter_0 = vmovl_s8(filter_s8_0);
1138     int16x8_t filter_1 = vmovl_s8(filter_s8_1);
1139     int16x8_t filter_2 = vmovl_s8(filter_s8_2);
1140     int16x8_t filter_3 = vmovl_s8(filter_s8_3);
1141     // Handle one output pixel at a time.
1142     for (int outp = 0; outp < num_output_pixels; outp++) {
1143       int8 input_s8 = *input_ptr;
1144       input_ptr += input_ptr_increment;
1145       int16 input = static_cast<int16>(input_s8 + input_offset);
1146       // Load the accumulators from acc_buffer
1147       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1148       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1149       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1150       int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1151       int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
1152       int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5);
1153       int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6);
1154       int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7);
1155       // Multiply-accumulate
1156       acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
1157       acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
1158       acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
1159       acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
1160       acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input);
1161       acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input);
1162       acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input);
1163       acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input);
1164       // Store the accumulators back to acc_buffer
1165       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1166       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1167       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1168       vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1169       vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
1170       vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5);
1171       vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6);
1172       vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7);
1173       acc_buffer_ptr += 32;
1174     }
1175   }
1176 };
1177 
1178 template <>
1179 struct QuantizedDepthwiseConvKernel<true, 1, 20> {
1180   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1181                   const int8* input_ptr, int16 input_offset,
1182                   int input_ptr_increment, const int8* filter_ptr,
1183                   int32* acc_buffer_ptr) {
1184     // Load the filters.
1185     // NEON wants to load 8 bytes at a time, but 20 is not divisible by 8.
1186     // We load the first 16 bytes into filter_s8_{0,1} as usual.
1187     // Then we load the 8 last bytes into filter_s8_x  (x for 'extra').
1188     // This is redundant: the first 4 bytes of filter_s8_x are the same
1189     // as the last 4 bytes of filter_s8_x.
1190     int8x8_t filter_s8_0 = vld1_s8(filter_ptr + 8 * 0);
1191     int8x8_t filter_s8_1 = vld1_s8(filter_ptr + 8 * 1);
1192     int8x8_t filter_s8_x = vld1_s8(filter_ptr + 8 * 1 + 4);
1193     int16x8_t filter_0 = vmovl_s8(filter_s8_0);
1194     int16x8_t filter_1 = vmovl_s8(filter_s8_1);
1195     int16x8_t filter_x = vmovl_s8(filter_s8_x);
1196     // Handle one output pixel at a time.
1197     for (int outp = 0; outp < num_output_pixels; outp++) {
1198       int8 input_s8 = *input_ptr;
1199       input_ptr += input_ptr_increment;
1200       int16 input = static_cast<int16>(input_s8 + input_offset);
1201       // Load the accumulators from acc_buffer
1202       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1203       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1204       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1205       int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3);
1206       int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4);
1207       // Multiply-accumulate
1208       acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input);
1209       acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input);
1210       acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input);
1211       acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input);
1212       acc_4 = vmlal_n_s16(acc_4, vget_high_s16(filter_x), input);
1213       // Store the accumulators back to acc_buffer
1214       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1215       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1216       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1217       vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3);
1218       vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4);
1219       acc_buffer_ptr += 20;
1220     }
1221   }
1222 };
1223 
1224 template <>
1225 struct QuantizedDepthwiseConvKernel<true, 1, 8> {
1226   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1227                   const int8* input_ptr, int16 input_offset,
1228                   int input_ptr_increment, const int8* filter_ptr,
1229                   int32* acc_buffer_ptr) {
1230     // Load the filters.
1231     const int8x8_t filter_s8 = vld1_s8(filter_ptr);
1232     const int16x8_t filter = vmovl_s8(filter_s8);
1233     // Handle one output pixel at a time.
1234     for (int outp = 0; outp < num_output_pixels; outp++) {
1235       int8 input_s8 = *input_ptr;
1236       input_ptr += input_ptr_increment;
1237       int16 input = static_cast<int16>(input_s8 + input_offset);
1238       // Load the accumulators from acc_buffer
1239       int32x4_t acc[2];
1240       for (int i = 0; i < 2; i++) {
1241         acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i);
1242       }
1243       // Multiply-accumulate
1244       acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input);
1245       acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input);
1246       // Store the accumulators back to acc_buffer
1247       for (int i = 0; i < 2; i++) {
1248         vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]);
1249       }
1250       acc_buffer_ptr += 8;
1251     }
1252   }
1253 };
1254 
1255 template <>
1256 struct QuantizedDepthwiseConvKernel<true, 2, 1> {
1257   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1258                   const int8* input_ptr, int16 input_offset,
1259                   int input_ptr_increment, const int8* filter_ptr,
1260                   int32* acc_buffer_ptr) {
1261     // Load the filters.
1262     int8x8_t filter_s8 = vdup_n_s8(0);
1263     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0);
1264     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 1);
1265     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 2);
1266     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 3);
1267     const int16x4_t filter = vget_low_s16(vmovl_s8(filter_s8));
1268 
1269     int outp = 0;
1270 
1271     // Handle 2 output pixels at a time.
1272     for (; outp <= num_output_pixels - 2; outp += 2) {
1273       // Load the accumulators from acc_buffer.
1274       int32x4_t acc = vld1q_s32(acc_buffer_ptr);
1275       // Load the inputs, add input_offset.
1276       int16x4_t input_s16 = vdup_n_s16(0);
1277       input_s16 = vset_lane_s16((reinterpret_cast<const int16*>(input_ptr))[0],
1278                                 input_s16, 0);
1279       input_ptr += input_ptr_increment;
1280       input_s16 = vset_lane_s16((reinterpret_cast<const int16*>(input_ptr))[0],
1281                                 input_s16, 1);
1282       input_ptr += input_ptr_increment;
1283       input_s16 = vget_low_s16(vmovl_s8(vreinterpret_s8_s16(input_s16)));
1284       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1285 
1286       // Multiply-accumulate.
1287       acc = vmlal_s16(acc, filter, input);
1288       // Store the accumulators back to acc_buffer.
1289       vst1q_s32(acc_buffer_ptr, acc);
1290       acc_buffer_ptr += 4;
1291     }
1292 
1293     // Handle 1 output pixel at a time.
1294     for (; outp < num_output_pixels; outp++) {
1295       // Load the accumulators from acc_buffer.
1296       int32x2_t acc = vld1_s32(acc_buffer_ptr);
1297       // Load the inputs, add input_offset.
1298       int8x8_t input_s8 = vdup_n_s8(0);
1299       input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
1300       input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
1301       input_ptr += input_ptr_increment;
1302       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
1303       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1304 
1305       // Multiply-accumulate.
1306       acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input));
1307       // Store the accumulators back to acc_buffer.
1308       vst1_s32(acc_buffer_ptr, acc);
1309       acc_buffer_ptr += 2;
1310     }
1311   }
1312 };
1313 
1314 template <>
1315 struct QuantizedDepthwiseConvKernel<true, 4, 1> {
1316   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1317                   const int8* input_ptr, int16 input_offset,
1318                   int input_ptr_increment, const int8* filter_ptr,
1319                   int32* acc_buffer_ptr) {
1320     if (num_output_pixels <= 0) {
1321       return;
1322     }
1323 
1324     // Load the filters.
1325     int8x8_t filter_s8 = vdup_n_s8(0);
1326     filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0);
1327     filter_s8 = vset_lane_s8(filter_ptr[1], filter_s8, 1);
1328     filter_s8 = vset_lane_s8(filter_ptr[2], filter_s8, 2);
1329     filter_s8 = vset_lane_s8(filter_ptr[3], filter_s8, 3);
1330     const int16x4_t filter = vget_low_s16(vmovl_s8(filter_s8));
1331 
1332     int outp = 0;
1333 
1334     // Handle one output pixel at a time until second to the last pixel. Second
1335     // to the last because we read eight input pixels while only processing
1336     // four.
1337     for (; outp < num_output_pixels - 1; outp++) {
1338       // Load the accumulators from acc_buffer
1339       int32x4_t acc;
1340       acc = vld1q_s32(acc_buffer_ptr);
1341 
1342       // Load the inputs, add input_offset.
1343       int8x8_t input_s8 = vld1_s8(input_ptr);
1344       input_ptr += input_ptr_increment;
1345       const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
1346       const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1347       // Multiply-accumulate
1348       acc = vmlal_s16(acc, filter, input);
1349       // Store the accumulators back to acc_buffer
1350       vst1q_s32(acc_buffer_ptr, acc);
1351       acc_buffer_ptr += 4;
1352     }
1353 
1354     // Handle the last output pixel.
1355     // Load the accumulators from acc_buffer
1356     int32x4_t acc;
1357     acc = vld1q_s32(acc_buffer_ptr);
1358 
1359     // Load the inputs, add input_offset.
1360     int8x8_t input_s8 = vdup_n_s8(0);
1361     input_s8 = vset_lane_s8(input_ptr[0], input_s8, 0);
1362     input_s8 = vset_lane_s8(input_ptr[1], input_s8, 1);
1363     input_s8 = vset_lane_s8(input_ptr[2], input_s8, 2);
1364     input_s8 = vset_lane_s8(input_ptr[3], input_s8, 3);
1365     const int16x4_t input_s16 = vget_low_s16(vmovl_s8(input_s8));
1366     const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset));
1367     // Multiply-accumulate
1368     acc = vmlal_s16(acc, filter, input);
1369     // Store the accumulators back to acc_buffer
1370     vst1q_s32(acc_buffer_ptr, acc);
1371   }
1372 };
1373 
1374 template <>
1375 struct QuantizedDepthwiseConvKernel<false, 12, 1> {
1376   static void Run(int num_output_pixels, int input_depth, int depth_multiplier,
1377                   const int8* input_ptr, int16 input_offset,
1378                   int input_ptr_increment, const int8* filter_ptr,
1379                   int32* acc_buffer_ptr) {
1380     // Load the filters.
1381     int8x8_t filter_s8_0 = vld1_s8(filter_ptr);
1382     int8x8_t filter_s8_1 = vld1_s8(filter_ptr + 4);
1383     int16x8_t filter_s16_0 = vmovl_s8(filter_s8_0);
1384     int16x8_t filter_s16_1 = vmovl_s8(filter_s8_1);
1385     int16x4_t filter_0 = vget_low_s16(filter_s16_0);
1386     int16x4_t filter_1 = vget_high_s16(filter_s16_0);
1387     int16x4_t filter_2 = vget_high_s16(filter_s16_1);
1388 
1389     // Handle one output pixel at a time.
1390     for (int outp = 0; outp < num_output_pixels; outp++) {
1391       // Load the inputs, add input_offset.
1392       int8x8_t input_s8_0 = vld1_s8(input_ptr);
1393       int8x8_t input_s8_1 = vld1_s8(input_ptr + 4);
1394       input_ptr += input_ptr_increment;
1395       int16x8_t input_0 = vmovl_s8(input_s8_0);
1396       int16x8_t input_1 = vmovl_s8(input_s8_1);
1397       input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset));
1398       input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset));
1399 
1400       // Load the accumulators from acc_buffer
1401       int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0);
1402       int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1);
1403       int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2);
1404 
1405       // Multiply-accumulate
1406       acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0);
1407       acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1);
1408       acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2);
1409 
1410       // Store the accumulators back to acc_buffer
1411       vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0);
1412       vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1);
1413       vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2);
1414 
1415       acc_buffer_ptr += 12;
1416     }
1417   }
1418 };
1419 #endif
1420 
1421 // Accumulates the effect of one row of the filter, on a segment of one row
1422 // of the output, accessing the corresponding one row of the input.
1423 template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
1424 void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
1425                                     int input_depth, int input_width,
1426                                     const int8* input_data, int16 input_offset,
1427                                     int pad_width, int depth_multiplier,
1428                                     int filter_width, const int8* filter_data,
1429                                     int out_x_buffer_start,
1430                                     int out_x_buffer_end, int output_depth,
1431                                     int32* acc_buffer) {
1432   ruy::profiler::ScopeLabel label(TFLITE_PRETTY_FUNCTION);
1433   // Consistency check parameters. This is important in particular to ensure
1434   // that we keep the number of template instantiations minimal, so we don't
1435   // increase binary size unnecessarily.
1436   static_assert(kFixedDepthMultiplier || !kFixedInputDepth, "");
1437   static_assert(kFixedInputDepth || kAllowStrided, "");
1438   TFLITE_DCHECK(stride == 1 || kAllowStrided);
1439   if (kFixedInputDepth) {
1440     TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth);
1441   }
1442   if (kFixedDepthMultiplier) {
1443     TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier);
1444   }
1445   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
1446   const int input_ptr_increment = stride * input_depth;
1447   const int8* filter_base_ptr = filter_data;
1448   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
1449     // For the current (filter_x, filter_y) point in the filter,
1450     // compute the boundaries of the corresponding output row segment.
1451     int out_x_loop_start_unclamped = 0;
1452     int out_x_loop_end_unclamped = 0;
1453     if (kAllowStrided) {
1454       if (stride == 2) {
1455         out_x_loop_start_unclamped =
1456             (pad_width - dilation_factor * filter_x + 1) / 2;
1457         out_x_loop_end_unclamped =
1458             (pad_width + input_width - dilation_factor * filter_x + 1) / 2;
1459       } else if (stride == 4) {
1460         out_x_loop_start_unclamped =
1461             (pad_width - dilation_factor * filter_x + 3) / 4;
1462         out_x_loop_end_unclamped =
1463             (pad_width + input_width - dilation_factor * filter_x + 3) / 4;
1464       } else {
1465         out_x_loop_start_unclamped =
1466             (pad_width - dilation_factor * filter_x + stride - 1) / stride;
1467         out_x_loop_end_unclamped = (pad_width + input_width -
1468                                     dilation_factor * filter_x + stride - 1) /
1469                                    stride;
1470       }
1471     } else {
1472       out_x_loop_start_unclamped = pad_width - dilation_factor * filter_x;
1473       out_x_loop_end_unclamped =
1474           pad_width + input_width - dilation_factor * filter_x;
1475     }
1476     // The kernel will have to iterate on the segment of the
1477     // output row that starts at out_x_loop_start and out_x_loop_end.
1478     const int out_x_loop_start =
1479         std::max(out_x_buffer_start, out_x_loop_start_unclamped);
1480     const int out_x_loop_end =
1481         std::min(out_x_buffer_end, out_x_loop_end_unclamped);
1482 
1483     int32* acc_buffer_ptr =
1484         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
1485     const int in_x_origin =
1486         (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
1487     const int8* input_ptr = input_data + in_x_origin * input_depth;
1488     const int num_output_pixels = out_x_loop_end - out_x_loop_start;
1489     QuantizedDepthwiseConvKernel<
1490         kAllowStrided, kFixedInputDepth,
1491         kFixedDepthMultiplier>::Run(num_output_pixels, input_depth,
1492                                     depth_multiplier, input_ptr, input_offset,
1493                                     input_ptr_increment, filter_base_ptr,
1494                                     acc_buffer_ptr);
1495     filter_base_ptr += output_depth;
1496   }
1497 }
1498 
1499 // generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
1500 inline void QuantizedDepthwiseConvAccumRowGeneric(
1501     int stride, int dilation_factor, int input_depth, int input_width,
1502     const int8* input_data, int16 input_offset, int pad_width,
1503     int depth_multiplier, int filter_width, const int8* filter_data,
1504     int out_x_buffer_start, int out_x_buffer_end, int output_depth,
1505     int32* acc_buffer) {
1506   ruy::profiler::ScopeLabel label("DepthwiseConvAccumRowGeneric (slow)");
1507   const int8* filter_base_ptr = filter_data;
1508   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
1509     const int out_x_loop_start = std::max(
1510         out_x_buffer_start,
1511         (pad_width - dilation_factor * filter_x + stride - 1) / stride);
1512     const int out_x_loop_end = std::min(
1513         out_x_buffer_end,
1514         (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
1515             stride);
1516 
1517     int32* acc_buffer_ptr =
1518         acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
1519     const int in_x_origin =
1520         (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
1521     const int8* input_ptr = input_data + in_x_origin * input_depth;
1522     const int input_ptr_increment = (stride - 1) * input_depth;
1523     for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
1524       const int8* filter_ptr = filter_base_ptr;
1525       for (int ic = 0; ic < input_depth; ++ic) {
1526         const int16 input_val = *input_ptr++ + input_offset;
1527         for (int m = 0; m < depth_multiplier; m++) {
1528           const int16 filter_val = *filter_ptr++;
1529           *acc_buffer_ptr++ += static_cast<int32>(filter_val) * input_val;
1530         }
1531       }
1532       input_ptr += input_ptr_increment;
1533     }
1534     filter_base_ptr += output_depth;
1535   }
1536 }
1537 
1538 // Initializes the accumulator buffer with bias values.
1539 inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
1540                                        const int32* bias_data,
1541                                        int32* acc_buffer) {
1542   int i = 0;
1543 #ifdef USE_NEON
1544   if (output_depth == 1) {
1545     const int32x4_t b = vdupq_n_s32(bias_data[0]);
1546     for (; i <= num_output_pixels - 16; i += 16) {
1547       vst1q_s32(acc_buffer + i + 0, b);
1548       vst1q_s32(acc_buffer + i + 4, b);
1549       vst1q_s32(acc_buffer + i + 8, b);
1550       vst1q_s32(acc_buffer + i + 12, b);
1551     }
1552     for (; i <= num_output_pixels - 4; i += 4) {
1553       vst1q_s32(acc_buffer + i, b);
1554     }
1555   } else if (output_depth == 2) {
1556     int32x4_t b = vdupq_n_s32(bias_data[0]);
1557     b = vsetq_lane_s32(bias_data[1], b, 1);
1558     b = vsetq_lane_s32(bias_data[1], b, 3);
1559     for (; i <= num_output_pixels - 8; i += 8) {
1560       vst1q_s32(acc_buffer + 2 * i + 0, b);
1561       vst1q_s32(acc_buffer + 2 * i + 4, b);
1562       vst1q_s32(acc_buffer + 2 * i + 8, b);
1563       vst1q_s32(acc_buffer + 2 * i + 12, b);
1564     }
1565     for (; i <= num_output_pixels - 2; i += 2) {
1566       vst1q_s32(acc_buffer + 2 * i, b);
1567     }
1568   } else if (output_depth == 4) {
1569     const int32x4_t b = vld1q_s32(bias_data);
1570     for (; i <= num_output_pixels - 4; i += 4) {
1571       vst1q_s32(acc_buffer + 4 * i + 0, b);
1572       vst1q_s32(acc_buffer + 4 * i + 4, b);
1573       vst1q_s32(acc_buffer + 4 * i + 8, b);
1574       vst1q_s32(acc_buffer + 4 * i + 12, b);
1575     }
1576     for (; i < num_output_pixels; i++) {
1577       vst1q_s32(acc_buffer + 4 * i, b);
1578     }
1579   } else if (output_depth == 8) {
1580     const int32x4_t b0 = vld1q_s32(bias_data);
1581     const int32x4_t b1 = vld1q_s32(bias_data + 4);
1582     for (; i <= num_output_pixels - 2; i += 2) {
1583       vst1q_s32(acc_buffer + 8 * i + 0, b0);
1584       vst1q_s32(acc_buffer + 8 * i + 4, b1);
1585       vst1q_s32(acc_buffer + 8 * i + 8, b0);
1586       vst1q_s32(acc_buffer + 8 * i + 12, b1);
1587     }
1588     for (; i < num_output_pixels; i++) {
1589       vst1q_s32(acc_buffer + 8 * i + 0, b0);
1590       vst1q_s32(acc_buffer + 8 * i + 4, b1);
1591     }
1592   } else if (output_depth == 16) {
1593     const int32x4_t b0 = vld1q_s32(bias_data);
1594     const int32x4_t b1 = vld1q_s32(bias_data + 4);
1595     const int32x4_t b2 = vld1q_s32(bias_data + 8);
1596     const int32x4_t b3 = vld1q_s32(bias_data + 12);
1597     for (; i < num_output_pixels; i++) {
1598       vst1q_s32(acc_buffer + 16 * i + 0, b0);
1599       vst1q_s32(acc_buffer + 16 * i + 4, b1);
1600       vst1q_s32(acc_buffer + 16 * i + 8, b2);
1601       vst1q_s32(acc_buffer + 16 * i + 12, b3);
1602     }
1603   }
1604 #endif
1605   for (; i < num_output_pixels; i++) {
1606     memcpy(acc_buffer + i * output_depth, bias_data,
1607            sizeof(acc_buffer[0]) * output_depth);
1608   }
1609 }
1610 
1611 inline void DepthwiseConvGeneral(
1612     const DepthwiseParams& params, const int32* output_multiplier,
1613     const int32* output_shift, const RuntimeShape& input_shape,
1614     const int8* input_data, const RuntimeShape& filter_shape,
1615     const int8* filter_data, const RuntimeShape& bias_shape,
1616     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
1617     int thread_start, int thread_end, int thread_dim) {
1618   const int stride_width = params.stride_width;
1619   const int stride_height = params.stride_height;
1620   const int pad_width = params.padding_values.width;
1621   const int pad_height = params.padding_values.height;
1622   const int depth_multiplier = params.depth_multiplier;
1623   const int32 output_activation_min = params.quantized_activation_min;
1624   const int32 output_activation_max = params.quantized_activation_max;
1625   const int32 input_offset = params.input_offset;
1626   const int32 output_offset = params.output_offset;
1627   const int dilation_width_factor = params.dilation_width_factor;
1628   const int dilation_height_factor = params.dilation_height_factor;
1629   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
1630   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
1631   const int input_height = input_shape.Dims(1);
1632   const int input_width = input_shape.Dims(2);
1633   const int input_depth = input_shape.Dims(3);
1634   const int filter_height = filter_shape.Dims(1);
1635   const int filter_width = filter_shape.Dims(2);
1636   const int output_rows = output_shape.Dims(1);
1637   const int output_width = output_shape.Dims(2);
1638 
1639   static const int kAccBufferMaxSize = 2048;
1640   int acc_buffer_size = kAccBufferMaxSize;
1641   int32 stack_acc_buffer[kAccBufferMaxSize];
1642   int32* acc_buffer = stack_acc_buffer;
1643 #ifndef TF_LITE_STATIC_MEMORY
1644   std::unique_ptr<int32[]> heap_acc_buffer;
1645   if (kAccBufferMaxSize < output_depth) {
1646     heap_acc_buffer.reset(new int32[output_depth]);
1647     acc_buffer = heap_acc_buffer.get();
1648     acc_buffer_size = output_depth;
1649   }
1650 #endif
1651   TFLITE_DCHECK_GE(acc_buffer_size, output_depth);
1652   const int kOutputPixelsInAccBuffer = acc_buffer_size / output_depth;
1653   const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth;
1654   TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth,
1655                    kAccBufferActualSize);
1656   TFLITE_DCHECK_LE(kAccBufferActualSize, acc_buffer_size);
1657   TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1);
1658   TFLITE_DCHECK(thread_dim == 0 || thread_dim == 1);
1659 
1660   // row_accum_func will point to the core accumulation function to be used
1661   // for this DepthwiseConv op.
1662   using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric);
1663   row_accum_func_t row_accum_func = nullptr;
1664 
1665 #define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
1666                                         FIXED_DEPTH_MULTIPLIER)           \
1667   if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) &&          \
1668       (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) &&     \
1669       depth_multiplier == FIXED_DEPTH_MULTIPLIER) {                       \
1670     row_accum_func =                                                      \
1671         QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH,  \
1672                                        FIXED_DEPTH_MULTIPLIER>;           \
1673   }
1674 
1675 #ifdef USE_NEON
1676   // We go over our list of kernels by decreasing order of preference
1677   // for the cases where multiple kernels could apply.
1678 
1679   // Start with the fastest kernels: AllowStrided=false, fixed input depth.
1680 
1681   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2)
1682   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2)
1683   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2)
1684   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4)
1685   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1)
1686   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4)
1687   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1)
1688   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8)
1689   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1)
1690   TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1)
1691 
1692   // Next come the strided kernels: AllowStrided=true, fixed input depth.
1693   // They are a bit less efficient, but allow stride!=1.
1694 
1695   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2)
1696   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1)
1697   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16)
1698   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20)
1699   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32)
1700   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8)
1701   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1)
1702   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1)
1703   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1)
1704 
1705   // Finally, the kernels allowing a variable input depth,
1706   // these are the least efficient but most general kernels.
1707 
1708   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1)
1709   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2)
1710   TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3)
1711 #endif  // USE_NEON
1712 
1713   // No matching fast kernel found, use slow fallback.
1714   if (!row_accum_func) {
1715     row_accum_func = QuantizedDepthwiseConvAccumRowGeneric;
1716   }
1717 
1718 #undef TFMINI_USE_DEPTHWISECONV_KERNEL
1719 
1720   const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
1721   const int input_batch_stride = input_height_stride * input_shape.Dims(1);
1722   const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
1723 
1724   // Now that we have determined row_accum_func, we can start work.
1725   int batch_start = 0;
1726   int batch_end = batches;
1727   int row_start = 0;
1728   int row_end = output_rows;
1729   int output_ptr_offset = 0;
1730 
1731   switch (thread_dim) {
1732     case 0:
1733       TFLITE_DCHECK_GE(thread_start, 0);
1734       TFLITE_DCHECK_LE(thread_end, batches);
1735       batch_start = thread_start;
1736       batch_end = thread_end;
1737       output_ptr_offset = batch_start * FlatSizeSkipDim(output_shape, 0);
1738       break;
1739     case 1:
1740       TFLITE_DCHECK_GE(thread_start, 0);
1741       TFLITE_DCHECK_LE(thread_end, output_rows);
1742       row_start = thread_start;
1743       row_end = thread_end;
1744       output_ptr_offset = row_start * output_width * output_depth;
1745       break;
1746   }
1747 
1748   int8* output_ptr = output_data + output_ptr_offset;
1749   int batch_step =
1750       (output_rows + row_start - row_end) * output_width * output_depth;
1751   for (int b = batch_start; b < batch_end; ++b) {
1752     for (int out_y = row_start; out_y < row_end; ++out_y) {
1753       const int in_y_origin = (out_y * stride_height) - pad_height;
1754       const int filter_y_start =
1755           std::max(0, (-in_y_origin + dilation_height_factor - 1) /
1756                           dilation_height_factor);
1757       const int filter_y_end =
1758           std::min(filter_height,
1759                    (input_height - in_y_origin + dilation_height_factor - 1) /
1760                        dilation_height_factor);
1761       for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
1762            out_x_buffer_start += kOutputPixelsInAccBuffer) {
1763         const int out_x_buffer_end = std::min(
1764             output_width, out_x_buffer_start + kOutputPixelsInAccBuffer);
1765         // We call a 'pixel' a group of activation that share all but the
1766         // 'depth'/'channel' coordinate. num_output_pixels is the number of
1767         // output pixels that we will accumulate in this loop iteration.
1768         const int num_output_pixels = out_x_buffer_end - out_x_buffer_start;
1769         // Initialize our local accumulator with the bias values, so we don't
1770         // have to add them later.
1771         DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data,
1772                                    acc_buffer);
1773         // Accumulation loop. Most of the time should be spent in here.
1774         for (int filter_y = filter_y_start; filter_y < filter_y_end;
1775              ++filter_y) {
1776           const int in_y = in_y_origin + dilation_height_factor * filter_y;
1777           row_accum_func(
1778               stride_width, dilation_width_factor, input_depth, input_width,
1779               input_data + in_y * input_height_stride + b * input_batch_stride,
1780               input_offset, pad_width, depth_multiplier, filter_width,
1781               filter_data + filter_y * filter_height_stride, out_x_buffer_start,
1782               out_x_buffer_end, output_depth, acc_buffer);
1783         }
1784         // Finished accumulating int32 values. Now need to convert them to
1785         // the final 8bit form and store them.
1786         ruy::profiler::ScopeLabel label("downquantize+store");
1787         const int num_output_values = output_depth * num_output_pixels;
1788 
1789         optimized_ops::Quantize(output_multiplier, output_shift, output_depth,
1790                                 num_output_values, output_offset,
1791                                 output_activation_min, output_activation_max,
1792                                 acc_buffer, output_ptr);
1793 
1794         output_ptr += num_output_values;
1795       }
1796     }
1797     output_ptr += batch_step;
1798   }
1799 }
1800 
1801 }  // namespace depthwise_conv
1802 
1803 template <DepthwiseConvOutputRounding kOutputRounding>
1804 inline void DepthwiseConvWithRounding(
1805     const DepthwiseParams& params, const int32* output_multiplier,
1806     const int32* output_shift, const RuntimeShape& input_shape,
1807     const int8* input_data, const RuntimeShape& filter_shape,
1808     const int8* filter_data, const RuntimeShape& bias_shape,
1809     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
1810     int thread_start, int thread_end, int thread_dim,
1811     const CpuBackendContext& cpu_backend_context) {
1812   ruy::profiler::ScopeLabel label("DepthwiseConvInt8/8bit");
1813   const int depth_multiplier = params.depth_multiplier;
1814   const int dilation_width_factor = params.dilation_width_factor;
1815   const int dilation_height_factor = params.dilation_height_factor;
1816   TFLITE_DCHECK_GE(dilation_width_factor, 1);
1817   TFLITE_DCHECK_GE(dilation_height_factor, 1);
1818   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1819   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1820   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1821   const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
1822   const int input_depth = input_shape.Dims(3);
1823   TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
1824   TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
1825 
1826 // Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
1827 // Jetson TX-2. This compiler does not support the offsetof() macro.
1828 #if defined(__aarch64__) && !defined(GOOGLE_L4T)
1829 #if defined(__ANDROID__) && defined(__clang__)
1830   CpuFlags cpu_flags;
1831   GetCpuFlags(&cpu_flags);
1832   const bool has_dot_product_instructions = cpu_flags.neon_dotprod;
1833 
1834   // Dispatch to dot-product 3x3 kernels when supported.
1835   if (has_dot_product_instructions) {
1836     using optimized_ops::depthwise_conv::DotProduct3x3KernelType;
1837     DotProduct3x3KernelType kernel_type =
1838         optimized_ops::depthwise_conv::CategorizeDotProductKernel<
1839             optimized_ops::depthwise_conv::QuantizationType::kPerChannelInt8>(
1840             input_shape, filter_shape, output_shape, params, output_shift);
1841     if (kernel_type != DotProduct3x3KernelType::kNone) {
1842       ruy::profiler::ScopeLabel specialized_label(
1843           "DepthwiseConvInt8/8bit/3x3XDotProduct");
1844       DepthwiseParams params_copy = params;
1845       params_copy.output_shift_per_channel = output_shift;
1846       params_copy.output_multiplier_per_channel = output_multiplier;
1847       optimized_ops::depthwise_conv::DepthwiseConvDotProduct3x3PerChannel<
1848           DepthwiseConvImplementation::kUseNeon3x3DotProduct>(
1849           params_copy, input_shape, input_data, filter_shape, filter_data,
1850           bias_shape, bias_data, output_shape, output_data, thread_start,
1851           thread_end, thread_dim);
1852       return;
1853     }
1854   }
1855 
1856 #endif
1857   // Dispatch to non-dot-product 3x3 kernels when supported.
1858 
1859   const int stride_width = params.stride_width;
1860   const int stride_height = params.stride_height;
1861   const int pad_width = params.padding_values.width;
1862   const int pad_height = params.padding_values.height;
1863 
1864   // Call kernel optimized for depthwise convolutions using 3x3 filters if
1865   // parameters are supported.
1866   if (optimized_ops::depthwise_conv::Fast3x3FilterKernelSupported<
1867           optimized_ops::depthwise_conv::QuantizationType::kPerChannelInt8>(
1868           input_shape, filter_shape, stride_width, stride_height,
1869           dilation_width_factor, dilation_height_factor, pad_width, pad_height,
1870           depth_multiplier, output_shape, 0, output_shift)) {
1871     ruy::profiler::ScopeLabel specialized_label("DepthwiseConvInt8/8bit/3x3");
1872     optimized_ops::depthwise_conv::DepthwiseConv3x3FilterPerChannel<
1873         DepthwiseConvOutputRounding::kUpward>(
1874         params, output_multiplier, output_shift, input_shape, input_data,
1875         filter_shape, filter_data, bias_shape, bias_data, output_shape,
1876         output_data, thread_start, thread_end, thread_dim);
1877     return;
1878   }
1879 #endif
1880 
1881   ruy::profiler::ScopeLabel specialized_label("DepthwiseConvInt8/8bit/General");
1882   depthwise_conv::DepthwiseConvGeneral(
1883       params, output_multiplier, output_shift, input_shape, input_data,
1884       filter_shape, filter_data, bias_shape, bias_data, output_shape,
1885       output_data, thread_start, thread_end, thread_dim);
1886 }
1887 
1888 inline void DepthwiseConvImpl(
1889     const DepthwiseParams& params, const int32* output_multiplier,
1890     const int32* output_shift, const RuntimeShape& input_shape,
1891     const int8* input_data, const RuntimeShape& filter_shape,
1892     const int8* filter_data, const RuntimeShape& bias_shape,
1893     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
1894     int thread_start, int thread_end, int thread_dim,
1895     const CpuBackendContext& cpu_backend_context) {
1896   return DepthwiseConvWithRounding<DepthwiseConvOutputRounding::kAwayFromZero>(
1897       params, output_multiplier, output_shift, input_shape, input_data,
1898       filter_shape, filter_data, bias_shape, bias_data, output_shape,
1899       output_data, thread_start, thread_end, thread_dim, cpu_backend_context);
1900 }
1901 
1902 template <typename T, typename TS>
1903 struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task {
1904   DepthwiseConvWorkerTask(const DepthwiseParams& params,
1905                           const int32* output_multiplier,
1906                           const int32* output_shift,
1907                           const RuntimeShape& input_shape, const T* input_data,
1908                           const RuntimeShape& filter_shape,
1909                           const T* filter_data, const RuntimeShape& bias_shape,
1910                           const TS* bias_data, const RuntimeShape& output_shape,
1911                           T* output_data, int thread_start, int thread_end,
1912                           int thread_dim,
1913                           const CpuBackendContext& cpu_backend_context_x)
1914       : params_(params),
1915         output_multiplier_(output_multiplier),
1916         output_shift_(output_shift),
1917         input_shape_(input_shape),
1918         input_data_(input_data),
1919         filter_shape_(filter_shape),
1920         filter_data_(filter_data),
1921         bias_shape_(bias_shape),
1922         bias_data_(bias_data),
1923         output_shape_(output_shape),
1924         output_data_(output_data),
1925         thread_start_(thread_start),
1926         thread_end_(thread_end),
1927         thread_dim_(thread_dim),
1928         cpu_backend_context(cpu_backend_context_x) {}
1929 
1930   void Run() override {
1931     DepthwiseConvImpl(params_, output_multiplier_, output_shift_, input_shape_,
1932                       input_data_, filter_shape_, filter_data_, bias_shape_,
1933                       bias_data_, output_shape_, output_data_, thread_start_,
1934                       thread_end_, thread_dim_, cpu_backend_context);
1935   }
1936 
1937  private:
1938   const DepthwiseParams& params_;
1939   const int32* output_multiplier_;
1940   const int32* output_shift_;
1941   const RuntimeShape& input_shape_;
1942   const T* input_data_;
1943   const RuntimeShape& filter_shape_;
1944   const T* filter_data_;
1945   const RuntimeShape& bias_shape_;
1946   const TS* bias_data_;
1947   const RuntimeShape& output_shape_;
1948   T* output_data_;
1949   int thread_start_;
1950   int thread_end_;
1951   int thread_dim_;
1952   const CpuBackendContext& cpu_backend_context;
1953 };
1954 
1955 inline int HowManyConvThreads(const RuntimeShape& output_shape,
1956                               const RuntimeShape& filter_shape,
1957                               int thread_dim) {
1958   constexpr int kMinMulPerThread = 8;
1959   const int output_units = output_shape.Dims(thread_dim);
1960   const int filter_height = filter_shape.Dims(1);
1961   const int filter_width = filter_shape.Dims(2);
1962   const int num_mul_per_unit =
1963       FlatSizeSkipDim(output_shape, thread_dim) * filter_height * filter_width;
1964   const int min_units_per_thread = kMinMulPerThread / num_mul_per_unit + 1;
1965   int thread_count = output_units / min_units_per_thread;
1966   return thread_count;
1967 }
1968 
1969 inline void DepthwiseConvPerChannel(
1970     const DepthwiseParams& params, const int32* output_multiplier,
1971     const int32* output_shift, const RuntimeShape& input_shape,
1972     const int8* input_data, const RuntimeShape& filter_shape,
1973     const int8* filter_data, const RuntimeShape& bias_shape,
1974     const int32* bias_data, const RuntimeShape& output_shape, int8* output_data,
1975     CpuBackendContext* cpu_backend_context) {
1976   ruy::profiler::ScopeLabel label("DepthwiseConvInt8");
1977   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
1978   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
1979   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
1980 
1981   const int output_batches = output_shape.Dims(0);
1982   const int output_rows = output_shape.Dims(1);
1983   int thread_count_batch = HowManyConvThreads(output_shape, filter_shape, 0);
1984   int thread_count_row = HowManyConvThreads(output_shape, filter_shape, 1);
1985   int thread_dim, thread_count, thread_dim_size;
1986   if (thread_count_batch > thread_count_row) {
1987     thread_dim = 0;
1988     thread_dim_size = output_batches;
1989     thread_count = thread_count_batch;
1990   } else {
1991     thread_dim = 1;
1992     thread_dim_size = output_rows;
1993     thread_count = thread_count_row;
1994   }
1995 
1996   const int max_threads = cpu_backend_context->max_num_threads();
1997   thread_count = std::max(1, std::min(thread_count, max_threads));
1998 
1999   if (thread_count == 1) {
2000     DepthwiseConvImpl(params, output_multiplier, output_shift, input_shape,
2001                       input_data, filter_shape, filter_data, bias_shape,
2002                       bias_data, output_shape, output_data, /*thread_start=*/0,
2003                       /*thread_end=*/output_rows, /*thread_dim=*/1,
2004                       *cpu_backend_context);
2005   } else {
2006     std::vector<DepthwiseConvWorkerTask<int8, int32>> tasks;
2007     // TODO(b/131746020) don't create new heap allocations every time.
2008     // At least we make it a single heap allocation by using reserve().
2009     tasks.reserve(thread_count);
2010     int thread_start = 0;
2011     for (int i = 0; i < thread_count; ++i) {
2012       int thread_end =
2013           thread_start + (thread_dim_size - thread_start) / (thread_count - i);
2014       tasks.emplace_back(params, output_multiplier, output_shift, input_shape,
2015                          input_data, filter_shape, filter_data, bias_shape,
2016                          bias_data, output_shape, output_data, thread_start,
2017                          thread_end, thread_dim, *cpu_backend_context);
2018       thread_start = thread_end;
2019     }
2020     cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
2021                                     cpu_backend_context);
2022   }
2023 }
2024 
2025 }  // namespace optimized_integer_ops
2026 }  // namespace tflite
2027 
2028 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_DEPTHWISE_CONV_H_
2029