xref: /aosp_15_r20/external/ComputeLibrary/src/graph/detail/ExecutionHelpers.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018-2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "arm_compute/graph/detail/ExecutionHelpers.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/graph/Graph.h"
27*c217d954SCole Faust #include "arm_compute/graph/GraphContext.h"
28*c217d954SCole Faust #include "arm_compute/graph/GraphManager.h"
29*c217d954SCole Faust #include "arm_compute/graph/Tensor.h"
30*c217d954SCole Faust #include "arm_compute/graph/Utils.h"
31*c217d954SCole Faust #include "arm_compute/graph/backends/BackendRegistry.h"
32*c217d954SCole Faust 
33*c217d954SCole Faust namespace arm_compute
34*c217d954SCole Faust {
35*c217d954SCole Faust namespace graph
36*c217d954SCole Faust {
37*c217d954SCole Faust namespace detail
38*c217d954SCole Faust {
validate_all_nodes(Graph & g)39*c217d954SCole Faust void validate_all_nodes(Graph &g)
40*c217d954SCole Faust {
41*c217d954SCole Faust     auto &nodes = g.nodes();
42*c217d954SCole Faust 
43*c217d954SCole Faust     // Create tasks
44*c217d954SCole Faust     for(auto &node : nodes)
45*c217d954SCole Faust     {
46*c217d954SCole Faust         if(node != nullptr)
47*c217d954SCole Faust         {
48*c217d954SCole Faust             Target                    assigned_target = node->assigned_target();
49*c217d954SCole Faust             backends::IDeviceBackend &backend         = backends::BackendRegistry::get().get_backend(assigned_target);
50*c217d954SCole Faust             Status                    status          = backend.validate_node(*node);
51*c217d954SCole Faust             ARM_COMPUTE_ERROR_ON_MSG(!bool(status), status.error_description().c_str());
52*c217d954SCole Faust         }
53*c217d954SCole Faust     }
54*c217d954SCole Faust }
55*c217d954SCole Faust 
configure_all_tensors(Graph & g)56*c217d954SCole Faust void configure_all_tensors(Graph &g)
57*c217d954SCole Faust {
58*c217d954SCole Faust     auto &tensors = g.tensors();
59*c217d954SCole Faust 
60*c217d954SCole Faust     for(auto &tensor : tensors)
61*c217d954SCole Faust     {
62*c217d954SCole Faust         if(tensor && tensor->handle() == nullptr)
63*c217d954SCole Faust         {
64*c217d954SCole Faust             Target                         target  = tensor->desc().target;
65*c217d954SCole Faust             backends::IDeviceBackend      &backend = backends::BackendRegistry::get().get_backend(target);
66*c217d954SCole Faust             std::unique_ptr<ITensorHandle> handle  = backend.create_tensor(*tensor);
67*c217d954SCole Faust             ARM_COMPUTE_ERROR_ON_MSG(!handle, "Couldn't create backend handle!");
68*c217d954SCole Faust             tensor->set_handle(std::move(handle));
69*c217d954SCole Faust         }
70*c217d954SCole Faust     }
71*c217d954SCole Faust }
72*c217d954SCole Faust 
allocate_all_input_tensors(INode & node)73*c217d954SCole Faust void allocate_all_input_tensors(INode &node)
74*c217d954SCole Faust {
75*c217d954SCole Faust     for(unsigned int i = 0; i < node.num_inputs(); ++i)
76*c217d954SCole Faust     {
77*c217d954SCole Faust         Tensor *tensor = node.input(i);
78*c217d954SCole Faust         if(tensor != nullptr && !tensor->bound_edges().empty())
79*c217d954SCole Faust         {
80*c217d954SCole Faust             ARM_COMPUTE_ERROR_ON_MSG(!tensor->handle(), "Tensor handle is not configured!");
81*c217d954SCole Faust             tensor->handle()->allocate();
82*c217d954SCole Faust         }
83*c217d954SCole Faust     }
84*c217d954SCole Faust }
85*c217d954SCole Faust 
allocate_all_output_tensors(INode & node)86*c217d954SCole Faust void allocate_all_output_tensors(INode &node)
87*c217d954SCole Faust {
88*c217d954SCole Faust     for(unsigned int i = 0; i < node.num_outputs(); ++i)
89*c217d954SCole Faust     {
90*c217d954SCole Faust         Tensor *tensor = node.output(i);
91*c217d954SCole Faust         if(tensor != nullptr && !tensor->bound_edges().empty())
92*c217d954SCole Faust         {
93*c217d954SCole Faust             ARM_COMPUTE_ERROR_ON_MSG(!tensor->handle(), "Tensor handle is not configured!");
94*c217d954SCole Faust             tensor->handle()->allocate();
95*c217d954SCole Faust         }
96*c217d954SCole Faust     }
97*c217d954SCole Faust }
98*c217d954SCole Faust 
allocate_const_tensors(Graph & g)99*c217d954SCole Faust void allocate_const_tensors(Graph &g)
100*c217d954SCole Faust {
101*c217d954SCole Faust     for(auto &node : g.nodes())
102*c217d954SCole Faust     {
103*c217d954SCole Faust         if(node != nullptr)
104*c217d954SCole Faust         {
105*c217d954SCole Faust             switch(node->type())
106*c217d954SCole Faust             {
107*c217d954SCole Faust                 case NodeType::Const:
108*c217d954SCole Faust                 case NodeType::Input:
109*c217d954SCole Faust                     allocate_all_output_tensors(*node);
110*c217d954SCole Faust                     break;
111*c217d954SCole Faust                 case NodeType::Output:
112*c217d954SCole Faust                     allocate_all_input_tensors(*node);
113*c217d954SCole Faust                 default:
114*c217d954SCole Faust                     break;
115*c217d954SCole Faust             }
116*c217d954SCole Faust         }
117*c217d954SCole Faust     }
118*c217d954SCole Faust }
119*c217d954SCole Faust 
allocate_all_tensors(Graph & g)120*c217d954SCole Faust void allocate_all_tensors(Graph &g)
121*c217d954SCole Faust {
122*c217d954SCole Faust     auto &tensors = g.tensors();
123*c217d954SCole Faust 
124*c217d954SCole Faust     for(auto &tensor : tensors)
125*c217d954SCole Faust     {
126*c217d954SCole Faust         if(tensor && !tensor->bound_edges().empty() && tensor->handle() != nullptr && tensor->handle()->tensor().info()->is_resizable() && tensor->handle()->tensor().is_used())
127*c217d954SCole Faust         {
128*c217d954SCole Faust             tensor->handle()->allocate();
129*c217d954SCole Faust         }
130*c217d954SCole Faust     }
131*c217d954SCole Faust }
132*c217d954SCole Faust 
configure_all_nodes(Graph & g,GraphContext & ctx,const std::vector<NodeID> & node_order)133*c217d954SCole Faust ExecutionWorkload configure_all_nodes(Graph &g, GraphContext &ctx, const std::vector<NodeID> &node_order)
134*c217d954SCole Faust {
135*c217d954SCole Faust     ExecutionWorkload workload;
136*c217d954SCole Faust     workload.graph = &g;
137*c217d954SCole Faust     workload.ctx   = &ctx;
138*c217d954SCole Faust 
139*c217d954SCole Faust     // Reserve memory for tasks
140*c217d954SCole Faust     workload.tasks.reserve(node_order.size());
141*c217d954SCole Faust 
142*c217d954SCole Faust     // Create tasks
143*c217d954SCole Faust     for(auto &node_id : node_order)
144*c217d954SCole Faust     {
145*c217d954SCole Faust         auto node = g.node(node_id);
146*c217d954SCole Faust         if(node != nullptr)
147*c217d954SCole Faust         {
148*c217d954SCole Faust             Target                     assigned_target = node->assigned_target();
149*c217d954SCole Faust             backends::IDeviceBackend &backend         = backends::BackendRegistry::get().get_backend(assigned_target);
150*c217d954SCole Faust             std::unique_ptr<IFunction> func            = backend.configure_node(*node, ctx);
151*c217d954SCole Faust             if(func != nullptr || is_utility_node(node))
152*c217d954SCole Faust             {
153*c217d954SCole Faust                 workload.tasks.emplace_back(ExecutionTask(std::move(func), node));
154*c217d954SCole Faust             }
155*c217d954SCole Faust         }
156*c217d954SCole Faust     }
157*c217d954SCole Faust 
158*c217d954SCole Faust     // Add inputs and outputs
159*c217d954SCole Faust     for(auto &node : g.nodes())
160*c217d954SCole Faust     {
161*c217d954SCole Faust         if(node != nullptr && node->type() == NodeType::Input)
162*c217d954SCole Faust         {
163*c217d954SCole Faust             workload.inputs.push_back(node->output(0));
164*c217d954SCole Faust         }
165*c217d954SCole Faust 
166*c217d954SCole Faust         if(node != nullptr && node->type() == NodeType::Output)
167*c217d954SCole Faust         {
168*c217d954SCole Faust             workload.outputs.push_back(node->input(0));
169*c217d954SCole Faust             continue;
170*c217d954SCole Faust         }
171*c217d954SCole Faust     }
172*c217d954SCole Faust 
173*c217d954SCole Faust     return workload;
174*c217d954SCole Faust }
175*c217d954SCole Faust 
release_unused_tensors(Graph & g)176*c217d954SCole Faust void release_unused_tensors(Graph &g)
177*c217d954SCole Faust {
178*c217d954SCole Faust     for(auto &tensor : g.tensors())
179*c217d954SCole Faust     {
180*c217d954SCole Faust         if(tensor != nullptr && tensor->handle() != nullptr)
181*c217d954SCole Faust         {
182*c217d954SCole Faust             tensor->handle()->release_if_unused();
183*c217d954SCole Faust         }
184*c217d954SCole Faust     }
185*c217d954SCole Faust }
186*c217d954SCole Faust 
call_tensor_accessor(Tensor * tensor)187*c217d954SCole Faust void call_tensor_accessor(Tensor *tensor)
188*c217d954SCole Faust {
189*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(!tensor);
190*c217d954SCole Faust     tensor->call_accessor();
191*c217d954SCole Faust }
192*c217d954SCole Faust 
call_all_const_node_accessors(Graph & g)193*c217d954SCole Faust void call_all_const_node_accessors(Graph &g)
194*c217d954SCole Faust {
195*c217d954SCole Faust     auto &nodes = g.nodes();
196*c217d954SCole Faust 
197*c217d954SCole Faust     for(auto &node : nodes)
198*c217d954SCole Faust     {
199*c217d954SCole Faust         if(node != nullptr && node->type() == NodeType::Const && node->num_outputs())
200*c217d954SCole Faust         {
201*c217d954SCole Faust             if(!node->output(0)->bound_edges().empty())
202*c217d954SCole Faust             {
203*c217d954SCole Faust                 call_tensor_accessor(node->output(0));
204*c217d954SCole Faust             }
205*c217d954SCole Faust         }
206*c217d954SCole Faust     }
207*c217d954SCole Faust }
208*c217d954SCole Faust 
call_all_input_node_accessors(ExecutionWorkload & workload)209*c217d954SCole Faust bool call_all_input_node_accessors(ExecutionWorkload &workload)
210*c217d954SCole Faust {
211*c217d954SCole Faust     bool is_valid = true;
212*c217d954SCole Faust     std::for_each(std::begin(workload.inputs), std::end(workload.inputs), [&](Tensor * input_tensor)
213*c217d954SCole Faust     {
214*c217d954SCole Faust         bool valid_input = (input_tensor != nullptr) && input_tensor->call_accessor();
215*c217d954SCole Faust         is_valid         = is_valid && valid_input;
216*c217d954SCole Faust     });
217*c217d954SCole Faust     return is_valid;
218*c217d954SCole Faust }
219*c217d954SCole Faust 
prepare_all_tasks(ExecutionWorkload & workload)220*c217d954SCole Faust void prepare_all_tasks(ExecutionWorkload &workload)
221*c217d954SCole Faust {
222*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(workload.graph == nullptr);
223*c217d954SCole Faust     for(auto &task : workload.tasks)
224*c217d954SCole Faust     {
225*c217d954SCole Faust         task.prepare();
226*c217d954SCole Faust         release_unused_tensors(*workload.graph);
227*c217d954SCole Faust     }
228*c217d954SCole Faust }
229*c217d954SCole Faust 
call_all_tasks(ExecutionWorkload & workload)230*c217d954SCole Faust void call_all_tasks(ExecutionWorkload &workload)
231*c217d954SCole Faust {
232*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(workload.ctx == nullptr);
233*c217d954SCole Faust 
234*c217d954SCole Faust     // Acquire memory for the transition buffers
235*c217d954SCole Faust     for(auto &mm_ctx : workload.ctx->memory_managers())
236*c217d954SCole Faust     {
237*c217d954SCole Faust         if(mm_ctx.second.cross_group != nullptr)
238*c217d954SCole Faust         {
239*c217d954SCole Faust             mm_ctx.second.cross_group->acquire();
240*c217d954SCole Faust         }
241*c217d954SCole Faust     }
242*c217d954SCole Faust 
243*c217d954SCole Faust     // Execute tasks
244*c217d954SCole Faust     for(auto &task : workload.tasks)
245*c217d954SCole Faust     {
246*c217d954SCole Faust         task();
247*c217d954SCole Faust     }
248*c217d954SCole Faust 
249*c217d954SCole Faust     // Release memory for the transition buffers
250*c217d954SCole Faust     for(auto &mm_ctx : workload.ctx->memory_managers())
251*c217d954SCole Faust     {
252*c217d954SCole Faust         if(mm_ctx.second.cross_group != nullptr)
253*c217d954SCole Faust         {
254*c217d954SCole Faust             mm_ctx.second.cross_group->release();
255*c217d954SCole Faust         }
256*c217d954SCole Faust     }
257*c217d954SCole Faust }
258*c217d954SCole Faust 
call_all_output_node_accessors(ExecutionWorkload & workload)259*c217d954SCole Faust bool call_all_output_node_accessors(ExecutionWorkload &workload)
260*c217d954SCole Faust {
261*c217d954SCole Faust     bool is_valid = true;
262*c217d954SCole Faust     std::for_each(std::begin(workload.outputs), std::end(workload.outputs), [&](Tensor * output_tensor)
263*c217d954SCole Faust     {
264*c217d954SCole Faust         bool valid_output = (output_tensor != nullptr) && output_tensor->call_accessor();
265*c217d954SCole Faust         is_valid          = is_valid && valid_output;
266*c217d954SCole Faust     });
267*c217d954SCole Faust 
268*c217d954SCole Faust     sync_backends();
269*c217d954SCole Faust 
270*c217d954SCole Faust     return is_valid;
271*c217d954SCole Faust }
272*c217d954SCole Faust } // namespace detail
273*c217d954SCole Faust } // namespace graph
274*c217d954SCole Faust } // namespace arm_compute
275