1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <string.h>
15
16 #include <pytorch_qnnpack.h>
17 #include <qnnpack/log.h>
18 #include <qnnpack/math.h>
19 #include <qnnpack/operator.h>
20 #include <qnnpack/pack.h>
21 #include <qnnpack/params.h>
22 #include <qnnpack/requantization.h>
23
pytorch_qnnp_create_fully_connected_nc_q8(size_t input_channels,size_t output_channels,uint8_t input_zero_point,const uint8_t * kernel_zero_points,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,const float * requantization_scales,pytorch_qnnp_operator_t * fully_connected_out)24 enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_nc_q8(
25 size_t input_channels,
26 size_t output_channels,
27 uint8_t input_zero_point,
28 const uint8_t* kernel_zero_points,
29 const uint8_t* kernel,
30 const int32_t* bias,
31 uint8_t output_zero_point,
32 uint8_t output_min,
33 uint8_t output_max,
34 uint32_t flags,
35 const float* requantization_scales,
36 pytorch_qnnp_operator_t* fully_connected_out) {
37 pytorch_qnnp_operator_t fully_connected = NULL;
38 enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
39
40 if (!pytorch_qnnp_params.initialized) {
41 pytorch_qnnp_log_error(
42 "pytorch_qnnp_create_fully_connected_nc_q8 failed because QNNPACK is not properly initialized");
43 goto error;
44 }
45
46 status = pytorch_qnnp_status_unsupported_parameter;
47
48 for (int i = 0; i < output_channels; ++i) {
49 if (requantization_scales[i] <= 0.0f ||
50 !isnormal(requantization_scales[i])) {
51 pytorch_qnnp_log_error(
52 "failed to create fully connected operator with %.7g requantization scale: scale must be finite and positive",
53 requantization_scales[i]);
54 goto error;
55 }
56 }
57
58 status = pytorch_qnnp_status_out_of_memory;
59
60 fully_connected = calloc(1, sizeof(struct pytorch_qnnp_operator));
61 if (fully_connected == NULL) {
62 pytorch_qnnp_log_error(
63 "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
64 sizeof(struct pytorch_qnnp_operator));
65 goto error;
66 }
67
68 const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
69 const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
70
71 const uint32_t n_stride = (output_channels + (nr - 1)) & -nr;
72 const uint32_t k_stride = (input_channels + (kr - 1)) & -kr;
73
74 fully_connected->packed_weights =
75 malloc(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
76 if (fully_connected->packed_weights == NULL) {
77 pytorch_qnnp_log_error(
78 "failed to allocate %zu bytes for packed weights",
79 n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
80 goto error;
81 }
82 memset(
83 fully_connected->packed_weights,
84 kernel_zero_points[0],
85 n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
86
87 pytorch_pack_q8gemm_w(
88 output_channels,
89 input_channels,
90 nr,
91 nr,
92 kr,
93 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
94 input_zero_point,
95 kernel_zero_points[0],
96 #endif
97 kernel,
98 bias,
99 #if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
100 kernel_zero_points,
101 #endif
102 fully_connected->packed_weights);
103
104 fully_connected->groups = 1;
105 fully_connected->group_input_channels = input_channels;
106 fully_connected->group_output_channels = output_channels;
107
108 fully_connected->kernel_zero_point = kernel_zero_points[0];
109
110 fully_connected->conv_quantization_params =
111 pytorch_qnnp_compute_conv_quantization_params(
112 input_zero_point,
113 kernel_zero_points,
114 requantization_scales,
115 output_zero_point,
116 output_min,
117 output_max);
118
119 fully_connected->ukernel_type = pytorch_qnnp_ukernel_type_gemm;
120 fully_connected->format = pytorch_qnnp_format_quint8;
121
122 *fully_connected_out = fully_connected;
123 return pytorch_qnnp_status_success;
124
125 error:
126 pytorch_qnnp_delete_operator(fully_connected);
127 return status;
128 }
129
pytorch_qnnp_setup_fully_connected_nc_q8(pytorch_qnnp_operator_t fully_connected,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)130 enum pytorch_qnnp_status pytorch_qnnp_setup_fully_connected_nc_q8(
131 pytorch_qnnp_operator_t fully_connected,
132 size_t batch_size,
133 const uint8_t* input,
134 size_t input_stride,
135 uint8_t* output,
136 size_t output_stride) {
137 if (!pytorch_qnnp_params.initialized) {
138 pytorch_qnnp_log_error(
139 "pytorch_qnnp_setup_fully_connected_nc_q8 failed because QNNPACK is not properly initialized");
140 return pytorch_qnnp_status_uninitialized;
141 }
142
143 if (batch_size == 0) {
144 fully_connected->batch_size = 0;
145 return pytorch_qnnp_status_success;
146 }
147
148 fully_connected->batch_size = 1;
149 fully_connected->input_height = batch_size;
150 fully_connected->input_width = 1;
151 fully_connected->input = input;
152 fully_connected->input_pixel_stride = input_stride;
153
154 fully_connected->output_height = batch_size;
155 fully_connected->output_width = 1;
156 fully_connected->output = output;
157 fully_connected->output_pixel_stride = output_stride;
158
159 return pytorch_qnnp_status_success;
160 }
161