xref: /aosp_15_r20/external/ComputeLibrary/src/graph/mutators/NodeFusionMutator.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/mutators/NodeFusionMutator.h"
25 
26 #include "arm_compute/graph/GraphBuilder.h"
27 #include "arm_compute/graph/Logger.h"
28 #include "arm_compute/graph/Utils.h"
29 #include "arm_compute/graph/backends/BackendRegistry.h"
30 #include "arm_compute/graph/nodes/FusedConvolutionBatchNormalizationNode.h"
31 #include "arm_compute/graph/nodes/FusedConvolutionBatchNormalizationWithPostOpsNode.h"
32 #include "arm_compute/graph/nodes/FusedConvolutionWithPostOpNode.h"
33 #include "arm_compute/graph/nodes/Nodes.h"
34 
35 #include "src/graph/mutators/MutatorUtils.h"
36 
37 #include "support/Cast.h"
38 
39 #include <list>
40 #include <set>
41 
42 namespace arm_compute
43 {
44 namespace graph
45 {
46 namespace detail
47 {
transfer_driving_nodes_and_remove_old_node(Graph & g,INode * new_node,INode * old_node,bool add_output_tensor)48 void transfer_driving_nodes_and_remove_old_node(Graph &g, INode *new_node, INode *old_node, bool add_output_tensor)
49 {
50     if(new_node == nullptr || old_node == nullptr)
51     {
52         return;
53     }
54 
55     // Get driving nodes of last fusable node
56     std::vector<NodeIdxPair> last_driving_nodes = get_driving_nodes(*old_node);
57 
58     // Extract last fusable node accessor if any
59     if(old_node->output(0) == nullptr)
60     {
61         return;
62     }
63     auto old_node_accessor = old_node->output(0)->extract_accessor();
64 
65     // Remove node
66     g.remove_node(old_node->id());
67 
68     // Update fused node outputs
69     for(auto &driving_node : last_driving_nodes)
70     {
71         g.add_connection(new_node->id(), 0, driving_node.node_id, driving_node.index);
72         if(add_output_tensor)
73         {
74             configure_tensor(new_node->output(0));
75         }
76     }
77 
78     // Update accessor to fused node
79     new_node->output(0)->set_accessor(std::move(old_node_accessor));
80 }
81 
fuse_convolution_with_batch_normalization(Graph & g,const Edge * output_edge)82 void fuse_convolution_with_batch_normalization(Graph &g, const Edge *output_edge)
83 {
84     ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
85 
86     auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(output_edge->producer());
87     auto *bn_node   = arm_compute::utils::cast::polymorphic_downcast<BatchNormalizationLayerNode *>(output_edge->consumer());
88 
89     // Not fusing if number of groups is greater than 1
90     if(conv_node->num_groups() > 1)
91     {
92         return;
93     }
94 
95     ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing convolution node with ID : " << output_edge->producer_id()
96                                   << " with BatchNormalization Layer node with ID : " << output_edge->consumer_id() << std::endl);
97 
98     // Prevent fusion if fused node has an output accessor
99     if(conv_node->output(0)->accessor() == nullptr)
100     {
101         const Target assigned_target = conv_node->assigned_target();
102 
103         // Extract conv inputs
104         const auto   conv_input_id   = conv_node->input_edge(0)->producer_id();
105         const auto   conv_weights_id = conv_node->input_edge(1)->producer_id();
106         const auto   conv_info       = conv_node->convolution_info();
107         const auto   conv_method     = conv_node->convolution_method();
108         const auto   num_groups      = conv_node->num_groups();
109         const auto   act_info        = bn_node->fused_activation();
110         FastMathHint fast_math_hint  = conv_node->fast_math_hint();
111 
112         // Extract bn inputs
113         const auto bn_mean_id = bn_node->input_edge(1)->producer_id();
114         const auto bn_var_id  = bn_node->input_edge(2)->producer_id();
115 
116         const auto epsilon = bn_node->epsilon();
117 
118         // Create the fused node
119         const NodeID fused_id = g.add_node<FusedConvolutionBatchNormalizationNode>(epsilon, conv_info, num_groups, conv_method, fast_math_hint, act_info);
120 
121         if(conv_node->input_edge(2) != nullptr)
122         {
123             auto conv_bias_id = conv_node->input_edge(2)->producer_id();
124             g.add_connection(conv_bias_id, 0, fused_id, 2);
125         }
126 
127         // Add connections from the conv/batch_norm inputs to the fused node
128         g.add_connection(conv_input_id, 0, fused_id, 0);
129         g.add_connection(conv_weights_id, 0, fused_id, 1);
130         g.add_connection(bn_mean_id, 0, fused_id, 3);
131         g.add_connection(bn_var_id, 0, fused_id, 4);
132 
133         if(bn_node->input_edge(3) != nullptr)
134         {
135             const auto bn_beta_id = bn_node->input_edge(3)->producer_id();
136             g.add_connection(bn_beta_id, 0, fused_id, 5);
137         }
138 
139         if(bn_node->input_edge(4) != nullptr)
140         {
141             const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
142             g.add_connection(bn_gamma_id, 0, fused_id, 6);
143         }
144 
145         auto fused_node   = g.node(fused_id);
146         auto bn_node_name = bn_node->name();
147 
148         transfer_driving_nodes_and_remove_old_node(g, fused_node, bn_node, true);
149 
150         fused_node->set_assigned_target(assigned_target);
151         fused_node->set_common_node_parameters(NodeParams{ conv_node->name() + "+" + bn_node_name, assigned_target });
152 
153         // Remove convolution node
154         g.remove_node(conv_node->id());
155     }
156     else
157     {
158         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution with batch normalization due to the presence of an output accessor\n");
159     }
160 }
161 
fuse_depthwise_convolution_with_batch_normalization(Graph & g,const Edge * output_edge)162 void fuse_depthwise_convolution_with_batch_normalization(Graph &g, const Edge *output_edge)
163 {
164     ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
165 
166     auto *depth_conv_node = arm_compute::utils::cast::polymorphic_downcast<DepthwiseConvolutionLayerNode *>(output_edge->producer());
167     auto *bn_node         = arm_compute::utils::cast::polymorphic_downcast<BatchNormalizationLayerNode *>(output_edge->consumer());
168 
169     ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing depthwise convolution node with ID : " << output_edge->producer_id()
170                                   << " with BatchNormalization Layer node with ID : " << output_edge->consumer_id() << std::endl);
171 
172     // Prevent fusion if fused node has an output accessor
173     if(depth_conv_node->output(0)->accessor() == nullptr)
174     {
175         const Target assigned_target = depth_conv_node->assigned_target();
176 
177         // Extract conv inputs
178         const auto depth_conv_input_id = depth_conv_node->input_edge(0)->producer_id();
179         const auto conv_weights_id     = depth_conv_node->input_edge(1)->producer_id();
180         const auto conv_info           = depth_conv_node->convolution_info();
181         const auto depth_conv_method   = depth_conv_node->depthwise_convolution_method();
182         const auto depth_multiplier    = depth_conv_node->depth_multiplier();
183         const auto act_info            = bn_node->fused_activation();
184 
185         // Extract bn inputs
186         const auto bn_mean_id  = bn_node->input_edge(1)->producer_id();
187         const auto bn_var_id   = bn_node->input_edge(2)->producer_id();
188         const auto bn_beta_id  = bn_node->input_edge(3)->producer_id();
189         const auto bn_gamma_id = bn_node->input_edge(4)->producer_id();
190         const auto epsilon     = bn_node->epsilon();
191 
192         // Create the fused node
193         const NodeID fused_id = g.add_node<FusedDepthwiseConvolutionBatchNormalizationNode>(epsilon, conv_info, depth_multiplier, depth_conv_method, act_info);
194 
195         if(depth_conv_node->input_edge(2) != nullptr)
196         {
197             const auto conv_bias_id = depth_conv_node->input_edge(2)->producer_id();
198             g.add_connection(conv_bias_id, 0, fused_id, 2);
199         }
200 
201         // Add connections from the conv/batch_norm inputs to the fused node
202         g.add_connection(depth_conv_input_id, 0, fused_id, 0);
203         g.add_connection(conv_weights_id, 0, fused_id, 1);
204         g.add_connection(bn_mean_id, 0, fused_id, 3);
205         g.add_connection(bn_var_id, 0, fused_id, 4);
206         g.add_connection(bn_beta_id, 0, fused_id, 5);
207         g.add_connection(bn_gamma_id, 0, fused_id, 6);
208 
209         auto fused_node   = g.node(fused_id);
210         auto bn_node_name = bn_node->name();
211 
212         transfer_driving_nodes_and_remove_old_node(g, fused_node, bn_node, true);
213 
214         fused_node->set_assigned_target(assigned_target);
215         fused_node->set_common_node_parameters(NodeParams{ depth_conv_node->name() + "+" + bn_node_name, assigned_target });
216 
217         // Remove convolution node
218         g.remove_node(depth_conv_node->id());
219     }
220     else
221     {
222         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of depthwise convolution with batch normalization due to the presence of an output accessor\n");
223     }
224 }
225 
226 template <typename N>
fuse_node_with_activation(Graph & g,const Edge * output_edge,const std::set<Activation> & supported_fused_activations)227 void fuse_node_with_activation(Graph &g, const Edge *output_edge, const std::set<Activation> &supported_fused_activations)
228 {
229     ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
230 
231     auto *n_node   = arm_compute::utils::cast::polymorphic_downcast<N *>(output_edge->producer());
232     auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(output_edge->consumer());
233 
234     ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr || n_node->output(0) == nullptr);
235 
236     // Check if activation is supported for fusion
237     if(supported_fused_activations.count(act_node->activation_info().activation()) == 0)
238     {
239         return;
240     }
241 
242     // EltwiseLayerNode can only be fused when dataype is float
243     if(n_node->type() == NodeType::EltwiseLayer && !is_data_type_float(n_node->output(0)->desc().data_type))
244     {
245         return;
246     }
247 
248     ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing node with ID : " << output_edge->producer_id()
249                                   << " with Activation Layer node with ID : " << output_edge->consumer_id() << std::endl);
250 
251     // Prevent fusion if fused node has an output accessor
252     if(n_node->output(0)->accessor() == nullptr)
253     {
254         // Set activation info to fused node
255         n_node->set_fused_activation(act_node->activation_info());
256 
257         transfer_driving_nodes_and_remove_old_node(g, n_node, act_node, false);
258     }
259     else
260     {
261         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of node with activation due to the presence of an output accessor\n");
262     }
263 }
264 
265 template <typename N>
fuse_pad_with_convolution(Graph & g,const Edge * output_edge)266 void fuse_pad_with_convolution(Graph &g, const Edge *output_edge)
267 {
268     auto *pad_node  = arm_compute::utils::cast::polymorphic_downcast<PadLayerNode *>(output_edge->producer());
269     auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<N *>(output_edge->consumer());
270 
271     const Edge *input_edge = pad_node->input_edge(0);
272     if(input_edge != nullptr && input_edge->tensor() != nullptr && pad_node->output(0)->accessor() == nullptr
273        && pad_node->pad_value().get<float>() == 0.0)
274     {
275         const DataLayout  layout       = input_edge->tensor()->desc().layout;
276         const PaddingList padding_list = pad_node->padding();
277 
278         const unsigned int height_index = get_dimension_idx(layout, DataLayoutDimension::HEIGHT);
279         const unsigned int width_index  = get_dimension_idx(layout, DataLayoutDimension::WIDTH);
280 
281         const PaddingInfo pad_w = width_index < padding_list.size() ? padding_list[width_index] : PaddingInfo(0, 0);
282         const PaddingInfo pad_h = height_index < padding_list.size() ? padding_list[height_index] : PaddingInfo(0, 0);
283 
284         if(is_padding_in_height_or_width(layout, padding_list))
285         {
286             // Add paddings to the convolution node
287             const PadStrideInfo conv_info = conv_node->convolution_info();
288             const PadStrideInfo new_conv_info(
289                 conv_info.stride().first,
290                 conv_info.stride().second,
291                 conv_info.pad_left() + pad_w.first,
292                 conv_info.pad_right() + pad_w.second,
293                 conv_info.pad_top() + pad_h.first,
294                 conv_info.pad_bottom() + pad_h.second,
295                 conv_info.round());
296             conv_node->set_convolution_info(new_conv_info);
297 
298             // Update drivers of the convolution node
299             std::vector<NodeIdxPair> pad_driver_nodes = get_driver_nodes(*pad_node);
300             g.remove_node(pad_node->id());
301 
302             // Update fused node inputs
303             for(auto &driver_node : pad_driver_nodes)
304             {
305                 g.add_connection(driver_node.node_id, driver_node.index, conv_node->id(), 0);
306             }
307         }
308     }
309 }
310 
311 template <typename N1, typename N2, typename F, typename... Args>
fuse_layer(Graph & g,std::function<bool (INode &)> const & prec,const F fuse_fcn,Args &&...optional_arguments)312 void fuse_layer(Graph &g, std::function<bool(INode &)> const &prec, const F fuse_fcn, Args &&... optional_arguments)
313 {
314     // Note that fused nodes may be added to the end of the node list.
315     // Instead of only looping over the original list of nodes, we loop over the current node list which could be growing.
316     // This is intentional as it probes the newly added fused nodes for further fusing opportunities.
317     for(unsigned int i = 0; i < g.nodes().size(); ++i)
318     {
319         auto node = g.node(i);
320         // Check if the node is of type N1 and not a branching node
321         if(node && node->type() == N1::node_type && node->output_edges().size() == 1)
322         {
323             const auto output_edge_id = *node->output_edges().begin();
324             const auto output_edge    = g.edge(output_edge_id);
325 
326             // Check if following node is a type N2 node
327             if((output_edge != nullptr) && (output_edge->consumer() != nullptr) && (output_edge->consumer()->type() == N2::node_type) && prec(*output_edge->producer()))
328             {
329                 fuse_fcn(g, output_edge, optional_arguments...);
330             }
331         }
332     }
333 }
334 
335 /** Check valid combinations:
336  *
337  * | Main operator | Post operators             |
338  * |:--------------|:---------------------------|
339  * |conv           | add                        |
340  * |conv           | act + add                  |
341  * |conv           | add + act                  |
342  * |conv           | act + add + act            |
343  *
344 */
345 #define MAX_VALIDE_COMBINATION 4
346 #define MAX_POST_OP_NUM 3
347 NodeType valide_post_op_type[MAX_VALIDE_COMBINATION][MAX_POST_OP_NUM] = { { EltwiseLayerNode::node_type },
348     { EltwiseLayerNode::node_type, ActivationLayerNode::node_type },
349     { ActivationLayerNode::node_type, EltwiseLayerNode::node_type },
350     { ActivationLayerNode::node_type, EltwiseLayerNode::node_type, ActivationLayerNode::node_type }
351 };
352 
check_post_op_type(NodeType * post_op_type,int len)353 bool check_post_op_type(NodeType *post_op_type, int len)
354 {
355     if(len > MAX_POST_OP_NUM || len <= 0)
356     {
357         return false;
358     }
359 
360     bool found = false;
361     for(int i = 0; i < MAX_VALIDE_COMBINATION; ++i)
362     {
363         for(int j = 0; j < len; ++j)
364         {
365             if(post_op_type[j] != valide_post_op_type[i][j])
366             {
367                 found = false;
368                 break;
369             }
370             found = true;
371         }
372         if(found)
373             break;
374     }
375 
376     return found;
377 }
378 
fuse_convolution_with_post_op(Graph & g,INode * fused_node,std::list<INode * > post_op_node_list,int prev_op_dst_pos)379 void fuse_convolution_with_post_op(Graph &g, INode *fused_node, std::list<INode *> post_op_node_list, int prev_op_dst_pos)
380 {
381     unsigned int op_idx = 0;
382     // Fuse post operators with conv
383     for(const auto &post_op : post_op_node_list)
384     {
385         switch(post_op->type())
386         {
387             case EltwiseLayerNode::node_type:
388             {
389                 auto *eltwise_node = arm_compute::utils::cast::polymorphic_downcast<EltwiseLayerNode *>(post_op);
390                 ARM_COMPUTE_ERROR_ON(eltwise_node->output(0) == nullptr);
391 
392                 fused_node->post_op_info_list().push_back(std::make_unique<ConvPostOpInfoEltwiseAdd>(prev_op_dst_pos, eltwise_node->convert_policy()));
393                 ARM_COMPUTE_LOG_GRAPH_VERBOSE(" with Elementwise Layer node with ID : " << post_op->id());
394                 break;
395             }
396             case ActivationLayerNode::node_type:
397             {
398                 auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(post_op);
399                 ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr);
400 
401                 fused_node->post_op_info_list().push_back(std::make_unique<ConvPostOpInfoActivation>(act_node->activation_info()));
402                 ARM_COMPUTE_LOG_GRAPH_VERBOSE(" with Activation Layer node with ID : " << post_op->id());
403                 break;
404             }
405             default:
406             {
407                 break;
408             }
409         }
410 
411         if(op_idx == post_op_node_list.size() - 1) // last fusable node
412         {
413             transfer_driving_nodes_and_remove_old_node(g, fused_node, post_op, true);
414         }
415         else
416         {
417             // Remove node
418             g.remove_node(post_op->id());
419         }
420         op_idx++;
421     }
422 }
423 
get_post_op_list(Graph & g,int & eltwise_operand_id,int & prev_op_dst_pos,unsigned int conv_node_id,const std::set<Activation> & supported_fused_activations)424 std::list<INode *> get_post_op_list(Graph &g, int &eltwise_operand_id, int &prev_op_dst_pos, unsigned int conv_node_id, const std::set<Activation> &supported_fused_activations)
425 {
426     std::list<INode *> post_op_node_list    = {};
427     NodeID             prev_op_dst_id       = conv_node_id;
428     NodeType           post_op_type_list[3] = { NodeType::Dummy, NodeType::Dummy, NodeType::Dummy };
429     int                post_op_idx          = 0;
430 
431     // Get list of the connected nodes
432     auto current_node = g.node(conv_node_id);
433 
434     while(post_op_node_list.size() < 3)
435     {
436         // This convolution node must have only one output edge, otherwise this function would not have been called
437 
438         auto current_output_edge_id = current_node->output_edges().begin();
439         auto current_output_edge    = g.edge(*current_output_edge_id);
440         auto post_op_node           = current_output_edge->consumer();
441 
442         bool fusable_post_op = false;
443         if(post_op_node != nullptr && post_op_node->output_edges().size() > 0)
444         {
445             switch(post_op_node->type())
446             {
447                 case EltwiseLayerNode::node_type:
448                 {
449                     auto *eltwise_node = arm_compute::utils::cast::polymorphic_downcast<EltwiseLayerNode *>(post_op_node);
450                     ARM_COMPUTE_ERROR_ON(eltwise_node->output(0) == nullptr);
451                     if(eltwise_node->output(0)->accessor() == nullptr)
452                     {
453                         post_op_node_list.push_back(post_op_node);
454                         fusable_post_op                  = true;
455                         post_op_type_list[post_op_idx++] = eltwise_node->type();
456 
457                         // Extract elementwise inputs
458                         const auto eltwise_input_id_0 = eltwise_node->input_edge(0)->producer_id();
459                         const auto eltwise_input_id_1 = eltwise_node->input_edge(1)->producer_id();
460                         if(eltwise_input_id_0 == prev_op_dst_id)
461                         {
462                             eltwise_operand_id = eltwise_input_id_1;
463                             prev_op_dst_pos    = 0;
464                         }
465                         else if(eltwise_input_id_1 == prev_op_dst_id)
466                         {
467                             eltwise_operand_id = eltwise_input_id_0;
468                             prev_op_dst_pos    = 1;
469                         }
470                     }
471                     else
472                     {
473                         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with elementwise due to the presence of an output accessor\n");
474                     }
475                     break;
476                 }
477                 case ActivationLayerNode::node_type:
478                 {
479                     auto *act_node = arm_compute::utils::cast::polymorphic_downcast<ActivationLayerNode *>(post_op_node);
480                     ARM_COMPUTE_ERROR_ON(act_node->output(0) == nullptr);
481                     // Check if activation is supported for fusion
482                     if(supported_fused_activations.count(act_node->activation_info().activation()) == 0)
483                     {
484                         break;
485                     }
486                     if(act_node->output(0)->accessor() == nullptr)
487                     {
488                         post_op_node_list.push_back(post_op_node);
489                         fusable_post_op                  = true;
490                         post_op_type_list[post_op_idx++] = act_node->type();
491                         prev_op_dst_id                   = act_node->id();
492                     }
493                     else
494                     {
495                         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to the presence of an output accessor\n");
496                     }
497                     break;
498                 }
499                 default:
500                 {
501                     break;
502                 }
503             }
504 
505             // Check if the node is not a branching node and current node is fusable
506             if(post_op_node->output_edges().size() == 1 && fusable_post_op == true)
507             {
508                 current_node = post_op_node;
509             }
510             else
511             {
512                 break;
513             }
514         }
515     }
516 
517     // Check whether it's valid post op list
518     if(post_op_node_list.size() > 0)
519     {
520         bool fuse_with_post_op = check_post_op_type(post_op_type_list, post_op_node_list.size());
521         if(!fuse_with_post_op)
522         {
523             post_op_node_list.clear();
524         }
525     }
526 
527     return post_op_node_list;
528 }
529 
530 /** Fuse below operators:
531  *
532  * | Main operator | Post operators             |
533  * |:--------------|:---------------------------|
534  * |conv           | add                        |
535  * |conv           | act + add                  |
536  * |conv           | add + act                  |
537  * |conv           | act + add + act            |
538  *
539  * Notes: currently, only GEMM supports fusion with post operator
540 */
fuse_convolution_with_post_ops(Graph & g,const Edge * output_edge,unsigned int conv_node_id,const std::set<Activation> & supported_fused_activations)541 void fuse_convolution_with_post_ops(Graph &g, const Edge *output_edge, unsigned int conv_node_id, const std::set<Activation> &supported_fused_activations)
542 {
543     ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
544 
545     auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(output_edge->producer());
546     ARM_COMPUTE_ERROR_ON(conv_node->output(0) == nullptr);
547 
548     const ConvolutionMethod conv_algorithm = conv_node->convolution_method();
549     if(conv_algorithm != ConvolutionMethod::GEMM)
550     {
551         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to non GEMM convolution\n");
552         return;
553     }
554 
555     // Prevent fusion if fused node has an output accessor
556     if(conv_node->output(0)->accessor() == nullptr)
557     {
558         // If data type is FP32/FP16, data layout is NHWC, and filter size is 1x1, fuse convolution with post op, as Conv1x1 always leads to GEMM.
559         const Edge *input_edge = conv_node->input_edge(1);
560         if(input_edge != nullptr && input_edge->tensor() != nullptr)
561         {
562             const DataLayout  data_layout  = input_edge->tensor()->desc().layout;
563             const DataType    data_type    = input_edge->tensor()->desc().data_type;
564             const TensorShape tensor_shape = input_edge->tensor()->desc().shape;
565             if((data_layout != DataLayout::NHWC) || (is_data_type_float(data_type) == false) || (tensor_shape.y() != 1) || (tensor_shape.z() != 1))
566             {
567                 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to non GEMM convolution\n");
568                 return;
569             }
570         }
571         else
572         {
573             return;
574         }
575 
576         // Get post op list
577         int                eltwise_operand_id = 0;
578         int                prev_op_dst_pos    = 0; // Previous operator dst's postion in current operator
579         std::list<INode *> post_op_node_list  = get_post_op_list(g, eltwise_operand_id, prev_op_dst_pos, conv_node_id, supported_fused_activations);
580 
581         if(post_op_node_list.size() == 0)
582         {
583             return;
584         }
585         else // Do convolution fusion with post op if there're one(elementwise), two or more operators
586         {
587             const Target assigned_target = conv_node->assigned_target();
588 
589             // Extract conv inputs
590             const auto   conv_input_id   = conv_node->input_edge(0)->producer_id();
591             const auto   conv_weights_id = conv_node->input_edge(1)->producer_id();
592             const auto   conv_info       = conv_node->convolution_info();
593             const auto   conv_method     = conv_node->convolution_method();
594             const auto   num_groups      = conv_node->num_groups();
595             FastMathHint fast_math_hint  = conv_node->fast_math_hint();
596 
597             // Create the fused node
598             const NodeID fused_id = g.add_node<FusedConvolutionWithPostOpNode>(conv_info, num_groups, conv_method, fast_math_hint);
599             ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing convolution node with ID : " << conv_node->id());
600 
601             // Add connections from the conv inputs to the fused node
602             g.add_connection(conv_input_id, 0, fused_id, 0);
603             g.add_connection(conv_weights_id, 0, fused_id, 1);
604             if(conv_node->input_edge(2) != nullptr)
605             {
606                 auto conv_bias_id = conv_node->input_edge(2)->producer_id();
607                 g.add_connection(conv_bias_id, 0, fused_id, 2);
608             }
609             // Adding the Element wise operand in case the post op is element wise operation
610             auto it = std::find_if(post_op_node_list.begin(),
611                                    post_op_node_list.end(),
612                                    [&](const INode * nd)
613             {
614                 return (nd->type() == graph::NodeType::EltwiseLayer);
615             });
616 
617             if(it != post_op_node_list.end())
618             {
619                 g.add_connection(eltwise_operand_id, 0, fused_id, 3);
620             }
621             g.remove_node(conv_node->id());
622 
623             // Update fused node outputs
624             auto fused_node = g.node(fused_id);
625             fused_node->set_assigned_target(assigned_target);
626 
627             // Fuse convolution with post op
628             fuse_convolution_with_post_op(g, fused_node, post_op_node_list, prev_op_dst_pos);
629 
630             post_op_node_list.clear();
631             ARM_COMPUTE_LOG_GRAPH_VERBOSE(std::endl);
632         }
633     }
634     else
635     {
636         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to the presence of an output accessor\n");
637     }
638 }
639 
fuse_convolution_batch_normalization_with_post_ops(Graph & g,const Edge * output_edge,unsigned int conv_node_id,const std::set<Activation> & supported_fused_activations)640 void fuse_convolution_batch_normalization_with_post_ops(Graph &g, const Edge *output_edge, unsigned int conv_node_id, const std::set<Activation> &supported_fused_activations)
641 {
642     ARM_COMPUTE_ERROR_ON(output_edge == nullptr);
643 
644     auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<FusedConvolutionBatchNormalizationNode *>(output_edge->producer());
645     ARM_COMPUTE_ERROR_ON(conv_node->output(0) == nullptr);
646     const ConvolutionMethod conv_algorithm = conv_node->convolution_method();
647     if(conv_algorithm != ConvolutionMethod::GEMM)
648     {
649         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to non GEMM convolution\n");
650         return;
651     }
652 
653     // Prevent fusion if fused node has an output accessor
654     if(conv_node->output(0)->accessor() == nullptr)
655     {
656         // If data type is FP32/FP16, data layout is NHWC, and filter size is 1x1, fuse convolution with post op, as Conv1x1 always leads to GEMM.
657         const Edge *input_edge = conv_node->input_edge(1);
658         if(input_edge != nullptr && input_edge->tensor() != nullptr)
659         {
660             const DataLayout  data_layout  = input_edge->tensor()->desc().layout;
661             const DataType    data_type    = input_edge->tensor()->desc().data_type;
662             const TensorShape tensor_shape = input_edge->tensor()->desc().shape;
663             if((data_layout != DataLayout::NHWC) || (is_data_type_float(data_type) == false) || (tensor_shape.y() != 1) || (tensor_shape.z() != 1))
664             {
665                 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to non GEMM convolution\n");
666                 return;
667             }
668         }
669         else
670         {
671             return;
672         }
673 
674         // Get post op list
675         int                eltwise_operand_id = 0;
676         int                prev_op_dst_pos    = 0; // Previous operator dst's postion in current operator
677         std::list<INode *> post_op_node_list  = get_post_op_list(g, eltwise_operand_id, prev_op_dst_pos, conv_node_id, supported_fused_activations);
678 
679         if(post_op_node_list.size() == 0)
680         {
681             return;
682         }
683         else // Do convolution fusion with post op if there're one(elementwise), two or more operators
684         {
685             const Target assigned_target = conv_node->assigned_target();
686 
687             // Extract conv inputs
688             const auto   conv_input_id   = conv_node->input_edge(0)->producer_id();
689             const auto   conv_weights_id = conv_node->input_edge(1)->producer_id();
690             const auto   bn_mean_id      = conv_node->input_edge(3)->producer_id();
691             const auto   bn_var_id       = conv_node->input_edge(4)->producer_id();
692             const auto   conv_info       = conv_node->convolution_info();
693             const auto   conv_method     = conv_node->convolution_method();
694             const auto   num_groups      = conv_node->num_groups();
695             FastMathHint fast_math_hint  = conv_node->fast_math_hint();
696 
697             // Create the fused node
698 
699             const float  epsilon  = conv_node->epsilon();
700             const NodeID fused_id = g.add_node<FusedConvolutionBatchNormalizationWithPostOpsNode>(epsilon, conv_info, num_groups, conv_method, fast_math_hint);
701 
702             ARM_COMPUTE_LOG_GRAPH_VERBOSE("Fusing FusedConvolutionBatchNormalization node with ID : " << conv_node->id());
703 
704             // Add connections from the conv inputs to the fused node
705             g.add_connection(conv_input_id, 0, fused_id, 0);
706             g.add_connection(conv_weights_id, 0, fused_id, 1);
707 
708             if(conv_node->input_edge(2) != nullptr)
709             {
710                 auto conv_bias_id = conv_node->input_edge(2)->producer_id();
711                 g.add_connection(conv_bias_id, 0, fused_id, 2);
712             }
713             g.add_connection(bn_mean_id, 0, fused_id, 3);
714             g.add_connection(bn_var_id, 0, fused_id, 4);
715 
716             // Move connections of old FusedConvolutionBatchNormalization to the fused node
717             if(conv_node->input_edge(5) != nullptr)
718             {
719                 const auto bn_beta_id = conv_node->input_edge(5)->producer_id();
720                 g.add_connection(bn_beta_id, 0, fused_id, 5);
721             }
722 
723             if(conv_node->input_edge(6) != nullptr)
724             {
725                 const auto bn_gamma_id = conv_node->input_edge(6)->producer_id();
726                 g.add_connection(bn_gamma_id, 0, fused_id, 6);
727             }
728 
729             // Adding the Element wise operand in case the post op is element wise operation
730             auto it = std::find_if(post_op_node_list.begin(),
731                                    post_op_node_list.end(),
732                                    [&](const INode * nd)
733             {
734                 return (nd->type() == graph::NodeType::EltwiseLayer);
735             });
736 
737             if(it != post_op_node_list.end())
738             {
739                 g.add_connection(eltwise_operand_id, 0, fused_id, 7);
740             }
741 
742             // Update fused node outputs
743             auto fused_node = g.node(fused_id);
744             fused_node->set_assigned_target(assigned_target);
745 
746             auto conv_node_name = conv_node->name();
747 
748             // collect the post ops names
749             std::string post_ops_name = "";
750             for(auto &post_op : post_op_node_list)
751             {
752                 post_ops_name += post_op->name();
753             }
754             fused_node->set_common_node_parameters(NodeParams{ conv_node->name() + "+" + post_ops_name, assigned_target });
755 
756             // Fuse convolution with post op
757             fuse_convolution_with_post_op(g, fused_node, post_op_node_list, prev_op_dst_pos);
758 
759             post_op_node_list.clear();
760             g.remove_node(conv_node->id());
761             ARM_COMPUTE_LOG_GRAPH_VERBOSE(std::endl);
762         }
763     }
764     else
765     {
766         ARM_COMPUTE_LOG_GRAPH_VERBOSE("Prevented fusion of convolution node with post ops due to the presence of an output accessor\n");
767     }
768 }
769 
770 template <typename N1, typename F, typename... Args>
fuse_layer(Graph & g,std::function<bool (INode &)> const & prec,const F fuse_fcn,Args &&...optional_arguments)771 void fuse_layer(Graph &g, std::function<bool(INode &)> const &prec, const F fuse_fcn, Args &&... optional_arguments)
772 {
773     // Note that fused nodes may be added to the end of the node list.
774     // Instead of only looping over the original list of nodes, we loop over the current node list which could be growing.
775     // This is intentional as it probes the newly added fused nodes for further fusing opportunities.
776     for(unsigned int i = 0; i < g.nodes().size(); ++i)
777     {
778         auto node = g.node(i);
779         // Check if the node is of type N1 and not a branching node
780         if(node && node->type() == N1::node_type && node->output_edges().size() == 1)
781         {
782             const auto output_edge_id = *node->output_edges().begin();
783             const auto output_edge    = g.edge(output_edge_id);
784 
785             // Check if it's the correct target
786             if((output_edge != nullptr) && (output_edge->consumer() != nullptr) && prec(*output_edge->producer()))
787             {
788                 fuse_fcn(g, output_edge, i, optional_arguments...);
789             }
790         }
791     }
792 }
793 } // namespace detail
794 
name()795 const char *NodeFusionMutator::name()
796 {
797     return "NodeFusionMutator";
798 }
799 
type() const800 IGraphMutator::MutationType NodeFusionMutator::type() const
801 {
802     return IGraphMutator::MutationType::Backend;
803 }
804 
mutate(Graph & g)805 void NodeFusionMutator::mutate(Graph &g)
806 {
807     // Supported activations when fusing
808     const std::set<Activation> supported_fused_activations = { Activation::ABS, Activation::BOUNDED_RELU, Activation::ELU,
809                                                                Activation::HARD_SWISH, Activation::IDENTITY, Activation::LEAKY_RELU,
810                                                                Activation::LINEAR, Activation::LOGISTIC, Activation::LU_BOUNDED_RELU,
811                                                                Activation::RELU, Activation::SOFT_RELU, Activation::SQRT,
812                                                                Activation::SQUARE, Activation::TANH
813                                                              };
814 
815     // Preconditions
816     auto empty_prec = [](INode &)
817     {
818         return true;
819     };
820     auto cl_target_prec = [](INode & n)
821     {
822         return n.assigned_target() == Target::CL;
823     };
824     auto qs8_prec = [&g](INode & n)
825     {
826         ARM_COMPUTE_ERROR_ON(n.output(0) == nullptr);
827 
828         const auto output_edge_id = *n.output_edges().begin();
829         const auto output_edge    = g.edge(output_edge_id);
830         // To perform fusion the two nodes must have same output quantization information
831         const bool same_qinfo     = n.output(0)->desc().quant_info == output_edge->producer()->output(0)->desc().quant_info;
832         const bool output_qasymm8 = n.output(0)->desc().data_type == DataType::QASYMM8;
833 
834         return (output_qasymm8 && same_qinfo) || !output_qasymm8;
835     };
836 
837     // Fusion mutations
838 
839     detail::fuse_layer<PadLayerNode, ConvolutionLayerNode>(g, empty_prec, detail::fuse_pad_with_convolution<ConvolutionLayerNode>);
840     detail::fuse_layer<PadLayerNode, DepthwiseConvolutionLayerNode>(g, empty_prec, detail::fuse_pad_with_convolution<DepthwiseConvolutionLayerNode>);
841     // The fusion of PostOps to ConvolutionLayer:
842     // It must occur after the fusion of PadLayer into ConvolutionLayer
843     // It must occur before the fusion of normal ActivationLayer into ConvolutionLayer as it takes precedence
844     detail::fuse_layer<ConvolutionLayerNode>(g, cl_target_prec, detail::fuse_convolution_with_post_ops, supported_fused_activations);
845     detail::fuse_layer<BatchNormalizationLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<BatchNormalizationLayerNode>, supported_fused_activations);
846     detail::fuse_layer<ConvolutionLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<ConvolutionLayerNode>, supported_fused_activations);
847     detail::fuse_layer<DepthwiseConvolutionLayerNode, ActivationLayerNode>(g, qs8_prec, detail::fuse_node_with_activation<DepthwiseConvolutionLayerNode>, supported_fused_activations);
848     detail::fuse_layer<FullyConnectedLayerNode, ActivationLayerNode>(g, empty_prec, detail::fuse_node_with_activation<FullyConnectedLayerNode>, supported_fused_activations);
849     detail::fuse_layer<EltwiseLayerNode, ActivationLayerNode>(g, cl_target_prec, detail::fuse_node_with_activation<EltwiseLayerNode>, supported_fused_activations);
850     // The fusion of BatchNormalizationLayer must occur after the fusion of ActivationLayer. Because FusedConvolutionBatchNormalizationNode assumes the BatchNormalization is already fused with activation, if any
851     detail::fuse_layer<ConvolutionLayerNode, BatchNormalizationLayerNode>(g, empty_prec, detail::fuse_convolution_with_batch_normalization);
852     detail::fuse_layer<DepthwiseConvolutionLayerNode, BatchNormalizationLayerNode>(g, empty_prec, detail::fuse_depthwise_convolution_with_batch_normalization);
853     detail::fuse_layer<FusedConvolutionBatchNormalizationNode>(g, cl_target_prec, detail::fuse_convolution_batch_normalization_with_post_ops, supported_fused_activations);
854 }
855 } // namespace graph
856 } // namespace arm_compute
857