xref: /aosp_15_r20/external/ComputeLibrary/src/graph/mutators/GroupedConvolutionMutator.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/mutators/GroupedConvolutionMutator.h"
25 
26 #include "arm_compute/graph/Graph.h"
27 #include "arm_compute/graph/GraphBuilder.h"
28 #include "arm_compute/graph/Logger.h"
29 #include "arm_compute/graph/Utils.h"
30 #include "arm_compute/graph/backends/BackendRegistry.h"
31 #include "arm_compute/graph/nodes/Nodes.h"
32 
33 #include "support/Cast.h"
34 
35 #include "support/StringSupport.h"
36 
37 #include <set>
38 
39 namespace arm_compute
40 {
41 namespace graph
42 {
43 namespace
44 {
create_grouped_convolution(Graph & g,const NodeParams & params,NodeIdxPair input,NodeID weights,NodeID bias,PadStrideInfo conv_info,ConvolutionMethod method,ActivationLayerInfo fused_act,FastMathHint fast_math_hint,unsigned int num_groups)45 NodeID create_grouped_convolution(Graph &g, const NodeParams &params, NodeIdxPair input, NodeID weights, NodeID bias,
46                                   PadStrideInfo conv_info, ConvolutionMethod method, ActivationLayerInfo fused_act, FastMathHint fast_math_hint, unsigned int num_groups)
47 {
48     bool has_bias = (bias != EmptyNodeID);
49 
50     // Split input
51     const TensorDescriptor input_tensor_desc = get_tensor_descriptor(g, g.node(input.node_id)->outputs()[0]);
52     const unsigned int     input_idx         = get_dimension_idx(input_tensor_desc.layout, DataLayoutDimension::CHANNEL);
53     NodeID                 input_split       = GraphBuilder::add_split_node(g, params, input, num_groups, input_idx);
54 
55     // Split weights
56     const TensorDescriptor weights_tensor_desc = get_tensor_descriptor(g, g.node(weights)->outputs()[0]);
57     const unsigned int     batch_idx           = get_dimension_idx(weights_tensor_desc.layout, DataLayoutDimension::BATCHES);
58     NodeID                 weights_split       = GraphBuilder::add_split_node(g, params, { weights, 0 }, num_groups, batch_idx);
59 
60     // Split bias
61     NodeID bias_split = EmptyNodeID;
62     if(has_bias)
63     {
64         // Split bias
65         bias_split = GraphBuilder::add_split_node(g, params, { bias, 0 }, num_groups, 0);
66     }
67 
68     std::vector<NodeIdxPair> convolution_outputs;
69     for(unsigned int i = 0; i < num_groups; ++i)
70     {
71         NodeParams group_params = params;
72         NodeID     conv_nid     = g.add_node<ConvolutionLayerNode>(conv_info, 1, method, fast_math_hint);
73         g.add_connection(input_split, i, conv_nid, 0);
74         g.add_connection(weights_split, i, conv_nid, 1);
75         if(has_bias)
76         {
77             g.add_connection(bias_split, i, conv_nid, 2);
78         }
79 
80         // Add group name
81         if(!group_params.name.empty())
82         {
83             group_params.name.append("_g" + arm_compute::support::cpp11::to_string(i));
84         }
85 
86         // Set node parameters
87         INode *node = g.node(conv_nid);
88         ARM_COMPUTE_ERROR_ON(node == nullptr);
89         node->set_common_node_parameters(group_params);
90 
91         // Down-cast node
92         auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
93         conv_node->set_fused_activation(fused_act);
94 
95         convolution_outputs.push_back({ conv_nid, 0 });
96     }
97 
98     // Depth concatenate output
99     return GraphBuilder::add_concatenate_node(g, params, convolution_outputs, DataLayoutDimension::CHANNEL);
100 }
101 } // namespace
102 
name()103 const char *GroupedConvolutionMutator::name()
104 {
105     return "GroupedConvolutionMutator";
106 }
107 
type() const108 IGraphMutator::MutationType GroupedConvolutionMutator::type() const
109 {
110     return IGraphMutator::MutationType::Backend;
111 }
112 
mutate(Graph & g)113 void GroupedConvolutionMutator::mutate(Graph &g)
114 {
115     // Early exit if no Convolution layers exist in graph
116     if(g.nodes(NodeType::ConvolutionLayer).empty())
117     {
118         return;
119     }
120 
121     // Total nodes
122     size_t total_nodes = g.nodes().size();
123 
124     // Iterate over convolution nodes
125     for(unsigned int i = 0; i < total_nodes; ++i)
126     {
127         INode *node = g.node(i);
128         if(node != nullptr && node->type() == NodeType::ConvolutionLayer && arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node)->num_groups() != 1)
129         {
130             // Validate node
131             backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target());
132             Status                    status  = backend.validate_node(*node);
133 
134             // If grouped convolution is not supported
135             if(!bool(status))
136             {
137                 // Down-cast node
138                 auto *conv_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(node);
139 
140                 // Get internal convolution info
141                 // TODO (geopin01) : Create a descriptor or a clone interface
142                 const PadStrideInfo       conv_info       = conv_node->convolution_info();
143                 const ConvolutionMethod   conv_method     = conv_node->convolution_method();
144                 const ActivationLayerInfo fused_act_info  = conv_node->fused_activation();
145                 const FastMathHint        fast_math_hint  = conv_node->fast_math_hint();
146                 const unsigned int        num_groups      = conv_node->num_groups();
147                 const NodeParams          params          = conv_node->common_node_params();
148                 const Target              assigned_target = conv_node->assigned_target();
149 
150                 // Extract node ids
151                 ARM_COMPUTE_ERROR_ON(conv_node->input_edge(0) == nullptr || conv_node->input_edge(1) == nullptr);
152                 const NodeID input_id   = conv_node->input_edge(0)->producer()->id();
153                 const NodeID weights_id = conv_node->input_edge(1)->producer()->id();
154                 const NodeID bias_id    = (conv_node->input_edge(2) != nullptr) ? conv_node->input_edge(2)->producer()->id() : EmptyNodeID;
155 
156                 // Get driving nodes
157                 std::vector<NodeIdxPair> driving_nodes = get_driving_nodes(*node);
158 
159                 // Extract activation node accessor if any
160                 auto node_accessor = conv_node->output(0)->extract_accessor();
161 
162                 // Current max tensor and node id
163                 TensorID latest_tid = g.tensors().size();
164                 NodeID   latest_nid = g.nodes().size();
165 
166                 // Create grouped convolution node
167                 NodeID grouped_conv_id = create_grouped_convolution(g, params, { input_id, 0 }, weights_id, bias_id,
168                                                                     conv_info, conv_method, fused_act_info, fast_math_hint, num_groups);
169 
170                 // Remove convolution node
171                 g.remove_node(node->id());
172 
173                 // Update batch normalization node outputs
174                 for(auto &driving_node : driving_nodes)
175                 {
176                     g.add_connection(grouped_conv_id, 0, driving_node.node_id, driving_node.index);
177                 }
178 
179                 // Update accessor to batch normalization node
180                 g.node(grouped_conv_id)->output(0)->set_accessor(std::move(node_accessor));
181 
182                 // Configure new tensors and nodes
183                 std::for_each(g.tensors().begin() + latest_tid, g.tensors().end(), [](std::unique_ptr<Tensor> &t)
184                 {
185                     configure_tensor(t.get());
186                 });
187                 std::for_each(g.nodes().begin() + latest_nid, g.nodes().end(), [&assigned_target](std::unique_ptr<INode> &n)
188                 {
189                     if(n != nullptr)
190                     {
191                         n->set_assigned_target(assigned_target);
192                     }
193                 });
194             }
195         }
196     }
197 }
198 } // namespace graph
199 } // namespace arm_compute
200