1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <assert.h>
10 #include <math.h>
11 #include <stddef.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14
15 #include <pytorch_qnnpack.h>
16 #include <qnnpack/log.h>
17 #include <qnnpack/operator.h>
18 #include <qnnpack/params.h>
19
pytorch_qnnp_create_channel_shuffle_nc_x8(size_t groups,size_t group_channels,uint32_t flags,pytorch_qnnp_operator_t * channel_shuffle_out)20 enum pytorch_qnnp_status pytorch_qnnp_create_channel_shuffle_nc_x8(
21 size_t groups,
22 size_t group_channels,
23 uint32_t flags,
24 pytorch_qnnp_operator_t* channel_shuffle_out) {
25 pytorch_qnnp_operator_t channel_shuffle_op = NULL;
26 enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized;
27
28 if (!pytorch_qnnp_params.initialized) {
29 pytorch_qnnp_log_error(
30 "pytorch_qnnp_create_channel_shuffle_nc_x8 failed because QNNPACK is not properly initialized");
31 goto error;
32 }
33
34 status = pytorch_qnnp_status_invalid_parameter;
35
36 if (groups <= 1) {
37 pytorch_qnnp_log_error(
38 "failed to create channel shuffle operator with %zu groups: "
39 "at least two groups required",
40 groups);
41 goto error;
42 }
43
44 if (group_channels == 0) {
45 pytorch_qnnp_log_error(
46 "failed to create channel shuffle operator with %zu group channels: "
47 "number of group channels must be non-zero",
48 group_channels);
49 goto error;
50 }
51
52 status = pytorch_qnnp_status_out_of_memory;
53
54 channel_shuffle_op = calloc(1, sizeof(struct pytorch_qnnp_operator));
55 if (channel_shuffle_op == NULL) {
56 pytorch_qnnp_log_error(
57 "failed to allocate %zu bytes for pytorch_qnnp_operator structure",
58 sizeof(struct pytorch_qnnp_operator));
59 goto error;
60 }
61
62 channel_shuffle_op->groups = groups;
63 channel_shuffle_op->group_channels = group_channels;
64
65 channel_shuffle_op->ukernel_type = pytorch_qnnp_ukernel_type_channel_shuffle;
66 channel_shuffle_op->format = pytorch_qnnp_format_quint8;
67
68 *channel_shuffle_out = channel_shuffle_op;
69 return pytorch_qnnp_status_success;
70
71 error:
72 pytorch_qnnp_delete_operator(channel_shuffle_op);
73 return status;
74 }
75
pytorch_qnnp_setup_channel_shuffle_nc_x8(pytorch_qnnp_operator_t channel_shuffle_op,size_t batch_size,const uint8_t * input,size_t input_stride,uint8_t * output,size_t output_stride)76 enum pytorch_qnnp_status pytorch_qnnp_setup_channel_shuffle_nc_x8(
77 pytorch_qnnp_operator_t channel_shuffle_op,
78 size_t batch_size,
79 const uint8_t* input,
80 size_t input_stride,
81 uint8_t* output,
82 size_t output_stride) {
83 if (!pytorch_qnnp_params.initialized) {
84 pytorch_qnnp_log_error(
85 "pytorch_qnnp_setup_channel_shuffle_nc_x8 failed because QNNPACK is not properly initialized");
86 return pytorch_qnnp_status_uninitialized;
87 }
88
89 if (batch_size == 0) {
90 channel_shuffle_op->batch_size = 0;
91 return pytorch_qnnp_status_success;
92 }
93
94 channel_shuffle_op->batch_size = batch_size;
95 channel_shuffle_op->input = input;
96 channel_shuffle_op->input_pixel_stride = input_stride;
97 channel_shuffle_op->output = output;
98 channel_shuffle_op->output_pixel_stride = output_stride;
99
100 return pytorch_qnnp_status_success;
101 }
102