1 #include <pytorch_qnnpack.h>
2 #include <qnnpack_func.h>
3 #include <cstring>
4
5 namespace qnnpack {
6 struct q8gemm_context {
7 size_t k;
8 size_t k_stride;
9 size_t n;
10 size_t n_stride;
11 const uint8_t* a;
12 size_t a_stride;
13 const uint8_t* packed_w;
14 uint8_t* c;
15 size_t c_stride;
16 union pytorch_qnnp_conv_quantization_params quantization_params;
17 const pytorch_q8gemm_ukernel_function ukernel;
18 };
19
compute_q8gemm(const struct q8gemm_context context[1],size_t group_index,size_t pixel_index,size_t mr_block_start,size_t nr_block_start,size_t group_range,size_t pixel_range,size_t mr_block_size,size_t nr_block_size)20 static void compute_q8gemm(
21 const struct q8gemm_context context[1],
22 size_t group_index,
23 size_t pixel_index,
24 size_t mr_block_start,
25 size_t nr_block_start,
26 size_t group_range /* always 1 */,
27 size_t pixel_range,
28 size_t mr_block_size,
29 size_t nr_block_size)
30 {
31 const size_t k = context->k;
32 const size_t k_stride = context->k_stride;
33 const size_t n = context->n;
34 const size_t n_stride = context->n_stride;
35 const uint8_t* a = context->a;
36 const size_t a_stride = context->a_stride;
37 const void* packed_w = context->packed_w;
38 uint8_t* c = context->c;
39 const size_t c_stride = context->c_stride;
40
41 size_t output_channel_index = nr_block_start;
42 context->ukernel(
43 mr_block_size,
44 nr_block_size,
45 k,
46 a + (pixel_index + mr_block_start) * a_stride + group_index * k,
47 a_stride,
48 (const void*) ((uintptr_t) packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))),
49 c + (pixel_index + mr_block_start) * c_stride + nr_block_start + group_index * n,
50 c_stride,
51 output_channel_index,
52 &context->quantization_params);
53 }
54
qnnpackLinear(const size_t batch_size,const size_t input_channels,const size_t output_channels,const uint8_t input_zero_point,const uint8_t * kernel_zero_points,const float * requantization_scales,const uint8_t output_zero_point,const uint8_t output_min,const uint8_t output_max,const uint8_t * input,const size_t input_stride,void * packed_weights,uint8_t * output,const size_t output_stride,pthreadpool_t threadpool)55 enum pytorch_qnnp_status qnnpackLinear(
56 const size_t batch_size,
57 const size_t input_channels,
58 const size_t output_channels,
59 const uint8_t input_zero_point,
60 const uint8_t* kernel_zero_points,
61 const float* requantization_scales,
62 const uint8_t output_zero_point,
63 const uint8_t output_min,
64 const uint8_t output_max,
65 const uint8_t* input,
66 const size_t input_stride,
67 void* packed_weights,
68 uint8_t* output,
69 const size_t output_stride,
70 pthreadpool_t threadpool)
71 {
72 const size_t groups = 1;
73 const size_t group_input_channels = input_channels;
74 const size_t group_output_channels = output_channels;
75 const uint32_t mr = pytorch_qnnp_params.q8conv.mr;
76 const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
77 const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
78 const size_t k_stride = (group_input_channels + (kr - 1)) & -kr;
79 const size_t n_stride = (group_output_channels + (nr - 1)) & -nr;
80
81 const size_t output_size = batch_size * 1;
82 union pytorch_qnnp_conv_quantization_params conv_quantization_params =
83 pytorch_qnnp_compute_conv_quantization_params(
84 input_zero_point, kernel_zero_points,
85 requantization_scales, output_zero_point, output_min, output_max);
86
87 struct q8gemm_context q8gemm_context = {
88 .k = group_input_channels,
89 .k_stride = k_stride,
90 .n = group_output_channels,
91 .n_stride = n_stride,
92 .a = input,
93 .a_stride = input_stride,
94 .packed_w = (uint8_t*) packed_weights,
95 .c = output,
96 .c_stride = output_stride,
97 .quantization_params = conv_quantization_params,
98 .ukernel = pytorch_qnnp_params.q8conv.gemm,
99 };
100
101 if (output_size == 0) {
102 // pthreadpool can tolerate a range of 0, but not a tile of 0.
103 // We use output_size as a tile size, so bail here if it's 0.
104 return pytorch_qnnp_status_success;
105 }
106
107 pthreadpool_compute_4d_tiled(
108 threadpool,
109 (pthreadpool_function_4d_tiled_t) compute_q8gemm,
110 &q8gemm_context,
111 groups,
112 1 * output_size,
113 output_size,
114 group_output_channels,
115 1,
116 output_size,
117 mr,
118 nr);
119
120 return pytorch_qnnp_status_success;
121 }
122 } // namespace qnnpack
123