xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/fully-connected.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <assert.h>
10 #include <math.h>
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <string.h>
15 
16 #include <pytorch_qnnpack.h>
17 #include <qnnpack/log.h>
18 #include <qnnpack/math.h>
19 #include <qnnpack/operator.h>
20 #include <qnnpack/pack.h>
21 #include <qnnpack/params.h>
22 #include <qnnpack/requantization.h>
23 
pytorch_qnnp_create_fully_connected_nc_q8(size_t input_channels,size_t output_channels,uint8_t input_zero_point,const uint8_t * kernel_zero_points,const uint8_t * kernel,const int32_t * bias,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max,uint32_t flags,const float * requantization_scales,pytorch_qnnp_operator_t * fully_connected_out)24 enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_nc_q8(
25     size_t input_channels,
26     size_t output_channels,
27     uint8_t input_zero_point,
28     const uint8_t* kernel_zero_points,
29     const uint8_t* kernel,
30     const int32_t* bias,
31     uint8_t output_zero_point,
32     uint8_t output_min,
33     uint8_t output_max,
34     uint32_t flags,
35     const float* requantization_scales,
36     pytorch_qnnp_operator_t* fully_connected_out) {
37   pytorch_qnnp_operator_t fully_connected = NULL;
38   enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
39 
40   if (!pytorch_qnnp_params.initialized) {
41     pytorch_qnnp_log_error(
42         "pytorch_qnnp_create_fully_connected_nc_q8 failed because QNNPACK is not properly initialized");
43     goto error;
44   }
45 
46   status = pytorch_qnnp_status_unsupported_parameter;
47 
48   for (int i = 0; i < output_channels; ++i) {
49     if (requantization_scales[i] <= 0.0f ||
50         !isnormal(requantization_scales[i])) {
51       pytorch_qnnp_log_error(
52           "failed to create fully connected operator with %.7g requantization scale: scale must be finite and positive",
53           requantization_scales[i]);
54       goto error;
55     }
56   }
57 
58   status = pytorch_qnnp_status_out_of_memory;
59 
60   fully_connected = calloc(1, sizeof(struct pytorch_qnnp_operator));
61   if (fully_connected == NULL) {
62     pytorch_qnnp_log_error(
63         "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
64         sizeof(struct pytorch_qnnp_operator));
65     goto error;
66   }
67 
68   const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
69   const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
70 
71   const uint32_t n_stride = (output_channels + (nr - 1)) & -nr;
72   const uint32_t k_stride = (input_channels + (kr - 1)) & -kr;
73 
74   fully_connected->packed_weights =
75       malloc(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
76   if (fully_connected->packed_weights == NULL) {
77     pytorch_qnnp_log_error(
78         "failed to allocate %zu bytes for packed weights",
79         n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
80     goto error;
81   }
82   memset(
83       fully_connected->packed_weights,
84       kernel_zero_points[0],
85       n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
86 
87   pytorch_pack_q8gemm_w(
88       output_channels,
89       input_channels,
90       nr,
91       nr,
92       kr,
93 #if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
94       input_zero_point,
95       kernel_zero_points[0],
96 #endif
97       kernel,
98       bias,
99 #if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
100       kernel_zero_points,
101 #endif
102       fully_connected->packed_weights);
103 
104   fully_connected->groups = 1;
105   fully_connected->group_input_channels = input_channels;
106   fully_connected->group_output_channels = output_channels;
107 
108   fully_connected->kernel_zero_point = kernel_zero_points[0];
109 
110   fully_connected->conv_quantization_params =
111       pytorch_qnnp_compute_conv_quantization_params(
112           input_zero_point,
113           kernel_zero_points,
114           requantization_scales,
115           output_zero_point,
116           output_min,
117           output_max);
118 
119   fully_connected->ukernel_type = pytorch_qnnp_ukernel_type_gemm;
120   fully_connected->format = pytorch_qnnp_format_quint8;
121 
122   *fully_connected_out = fully_connected;
123   return pytorch_qnnp_status_success;
124 
125 error:
126   pytorch_qnnp_delete_operator(fully_connected);
127   return status;
128 }
129 
pytorch_qnnp_setup_fully_connected_nc_q8(pytorch_qnnp_operator_t fully_connected,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)130 enum pytorch_qnnp_status pytorch_qnnp_setup_fully_connected_nc_q8(
131     pytorch_qnnp_operator_t fully_connected,
132     size_t batch_size,
133     const uint8_t* input,
134     size_t input_stride,
135     uint8_t* output,
136     size_t output_stride) {
137   if (!pytorch_qnnp_params.initialized) {
138     pytorch_qnnp_log_error(
139         "pytorch_qnnp_setup_fully_connected_nc_q8 failed because QNNPACK is not properly initialized");
140     return pytorch_qnnp_status_uninitialized;
141   }
142 
143   if (batch_size == 0) {
144     fully_connected->batch_size = 0;
145     return pytorch_qnnp_status_success;
146   }
147 
148   fully_connected->batch_size = 1;
149   fully_connected->input_height = batch_size;
150   fully_connected->input_width = 1;
151   fully_connected->input = input;
152   fully_connected->input_pixel_stride = input_stride;
153 
154   fully_connected->output_height = batch_size;
155   fully_connected->output_width = 1;
156   fully_connected->output = output;
157   fully_connected->output_pixel_stride = output_stride;
158 
159   return pytorch_qnnp_status_success;
160 }
161