1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18
pytorch_qnnp_create_tanh_nc_q8(size_t channels,uint8_t input_zero_point,float input_scale,uint8_t output_zero_point,float output_scale,uint8_t output_min,uint8_t output_max,uint32_t flags,pytorch_qnnp_operator_t * tanh_out)19 enum pytorch_qnnp_status pytorch_qnnp_create_tanh_nc_q8(
20 size_t channels,
21 uint8_t input_zero_point,
22 float input_scale,
23 uint8_t output_zero_point,
24 float output_scale,
25 uint8_t output_min,
26 uint8_t output_max,
27 uint32_t flags,
28 pytorch_qnnp_operator_t* tanh_out) {
29 pytorch_qnnp_operator_t tanh_op = NULL;
30 enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
31
32 if (!pytorch_qnnp_params.initialized) {
33 pytorch_qnnp_log_error(
34 "pytorch_qnnp_create_tanh_nc_q8 failed because QNNPACK is not properly initialized");
35 goto error;
36 }
37
38 status = pytorch_qnnp_status_invalid_parameter;
39
40 if (channels == 0) {
41 pytorch_qnnp_log_error(
42 "failed to create TanH operator with %zu channels: number of channels must be non-zero",
43 channels);
44 goto error;
45 }
46
47 if (input_scale <= 0.0f || !isnormal(input_scale)) {
48 pytorch_qnnp_log_error(
49 "failed to create TanH operator with %.7g input scale: scale must be finite and positive",
50 input_scale);
51 goto error;
52 }
53
54 if (output_scale <= 0.0f || !isnormal(output_scale)) {
55 pytorch_qnnp_log_error(
56 "failed to create TanH operator with %.7g output scale: scale must be finite and positive",
57 output_scale);
58 goto error;
59 }
60
61 if (output_min >= output_max) {
62 pytorch_qnnp_log_error(
63 "failed to create TanH operator with [%" PRIu8 ", %" PRIu8
64 "] output range: range min must be below range max",
65 output_min,
66 output_max);
67 goto error;
68 }
69
70 status = pytorch_qnnp_status_unsupported_parameter;
71
72 if (output_scale != 0x2.0p-8f) { // [-1, 1] range in 8 bits = 2.0 / 256
73 pytorch_qnnp_log_error(
74 "failed to create TanH operator with %.7g output scale: only output scale of 2/256 is supported",
75 output_scale);
76 goto error;
77 }
78
79 if (output_zero_point != 128) {
80 pytorch_qnnp_log_error(
81 "failed to create TanH operator with %" PRIu8
82 " output zero point: only output zero point of 128 is supported",
83 output_zero_point);
84 goto error;
85 }
86
87 status = pytorch_qnnp_status_out_of_memory;
88
89 tanh_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
90 if (tanh_op == NULL) {
91 pytorch_qnnp_log_error(
92 "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
93 sizeof(struct pytorch_qnnp_operator));
94 goto error;
95 }
96
97 tanh_op->lookup_table = malloc(256 * sizeof(uint8_t));
98 if (tanh_op->lookup_table == NULL) {
99 pytorch_qnnp_log_error(
100 "failed to allocate 256 bytes for TanH lookup table");
101 goto error;
102 }
103
104 uint8_t* lookup_table = tanh_op->lookup_table;
105 const float scaled_min = (float)(int32_t)output_min;
106 const float scaled_max = (float)(int32_t)output_max;
107 for (int32_t i = 0; i < 256; i++) {
108 const float x =
109 input_scale * (float)(i - (int32_t)(uint32_t)input_zero_point);
110 /* Scale tanh(x) by 1 / output scale = 128.0
111 Also, offset by the zero_point from the scaled value, as we assume UINT8
112 */
113 float scaled_tanh_x = 128.0f * tanhf(x) + 128.0f;
114 if (scaled_tanh_x < scaled_min) {
115 scaled_tanh_x = scaled_min;
116 }
117 if (scaled_tanh_x > scaled_max) {
118 scaled_tanh_x = scaled_max;
119 }
120 lookup_table[(uint32_t)i] = (uint8_t)lrintf(scaled_tanh_x);
121 }
122
123 tanh_op->channels = channels;
124
125 tanh_op->ukernel_type = pytorch_qnnp_ukernel_type_lut;
126 tanh_op->format = pytorch_qnnp_format_quint8;
127
128 *tanh_out = tanh_op;
129 return pytorch_qnnp_status_success;
130
131 error:
132 pytorch_qnnp_delete_operator(tanh_op);
133 return status;
134 }
135
pytorch_qnnp_setup_tanh_nc_q8(pytorch_qnnp_operator_t tanh,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)136 enum pytorch_qnnp_status pytorch_qnnp_setup_tanh_nc_q8(
137 pytorch_qnnp_operator_t tanh,
138 size_t batch_size,
139 const uint8_t* input,
140 size_t input_stride,
141 uint8_t* output,
142 size_t output_stride) {
143 if (!pytorch_qnnp_params.initialized) {
144 pytorch_qnnp_log_error(
145 "pytorch_qnnp_setup_tanh_nc_q8 failed because QNNPACK is not properly initialized");
146 return pytorch_qnnp_status_uninitialized;
147 }
148
149 if (batch_size == 0) {
150 tanh->batch_size = 0;
151 return pytorch_qnnp_status_success;
152 }
153
154 tanh->batch_size = batch_size;
155 tanh->input = input;
156 tanh->input_pixel_stride = input_stride;
157 tanh->output = output;
158 tanh->output_pixel_stride = output_stride;
159
160 return pytorch_qnnp_status_success;
161 }
162