xref: /aosp_15_r20/external/XNNPACK/src/subgraph/maximum2.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2020 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <assert.h>
7 #include <math.h>
8 #include <stddef.h>
9 #include <stdint.h>
10 #include <string.h>
11 
12 #include <xnnpack.h>
13 #include <xnnpack/log.h>
14 #include <xnnpack/operator.h>
15 #include <xnnpack/params.h>
16 #include <xnnpack/subgraph.h>
17 #include <xnnpack/subgraph-validation.h>
18 
19 
create_maximum_operator(const struct xnn_node * node,const struct xnn_value * values,size_t num_values,struct xnn_operator_data * opdata,const struct xnn_caches * caches)20 static enum xnn_status create_maximum_operator(
21   const struct xnn_node* node,
22   const struct xnn_value* values,
23   size_t num_values,
24   struct xnn_operator_data* opdata,
25   const struct xnn_caches* caches)
26 {
27   assert(node->compute_type == xnn_compute_type_fp32);
28 
29   assert(node->num_inputs == 2);
30   const uint32_t input1_id = node->inputs[0];
31   assert(input1_id != XNN_INVALID_VALUE_ID);
32   assert(input1_id < num_values);
33   const uint32_t input2_id = node->inputs[1];
34   assert(input2_id != XNN_INVALID_VALUE_ID);
35   assert(input2_id < num_values);
36 
37   assert(node->num_outputs == 1);
38   const uint32_t output_id = node->outputs[0];
39   assert(output_id != XNN_INVALID_VALUE_ID);
40   assert(output_id < num_values);
41 
42   enum xnn_status status;
43   switch (node->compute_type) {
44 #ifndef XNN_NO_F16_OPERATORS
45     case xnn_compute_type_fp16:
46       status = xnn_create_maximum_nd_f16(
47         node->flags,
48         &opdata->operator_objects[0]);
49       break;
50 #endif  // !defined(XNN_NO_F16_OPERATORS)
51     case xnn_compute_type_fp32:
52       status = xnn_create_maximum_nd_f32(
53         node->flags,
54         &opdata->operator_objects[0]);
55       break;
56     default:
57       XNN_UNREACHABLE;
58   }
59   if (status == xnn_status_success) {
60     opdata->shape1.num_dims = values[input1_id].shape.num_dims;
61     opdata->shape2.num_dims = values[input2_id].shape.num_dims;
62     if (values[output_id].layout == xnn_layout_type_nchw) {
63       assert(values[input1_id].layout == xnn_layout_type_nchw);
64       assert(values[input2_id].layout == xnn_layout_type_nchw);
65       opdata->shape1.dim[0] = values[input1_id].shape.dim[0];
66       opdata->shape1.dim[1] = values[input1_id].shape.dim[values[input1_id].shape.num_dims - 1];
67       if (values[input1_id].shape.num_dims > 2) {
68         memcpy(&opdata->shape1.dim[2], &values[input1_id].shape.dim[1], (values[input1_id].shape.num_dims - 2) * sizeof(size_t));
69       }
70       opdata->shape2.dim[0] = values[input2_id].shape.dim[0];
71       opdata->shape2.dim[1] = values[input2_id].shape.dim[values[input2_id].shape.num_dims - 1];
72       if (values[input1_id].shape.num_dims > 2) {
73         memcpy(&opdata->shape2.dim[2], &values[input2_id].shape.dim[1], (values[input2_id].shape.num_dims - 2) * sizeof(size_t));
74       }
75     } else {
76       assert(values[output_id].layout == xnn_layout_type_nhwc);
77       assert(values[input1_id].layout == xnn_layout_type_nhwc);
78       assert(values[input2_id].layout == xnn_layout_type_nhwc);
79       memcpy(opdata->shape1.dim, values[input1_id].shape.dim, values[input1_id].shape.num_dims * sizeof(size_t));
80       memcpy(opdata->shape2.dim, values[input2_id].shape.dim, values[input2_id].shape.num_dims * sizeof(size_t));
81     }
82     opdata->inputs[0] = input1_id;
83     opdata->inputs[1] = input2_id;
84     opdata->outputs[0] = output_id;
85   }
86   return status;
87 }
88 
setup_maximum_operator(const struct xnn_operator_data * opdata,const struct xnn_blob * blobs,size_t num_blobs,pthreadpool_t threadpool)89 static enum xnn_status setup_maximum_operator(
90   const struct xnn_operator_data* opdata,
91   const struct xnn_blob* blobs,
92   size_t num_blobs,
93   pthreadpool_t threadpool)
94 {
95   const uint32_t input1_id = opdata->inputs[0];
96   assert(input1_id != XNN_INVALID_VALUE_ID);
97   assert(input1_id < num_blobs);
98 
99   const uint32_t input2_id = opdata->inputs[1];
100   assert(input2_id != XNN_INVALID_VALUE_ID);
101   assert(input2_id < num_blobs);
102 
103   const uint32_t output_id = opdata->outputs[0];
104   assert(output_id != XNN_INVALID_VALUE_ID);
105   assert(output_id < num_blobs);
106 
107   const struct xnn_blob* input1_blob = blobs + input1_id;
108   const void* input1_data = input1_blob->data;
109   assert(input1_data != NULL);
110 
111   const struct xnn_blob* input2_blob = blobs + input2_id;
112   const void* input2_data = input2_blob->data;
113   assert(input2_data != NULL);
114 
115   const struct xnn_blob* output_blob = blobs + output_id;
116   void* output_data = output_blob->data;
117   assert(output_data != NULL);
118 
119   switch (opdata->operator_objects[0]->type) {
120 #ifndef XNN_NO_F16_OPERATORS
121     case xnn_operator_type_maximum_nd_f16:
122       return xnn_setup_maximum_nd_f16(
123         opdata->operator_objects[0],
124         opdata->shape1.num_dims,
125         opdata->shape1.dim,
126         opdata->shape2.num_dims,
127         opdata->shape2.dim,
128         input1_data, input2_data, output_data,
129         threadpool);
130 #endif  // !defined(XNN_NO_F16_OPERATORS)
131     case xnn_operator_type_maximum_nd_f32:
132       return xnn_setup_maximum_nd_f32(
133         opdata->operator_objects[0],
134         opdata->shape1.num_dims,
135         opdata->shape1.dim,
136         opdata->shape2.num_dims,
137         opdata->shape2.dim,
138         input1_data, input2_data, output_data,
139         threadpool);
140     default:
141       XNN_UNREACHABLE;
142   }
143 }
144 
xnn_define_maximum2(xnn_subgraph_t subgraph,uint32_t input1_id,uint32_t input2_id,uint32_t output_id,uint32_t flags)145 enum xnn_status xnn_define_maximum2(
146   xnn_subgraph_t subgraph,
147   uint32_t input1_id,
148   uint32_t input2_id,
149   uint32_t output_id,
150   uint32_t flags)
151 {
152   enum xnn_status status;
153   if ((status = xnn_subgraph_check_xnnpack_initialized(xnn_node_type_maximum2)) != xnn_status_success) {
154     return status;
155   }
156 
157   if ((status = xnn_subgraph_check_nth_input_node_id(
158         xnn_node_type_maximum2, input1_id, subgraph->num_values, 1)) != xnn_status_success) {
159     return status;
160   }
161 
162   const struct xnn_value* input1_value = &subgraph->values[input1_id];
163   status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_maximum2, input1_id, input1_value, 1);
164   if (status != xnn_status_success) {
165     return status;
166   }
167 
168   switch (input1_value->datatype) {
169     case xnn_datatype_fp32:
170       break;
171     default:
172       xnn_log_error(
173         "failed to define %s operator with the first input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
174         xnn_node_type_to_string(xnn_node_type_maximum2), input1_id,
175         xnn_datatype_to_string(input1_value->datatype), input1_value->datatype);
176       return xnn_status_invalid_parameter;
177   }
178 
179   if ((status = xnn_subgraph_check_nth_input_node_id(
180         xnn_node_type_maximum2, input2_id, subgraph->num_values, 2)) != xnn_status_success) {
181     return status;
182   }
183 
184   const struct xnn_value* input2_value = &subgraph->values[input2_id];
185   status = xnn_subgraph_check_nth_input_type_dense(xnn_node_type_maximum2, input2_id, input2_value, 2);
186   if (status != xnn_status_success) {
187     return status;
188   }
189 
190   switch (input2_value->datatype) {
191     case xnn_datatype_fp32:
192       break;
193     default:
194       xnn_log_error(
195         "failed to define %s operator with the second input ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
196         xnn_node_type_to_string(xnn_node_type_maximum2), input2_id,
197         xnn_datatype_to_string(input2_value->datatype), input2_value->datatype);
198       return xnn_status_invalid_parameter;
199   }
200 
201   status = xnn_subgraph_check_output_node_id(xnn_node_type_maximum2, output_id, subgraph->num_values);
202   if (status != xnn_status_success) {
203     return status;
204   }
205 
206   const struct xnn_value* output_value = &subgraph->values[output_id];
207   status = xnn_subgraph_check_output_type_dense(xnn_node_type_maximum2, output_id, output_value);
208   if (status != xnn_status_success) {
209     return status;
210   }
211 
212   switch (output_value->datatype) {
213     case xnn_datatype_fp32:
214       break;
215     default:
216       xnn_log_error(
217         "failed to define %s operator with output ID #%" PRIu32 ": unsupported Value datatype %s (%d)",
218         xnn_node_type_to_string(xnn_node_type_maximum2), output_id,
219         xnn_datatype_to_string(output_value->datatype), output_value->datatype);
220       return xnn_status_invalid_parameter;
221   }
222 
223   struct xnn_node* node = xnn_subgraph_new_node(subgraph);
224   if (node == NULL) {
225     return xnn_status_out_of_memory;
226   }
227 
228   node->type = xnn_node_type_maximum2;
229   node->compute_type = xnn_compute_type_fp32;
230   node->num_inputs = 2;
231   node->inputs[0] = input1_id;
232   node->inputs[1] = input2_id;
233   node->num_outputs = 1;
234   node->outputs[0] = output_id;
235   node->flags = flags;
236 
237   node->create = create_maximum_operator;
238   node->setup = setup_maximum_operator;
239 
240   return xnn_status_success;
241 }
242