xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-run.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pytorch_qnnpack.h>
2 #include <qnnpack_func.h>
3 #include <cstring>
4 
5 namespace qnnpack {
6 struct q8gemm_context {
7   size_t k;
8   size_t k_stride;
9   size_t n;
10   size_t n_stride;
11   const uint8_t* a;
12   size_t a_stride;
13   const uint8_t* packed_w;
14   uint8_t* c;
15   size_t c_stride;
16   union pytorch_qnnp_conv_quantization_params quantization_params;
17   const pytorch_q8gemm_ukernel_function ukernel;
18 };
19 
compute_q8gemm(const struct q8gemm_context context[1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)20 static void compute_q8gemm(
21     const struct q8gemm_context context[1],
22     size_t group_index,
23     size_t pixel_index,
24     size_t mr_block_start,
25     size_t nr_block_start,
26     size_t group_range /* always 1 */,
27     size_t pixel_range,
28     size_t mr_block_size,
29     size_t nr_block_size)
30 {
31   const size_t k = context->k;
32   const size_t k_stride = context->k_stride;
33   const size_t n = context->n;
34   const size_t n_stride = context->n_stride;
35   const uint8_t* a = context->a;
36   const size_t a_stride = context->a_stride;
37   const void* packed_w = context->packed_w;
38   uint8_t* c = context->c;
39   const size_t c_stride = context->c_stride;
40 
41   size_t output_channel_index = nr_block_start;
42   context->ukernel(
43       mr_block_size,
44       nr_block_size,
45       k,
46       a + (pixel_index + mr_block_start) * a_stride + group_index * k,
47       a_stride,
48       (const void*) ((uintptr_t) packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
49       c + (pixel_index + mr_block_start) * c_stride + nr_block_start + group_index * n,
50       c_stride,
51       output_channel_index,
52       &context->quantization_params);
53 }
54 
qnnpackLinear(const size_t batch_size,const size_t input_channels,const size_t output_channels,const uint8_t input_zero_point,const uint8_t * kernel_zero_points,const float * requantization_scales,const uint8_t output_zero_point,const uint8_t output_min,const uint8_t output_max,const uint8_t * input,const size_t input_stride,void * packed_weights,uint8_t * output,const size_t output_stride,pthreadpool_t threadpool)55 enum pytorch_qnnp_status qnnpackLinear(
56     const size_t batch_size,
57     const size_t input_channels,
58     const size_t output_channels,
59     const uint8_t input_zero_point,
60     const uint8_t* kernel_zero_points,
61     const float* requantization_scales,
62     const uint8_t output_zero_point,
63     const uint8_t output_min,
64     const uint8_t output_max,
65     const uint8_t* input,
66     const size_t input_stride,
67     void* packed_weights,
68     uint8_t* output,
69     const size_t output_stride,
70     pthreadpool_t threadpool)
71 {
72   const size_t groups = 1;
73   const size_t group_input_channels = input_channels;
74   const size_t group_output_channels = output_channels;
75   const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
76   const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
77   const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
78   const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
79   const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
80 
81   const size_t output_size = batch_size * 1;
82   union pytorch_qnnp_conv_quantization_params conv_quantization_params =
83       pytorch_qnnp_compute_conv_quantization_params(
84           input_zero_point, kernel_zero_points,
85           requantization_scales, output_zero_point, output_min, output_max);
86 
87   struct q8gemm_context q8gemm_context = {
88       .k = group_input_channels,
89       .k_stride = k_stride,
90       .n = group_output_channels,
91       .n_stride = n_stride,
92       .a = input,
93       .a_stride = input_stride,
94       .packed_w = (uint8_t*) packed_weights,
95       .c = output,
96       .c_stride = output_stride,
97       .quantization_params = conv_quantization_params,
98       .ukernel = pytorch_qnnp_params.q8conv.gemm,
99   };
100 
101   if (output_size == 0) {
102       // pthreadpool can tolerate a range of 0, but not a tile of 0.
103       // We use output_size as a tile size, so bail here if it's 0.
104       return pytorch_qnnp_status_success;
105   }
106 
107   pthreadpool_compute_4d_tiled(
108       threadpool,
109       (pthreadpool_function_4d_tiled_t) compute_q8gemm,
110       &q8gemm_context,
111       groups,
112       1 * output_size,
113       output_size,
114       group_output_channels,
115       1,
116       output_size,
117       mr,
118       nr);
119 
120   return pytorch_qnnp_status_success;
121 }
122 } // namespace qnnpack
123