xref: /aosp_15_r20/external/ComputeLibrary/src/graph/Utils.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/Utils.h"
25 
26 #include "arm_compute/graph/GraphContext.h"
27 #include "arm_compute/graph/backends/BackendRegistry.h"
28 #include "arm_compute/graph/mutators/GraphMutators.h"
29 
30 namespace arm_compute
31 {
32 namespace graph
33 {
is_target_supported(Target target)34 bool is_target_supported(Target target)
35 {
36     return backends::BackendRegistry::get().contains(target) && backends::BackendRegistry::get().find_backend(target)->is_backend_supported();
37 }
38 
get_default_target()39 Target get_default_target()
40 {
41     if(is_target_supported(Target::NEON))
42     {
43         return Target::NEON;
44     }
45     if(is_target_supported(Target::CL))
46     {
47         return Target::CL;
48     }
49     ARM_COMPUTE_ERROR("No backend exists!");
50 }
51 
force_target_to_graph(Graph & g,Target target)52 void force_target_to_graph(Graph &g, Target target)
53 {
54     auto &nodes = g.nodes();
55     for(auto &node : nodes)
56     {
57         if(node)
58         {
59             node->set_assigned_target(target);
60         }
61     }
62 
63     auto &tensors = g.tensors();
64     for(auto &tensor : tensors)
65     {
66         if(tensor)
67         {
68             tensor->desc().target = target;
69         }
70     }
71 }
72 
create_default_pass_manager(Target target,const GraphConfig & cfg)73 PassManager create_default_pass_manager(Target target, const GraphConfig &cfg)
74 {
75     ARM_COMPUTE_UNUSED(target);
76     PassManager pm;
77 
78     // Passes that mutate graph IR
79     if(cfg.use_synthetic_type)
80     {
81         switch(cfg.synthetic_type)
82         {
83             case DataType::QASYMM8:
84             case DataType::QASYMM8_SIGNED:
85             {
86                 pm.append(std::make_unique<SyntheticDataTypeMutator>(cfg.synthetic_type));
87                 break;
88             }
89             default:
90             {
91                 ARM_COMPUTE_ERROR("Unsupported DataType for SyntheticDataTypeMutator");
92                 break;
93             }
94         }
95     }
96     pm.append(std::make_unique<NodeFusionMutator>());
97     pm.append(std::make_unique<GroupedConvolutionMutator>());
98     pm.append(std::make_unique<InPlaceOperationMutator>());
99 
100     // Passes that mutate backend information
101     pm.append(std::make_unique<DepthConcatSubTensorMutator>());
102     pm.append(std::make_unique<SplitLayerSubTensorMutator>());
103     pm.append(std::make_unique<NodeExecutionMethodMutator>());
104 
105     return pm;
106 }
107 
release_default_graph_context(GraphContext & ctx)108 void release_default_graph_context(GraphContext &ctx)
109 {
110     for(const auto &backend : backends::BackendRegistry::get().backends())
111     {
112         if(backend.second->is_backend_supported())
113         {
114             backend.second->release_backend_context(ctx);
115         }
116     }
117 }
118 
sync_backends()119 void sync_backends()
120 {
121     for(const auto &backend : backends::BackendRegistry::get().backends())
122     {
123         if(backend.second->backend_allocator())
124         {
125             backend.second->sync();
126         }
127     }
128 }
129 
setup_requested_backend_context(GraphContext & ctx,Target target)130 void setup_requested_backend_context(GraphContext &ctx, Target target)
131 {
132     if(backends::BackendRegistry::get().contains(target))
133     {
134         const auto &backend = backends::BackendRegistry::get().find_backend(target);
135         if(backend->is_backend_supported())
136         {
137             backend->setup_backend_context(ctx);
138         }
139     }
140 }
141 
get_dimension_size(const TensorDescriptor & descriptor,const DataLayoutDimension data_layout_dimension)142 size_t get_dimension_size(const TensorDescriptor &descriptor, const DataLayoutDimension data_layout_dimension)
143 {
144     ARM_COMPUTE_ERROR_ON_MSG(descriptor.layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!");
145     return descriptor.shape[get_dimension_idx(descriptor.layout, data_layout_dimension)];
146 }
147 
get_dimension_idx(DataLayout data_layout,const DataLayoutDimension data_layout_dimension)148 size_t get_dimension_idx(DataLayout data_layout, const DataLayoutDimension data_layout_dimension)
149 {
150     ARM_COMPUTE_ERROR_ON_MSG(data_layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!");
151 
152     /* Return the index based on the data layout
153      * [N C H W]
154      * [3 2 1 0]
155      * [N H W C]
156      */
157     switch(data_layout_dimension)
158     {
159         case DataLayoutDimension::CHANNEL:
160             return (data_layout == DataLayout::NCHW) ? 2 : 0;
161             break;
162         case DataLayoutDimension::HEIGHT:
163             return (data_layout == DataLayout::NCHW) ? 1 : 2;
164             break;
165         case DataLayoutDimension::WIDTH:
166             return (data_layout == DataLayout::NCHW) ? 0 : 1;
167             break;
168         case DataLayoutDimension::BATCHES:
169             return 3;
170             break;
171         default:
172             break;
173     }
174     ARM_COMPUTE_ERROR("Data layout index not supported!");
175 }
176 
get_driving_nodes(const INode & node)177 std::vector<NodeIdxPair> get_driving_nodes(const INode &node)
178 {
179     std::vector<NodeIdxPair> driving_nodes;
180 
181     const Graph *g = node.graph();
182     ARM_COMPUTE_ERROR_ON(g == nullptr);
183 
184     for(auto &output_edge_id : node.output_edges())
185     {
186         auto output_edge = g->edge(output_edge_id);
187         if(output_edge != nullptr)
188         {
189             ARM_COMPUTE_ERROR_ON(output_edge->consumer() == nullptr);
190             driving_nodes.push_back({ output_edge->consumer_id(), output_edge->consumer_idx() });
191         }
192     }
193 
194     return driving_nodes;
195 }
196 
get_driver_nodes(const INode & node)197 std::vector<NodeIdxPair> get_driver_nodes(const INode &node)
198 {
199     std::vector<NodeIdxPair> driver_nodes;
200 
201     const Graph *g = node.graph();
202     ARM_COMPUTE_ERROR_ON(g == nullptr);
203 
204     for(auto &input_edge_id : node.input_edges())
205     {
206         auto input_edge = g->edge(input_edge_id);
207         if(input_edge != nullptr)
208         {
209             ARM_COMPUTE_ERROR_ON(input_edge->producer() == nullptr);
210             driver_nodes.push_back({ input_edge->producer_id(), input_edge->producer_idx() });
211         }
212     }
213 
214     return driver_nodes;
215 }
216 
configure_tensor(Tensor * tensor)217 void configure_tensor(Tensor *tensor)
218 {
219     if(tensor != nullptr && tensor->handle() == nullptr)
220     {
221         Target                         target  = tensor->desc().target;
222         backends::IDeviceBackend      &backend = backends::BackendRegistry::get().get_backend(target);
223         std::unique_ptr<ITensorHandle> handle  = backend.create_tensor(*tensor);
224         ARM_COMPUTE_ERROR_ON_MSG(!handle, "Couldn't create backend handle!");
225         tensor->set_handle(std::move(handle));
226     }
227 }
228 
229 } // namespace graph
230 } // namespace arm_compute
231