1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2018-2021 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "arm_compute/graph/detail/ExecutionHelpers.h"
25*c217d954SCole Faust
26*c217d954SCole Faust #include "arm_compute/graph/Graph.h"
27*c217d954SCole Faust #include "arm_compute/graph/GraphContext.h"
28*c217d954SCole Faust #include "arm_compute/graph/GraphManager.h"
29*c217d954SCole Faust #include "arm_compute/graph/Tensor.h"
30*c217d954SCole Faust #include "arm_compute/graph/Utils.h"
31*c217d954SCole Faust #include "arm_compute/graph/backends/BackendRegistry.h"
32*c217d954SCole Faust
33*c217d954SCole Faust namespace arm_compute
34*c217d954SCole Faust {
35*c217d954SCole Faust namespace graph
36*c217d954SCole Faust {
37*c217d954SCole Faust namespace detail
38*c217d954SCole Faust {
validate_all_nodes(Graph & g)39*c217d954SCole Faust void validate_all_nodes(Graph &g)
40*c217d954SCole Faust {
41*c217d954SCole Faust auto &nodes = g.nodes();
42*c217d954SCole Faust
43*c217d954SCole Faust // Create tasks
44*c217d954SCole Faust for(auto &node : nodes)
45*c217d954SCole Faust {
46*c217d954SCole Faust if(node != nullptr)
47*c217d954SCole Faust {
48*c217d954SCole Faust Target assigned_target = node->assigned_target();
49*c217d954SCole Faust backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(assigned_target);
50*c217d954SCole Faust Status status = backend.validate_node(*node);
51*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MSG(!bool(status), status.error_description().c_str());
52*c217d954SCole Faust }
53*c217d954SCole Faust }
54*c217d954SCole Faust }
55*c217d954SCole Faust
configure_all_tensors(Graph & g)56*c217d954SCole Faust void configure_all_tensors(Graph &g)
57*c217d954SCole Faust {
58*c217d954SCole Faust auto &tensors = g.tensors();
59*c217d954SCole Faust
60*c217d954SCole Faust for(auto &tensor : tensors)
61*c217d954SCole Faust {
62*c217d954SCole Faust if(tensor && tensor->handle() == nullptr)
63*c217d954SCole Faust {
64*c217d954SCole Faust Target target = tensor->desc().target;
65*c217d954SCole Faust backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(target);
66*c217d954SCole Faust std::unique_ptr<ITensorHandle> handle = backend.create_tensor(*tensor);
67*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MSG(!handle, "Couldn't create backend handle!");
68*c217d954SCole Faust tensor->set_handle(std::move(handle));
69*c217d954SCole Faust }
70*c217d954SCole Faust }
71*c217d954SCole Faust }
72*c217d954SCole Faust
allocate_all_input_tensors(INode & node)73*c217d954SCole Faust void allocate_all_input_tensors(INode &node)
74*c217d954SCole Faust {
75*c217d954SCole Faust for(unsigned int i = 0; i < node.num_inputs(); ++i)
76*c217d954SCole Faust {
77*c217d954SCole Faust Tensor *tensor = node.input(i);
78*c217d954SCole Faust if(tensor != nullptr && !tensor->bound_edges().empty())
79*c217d954SCole Faust {
80*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MSG(!tensor->handle(), "Tensor handle is not configured!");
81*c217d954SCole Faust tensor->handle()->allocate();
82*c217d954SCole Faust }
83*c217d954SCole Faust }
84*c217d954SCole Faust }
85*c217d954SCole Faust
allocate_all_output_tensors(INode & node)86*c217d954SCole Faust void allocate_all_output_tensors(INode &node)
87*c217d954SCole Faust {
88*c217d954SCole Faust for(unsigned int i = 0; i < node.num_outputs(); ++i)
89*c217d954SCole Faust {
90*c217d954SCole Faust Tensor *tensor = node.output(i);
91*c217d954SCole Faust if(tensor != nullptr && !tensor->bound_edges().empty())
92*c217d954SCole Faust {
93*c217d954SCole Faust ARM_COMPUTE_ERROR_ON_MSG(!tensor->handle(), "Tensor handle is not configured!");
94*c217d954SCole Faust tensor->handle()->allocate();
95*c217d954SCole Faust }
96*c217d954SCole Faust }
97*c217d954SCole Faust }
98*c217d954SCole Faust
allocate_const_tensors(Graph & g)99*c217d954SCole Faust void allocate_const_tensors(Graph &g)
100*c217d954SCole Faust {
101*c217d954SCole Faust for(auto &node : g.nodes())
102*c217d954SCole Faust {
103*c217d954SCole Faust if(node != nullptr)
104*c217d954SCole Faust {
105*c217d954SCole Faust switch(node->type())
106*c217d954SCole Faust {
107*c217d954SCole Faust case NodeType::Const:
108*c217d954SCole Faust case NodeType::Input:
109*c217d954SCole Faust allocate_all_output_tensors(*node);
110*c217d954SCole Faust break;
111*c217d954SCole Faust case NodeType::Output:
112*c217d954SCole Faust allocate_all_input_tensors(*node);
113*c217d954SCole Faust default:
114*c217d954SCole Faust break;
115*c217d954SCole Faust }
116*c217d954SCole Faust }
117*c217d954SCole Faust }
118*c217d954SCole Faust }
119*c217d954SCole Faust
allocate_all_tensors(Graph & g)120*c217d954SCole Faust void allocate_all_tensors(Graph &g)
121*c217d954SCole Faust {
122*c217d954SCole Faust auto &tensors = g.tensors();
123*c217d954SCole Faust
124*c217d954SCole Faust for(auto &tensor : tensors)
125*c217d954SCole Faust {
126*c217d954SCole Faust if(tensor && !tensor->bound_edges().empty() && tensor->handle() != nullptr && tensor->handle()->tensor().info()->is_resizable() && tensor->handle()->tensor().is_used())
127*c217d954SCole Faust {
128*c217d954SCole Faust tensor->handle()->allocate();
129*c217d954SCole Faust }
130*c217d954SCole Faust }
131*c217d954SCole Faust }
132*c217d954SCole Faust
configure_all_nodes(Graph & g,GraphContext & ctx,const std::vector<NodeID> & node_order)133*c217d954SCole Faust ExecutionWorkload configure_all_nodes(Graph &g, GraphContext &ctx, const std::vector<NodeID> &node_order)
134*c217d954SCole Faust {
135*c217d954SCole Faust ExecutionWorkload workload;
136*c217d954SCole Faust workload.graph = &g;
137*c217d954SCole Faust workload.ctx = &ctx;
138*c217d954SCole Faust
139*c217d954SCole Faust // Reserve memory for tasks
140*c217d954SCole Faust workload.tasks.reserve(node_order.size());
141*c217d954SCole Faust
142*c217d954SCole Faust // Create tasks
143*c217d954SCole Faust for(auto &node_id : node_order)
144*c217d954SCole Faust {
145*c217d954SCole Faust auto node = g.node(node_id);
146*c217d954SCole Faust if(node != nullptr)
147*c217d954SCole Faust {
148*c217d954SCole Faust Target assigned_target = node->assigned_target();
149*c217d954SCole Faust backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(assigned_target);
150*c217d954SCole Faust std::unique_ptr<IFunction> func = backend.configure_node(*node, ctx);
151*c217d954SCole Faust if(func != nullptr || is_utility_node(node))
152*c217d954SCole Faust {
153*c217d954SCole Faust workload.tasks.emplace_back(ExecutionTask(std::move(func), node));
154*c217d954SCole Faust }
155*c217d954SCole Faust }
156*c217d954SCole Faust }
157*c217d954SCole Faust
158*c217d954SCole Faust // Add inputs and outputs
159*c217d954SCole Faust for(auto &node : g.nodes())
160*c217d954SCole Faust {
161*c217d954SCole Faust if(node != nullptr && node->type() == NodeType::Input)
162*c217d954SCole Faust {
163*c217d954SCole Faust workload.inputs.push_back(node->output(0));
164*c217d954SCole Faust }
165*c217d954SCole Faust
166*c217d954SCole Faust if(node != nullptr && node->type() == NodeType::Output)
167*c217d954SCole Faust {
168*c217d954SCole Faust workload.outputs.push_back(node->input(0));
169*c217d954SCole Faust continue;
170*c217d954SCole Faust }
171*c217d954SCole Faust }
172*c217d954SCole Faust
173*c217d954SCole Faust return workload;
174*c217d954SCole Faust }
175*c217d954SCole Faust
release_unused_tensors(Graph & g)176*c217d954SCole Faust void release_unused_tensors(Graph &g)
177*c217d954SCole Faust {
178*c217d954SCole Faust for(auto &tensor : g.tensors())
179*c217d954SCole Faust {
180*c217d954SCole Faust if(tensor != nullptr && tensor->handle() != nullptr)
181*c217d954SCole Faust {
182*c217d954SCole Faust tensor->handle()->release_if_unused();
183*c217d954SCole Faust }
184*c217d954SCole Faust }
185*c217d954SCole Faust }
186*c217d954SCole Faust
call_tensor_accessor(Tensor * tensor)187*c217d954SCole Faust void call_tensor_accessor(Tensor *tensor)
188*c217d954SCole Faust {
189*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(!tensor);
190*c217d954SCole Faust tensor->call_accessor();
191*c217d954SCole Faust }
192*c217d954SCole Faust
call_all_const_node_accessors(Graph & g)193*c217d954SCole Faust void call_all_const_node_accessors(Graph &g)
194*c217d954SCole Faust {
195*c217d954SCole Faust auto &nodes = g.nodes();
196*c217d954SCole Faust
197*c217d954SCole Faust for(auto &node : nodes)
198*c217d954SCole Faust {
199*c217d954SCole Faust if(node != nullptr && node->type() == NodeType::Const && node->num_outputs())
200*c217d954SCole Faust {
201*c217d954SCole Faust if(!node->output(0)->bound_edges().empty())
202*c217d954SCole Faust {
203*c217d954SCole Faust call_tensor_accessor(node->output(0));
204*c217d954SCole Faust }
205*c217d954SCole Faust }
206*c217d954SCole Faust }
207*c217d954SCole Faust }
208*c217d954SCole Faust
call_all_input_node_accessors(ExecutionWorkload & workload)209*c217d954SCole Faust bool call_all_input_node_accessors(ExecutionWorkload &workload)
210*c217d954SCole Faust {
211*c217d954SCole Faust bool is_valid = true;
212*c217d954SCole Faust std::for_each(std::begin(workload.inputs), std::end(workload.inputs), [&](Tensor * input_tensor)
213*c217d954SCole Faust {
214*c217d954SCole Faust bool valid_input = (input_tensor != nullptr) && input_tensor->call_accessor();
215*c217d954SCole Faust is_valid = is_valid && valid_input;
216*c217d954SCole Faust });
217*c217d954SCole Faust return is_valid;
218*c217d954SCole Faust }
219*c217d954SCole Faust
prepare_all_tasks(ExecutionWorkload & workload)220*c217d954SCole Faust void prepare_all_tasks(ExecutionWorkload &workload)
221*c217d954SCole Faust {
222*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(workload.graph == nullptr);
223*c217d954SCole Faust for(auto &task : workload.tasks)
224*c217d954SCole Faust {
225*c217d954SCole Faust task.prepare();
226*c217d954SCole Faust release_unused_tensors(*workload.graph);
227*c217d954SCole Faust }
228*c217d954SCole Faust }
229*c217d954SCole Faust
call_all_tasks(ExecutionWorkload & workload)230*c217d954SCole Faust void call_all_tasks(ExecutionWorkload &workload)
231*c217d954SCole Faust {
232*c217d954SCole Faust ARM_COMPUTE_ERROR_ON(workload.ctx == nullptr);
233*c217d954SCole Faust
234*c217d954SCole Faust // Acquire memory for the transition buffers
235*c217d954SCole Faust for(auto &mm_ctx : workload.ctx->memory_managers())
236*c217d954SCole Faust {
237*c217d954SCole Faust if(mm_ctx.second.cross_group != nullptr)
238*c217d954SCole Faust {
239*c217d954SCole Faust mm_ctx.second.cross_group->acquire();
240*c217d954SCole Faust }
241*c217d954SCole Faust }
242*c217d954SCole Faust
243*c217d954SCole Faust // Execute tasks
244*c217d954SCole Faust for(auto &task : workload.tasks)
245*c217d954SCole Faust {
246*c217d954SCole Faust task();
247*c217d954SCole Faust }
248*c217d954SCole Faust
249*c217d954SCole Faust // Release memory for the transition buffers
250*c217d954SCole Faust for(auto &mm_ctx : workload.ctx->memory_managers())
251*c217d954SCole Faust {
252*c217d954SCole Faust if(mm_ctx.second.cross_group != nullptr)
253*c217d954SCole Faust {
254*c217d954SCole Faust mm_ctx.second.cross_group->release();
255*c217d954SCole Faust }
256*c217d954SCole Faust }
257*c217d954SCole Faust }
258*c217d954SCole Faust
call_all_output_node_accessors(ExecutionWorkload & workload)259*c217d954SCole Faust bool call_all_output_node_accessors(ExecutionWorkload &workload)
260*c217d954SCole Faust {
261*c217d954SCole Faust bool is_valid = true;
262*c217d954SCole Faust std::for_each(std::begin(workload.outputs), std::end(workload.outputs), [&](Tensor * output_tensor)
263*c217d954SCole Faust {
264*c217d954SCole Faust bool valid_output = (output_tensor != nullptr) && output_tensor->call_accessor();
265*c217d954SCole Faust is_valid = is_valid && valid_output;
266*c217d954SCole Faust });
267*c217d954SCole Faust
268*c217d954SCole Faust sync_backends();
269*c217d954SCole Faust
270*c217d954SCole Faust return is_valid;
271*c217d954SCole Faust }
272*c217d954SCole Faust } // namespace detail
273*c217d954SCole Faust } // namespace graph
274*c217d954SCole Faust } // namespace arm_compute
275