xref: /aosp_15_r20/external/ComputeLibrary/src/graph/detail/CrossLayerMemoryManagerHelpers.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018-2020 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "arm_compute/graph/detail/CrossLayerMemoryManagerHelpers.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/graph/Graph.h"
27*c217d954SCole Faust #include "arm_compute/graph/GraphContext.h"
28*c217d954SCole Faust #include "arm_compute/graph/GraphManager.h"
29*c217d954SCole Faust #include "arm_compute/graph/INode.h"
30*c217d954SCole Faust #include "arm_compute/graph/Tensor.h"
31*c217d954SCole Faust #include "arm_compute/graph/Types.h"
32*c217d954SCole Faust #include "arm_compute/graph/Utils.h"
33*c217d954SCole Faust #include "arm_compute/graph/backends/BackendRegistry.h"
34*c217d954SCole Faust 
35*c217d954SCole Faust #include "arm_compute/core/ITensor.h"
36*c217d954SCole Faust #include "support/Cast.h"
37*c217d954SCole Faust 
38*c217d954SCole Faust #include <algorithm>
39*c217d954SCole Faust #include <map>
40*c217d954SCole Faust 
41*c217d954SCole Faust namespace arm_compute
42*c217d954SCole Faust {
43*c217d954SCole Faust namespace graph
44*c217d954SCole Faust {
45*c217d954SCole Faust namespace detail
46*c217d954SCole Faust {
47*c217d954SCole Faust namespace
48*c217d954SCole Faust {
49*c217d954SCole Faust using HandleCountPair     = std::pair<ITensorHandle *, unsigned int>;
50*c217d954SCole Faust using HandleCounter       = std::map<HandleCountPair::first_type, HandleCountPair::second_type>;
51*c217d954SCole Faust using TargetHandleCounter = std::map<Target, HandleCounter>;
52*c217d954SCole Faust 
53*c217d954SCole Faust /** Holds managed IO tensor handles if a task */
54*c217d954SCole Faust struct TaskHandles
55*c217d954SCole Faust {
56*c217d954SCole Faust     std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> input_handles  = {}; /**< Input handles to a task */
57*c217d954SCole Faust     std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> output_handles = {}; /**< Output handles of a task */
58*c217d954SCole Faust };
59*c217d954SCole Faust 
60*c217d954SCole Faust /** Returns memory group depending on handle backend type
61*c217d954SCole Faust  *
62*c217d954SCole Faust  * @param[in] ctx    Graph context
63*c217d954SCole Faust  * @param[in] handle Tensor handle
64*c217d954SCole Faust  *
65*c217d954SCole Faust  * @return Memory groupb
66*c217d954SCole Faust  */
get_memory_group_from_handle(GraphContext & ctx,ITensorHandle * handle)67*c217d954SCole Faust IMemoryGroup *get_memory_group_from_handle(GraphContext &ctx, ITensorHandle *handle)
68*c217d954SCole Faust {
69*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(handle == nullptr);
70*c217d954SCole Faust     return ctx.memory_management_ctx(handle->target())->cross_group.get();
71*c217d954SCole Faust }
72*c217d954SCole Faust 
73*c217d954SCole Faust /** Get handles of const tensors of graph
74*c217d954SCole Faust  *
75*c217d954SCole Faust  * @param[in] g Graph
76*c217d954SCole Faust  *
77*c217d954SCole Faust  * @return Handles of const tensors of graph
78*c217d954SCole Faust  */
get_const_handles(const Graph & g)79*c217d954SCole Faust std::set<ITensorHandle *> get_const_handles(const Graph &g)
80*c217d954SCole Faust {
81*c217d954SCole Faust     std::set<NodeType> const_node_types = { NodeType::Input, NodeType::Output, NodeType::Const };
82*c217d954SCole Faust 
83*c217d954SCole Faust     std::set<ITensorHandle *> const_tensors;
84*c217d954SCole Faust 
85*c217d954SCole Faust     auto &nodes = g.nodes();
86*c217d954SCole Faust     for(auto &node : nodes)
87*c217d954SCole Faust     {
88*c217d954SCole Faust         // If its a const node:
89*c217d954SCole Faust         if(node != nullptr && const_node_types.find(node->type()) != std::end(const_node_types))
90*c217d954SCole Faust         {
91*c217d954SCole Faust             // TODO (geopin01) : Create IO iterator wrappers
92*c217d954SCole Faust             // Add all its inputs / outputs to the list of constant handles
93*c217d954SCole Faust             for(unsigned int i = 0; i < node->num_inputs(); ++i)
94*c217d954SCole Faust             {
95*c217d954SCole Faust                 if(node->input(i) != nullptr)
96*c217d954SCole Faust                 {
97*c217d954SCole Faust                     const_tensors.insert(node->input(i)->handle()->parent_handle());
98*c217d954SCole Faust                 }
99*c217d954SCole Faust             }
100*c217d954SCole Faust             for(unsigned int i = 0; i < node->num_outputs(); ++i)
101*c217d954SCole Faust             {
102*c217d954SCole Faust                 if(node->output(i) != nullptr)
103*c217d954SCole Faust                 {
104*c217d954SCole Faust                     const_tensors.insert(node->output(i)->handle()->parent_handle());
105*c217d954SCole Faust                 }
106*c217d954SCole Faust             }
107*c217d954SCole Faust         }
108*c217d954SCole Faust     }
109*c217d954SCole Faust 
110*c217d954SCole Faust     return const_tensors;
111*c217d954SCole Faust }
112*c217d954SCole Faust 
113*c217d954SCole Faust /** Builds a list of all the transition handles (Handles that are used to link two nodes)
114*c217d954SCole Faust  *
115*c217d954SCole Faust  * @param[in] ctx           Graph context
116*c217d954SCole Faust  * @param[in] task          Workload task
117*c217d954SCole Faust  * @param[in] const_tensors Constant tensors
118*c217d954SCole Faust  *
119*c217d954SCole Faust  * @return List of transition handles
120*c217d954SCole Faust  */
get_transition_handles(GraphContext & ctx,ExecutionTask & task,const std::set<ITensorHandle * > & const_tensors)121*c217d954SCole Faust TaskHandles get_transition_handles(GraphContext                    &ctx,
122*c217d954SCole Faust                                    ExecutionTask                   &task,
123*c217d954SCole Faust                                    const std::set<ITensorHandle *> &const_tensors)
124*c217d954SCole Faust {
125*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(task.node == nullptr || (task.task == nullptr && !is_utility_node(task.node)));
126*c217d954SCole Faust     INode &node = *task.node;
127*c217d954SCole Faust 
128*c217d954SCole Faust     TaskHandles transition_handles;
129*c217d954SCole Faust 
130*c217d954SCole Faust     // Add input handles
131*c217d954SCole Faust     for(unsigned int i = 0; i < node.input_edges().size(); ++i)
132*c217d954SCole Faust     {
133*c217d954SCole Faust         Edge *input_edge = node.input_edge(i);
134*c217d954SCole Faust         // If this input is the output of another node
135*c217d954SCole Faust         if(input_edge != nullptr && input_edge->tensor() != nullptr && const_tensors.find(input_edge->tensor()->handle()->parent_handle()) == std::end(const_tensors))
136*c217d954SCole Faust         {
137*c217d954SCole Faust             // Then add it to the list of transition buffers
138*c217d954SCole Faust             ITensorHandle *tensor_handle = input_edge->tensor()->handle()->parent_handle();
139*c217d954SCole Faust             IMemoryGroup *mm_group      = get_memory_group_from_handle(ctx, tensor_handle);
140*c217d954SCole Faust             transition_handles.input_handles.emplace_back(std::make_pair(tensor_handle, mm_group));
141*c217d954SCole Faust         }
142*c217d954SCole Faust     }
143*c217d954SCole Faust 
144*c217d954SCole Faust     // Add output handles
145*c217d954SCole Faust     for(unsigned int i = 0; i < node.num_outputs(); ++i)
146*c217d954SCole Faust     {
147*c217d954SCole Faust         Tensor *output_tensor = node.output(i);
148*c217d954SCole Faust         // If this output is used as an input for another node
149*c217d954SCole Faust         if(output_tensor != nullptr && const_tensors.find(output_tensor->handle()->parent_handle()) == std::end(const_tensors))
150*c217d954SCole Faust         {
151*c217d954SCole Faust             ITensorHandle *tensor_handle = output_tensor->handle()->parent_handle();
152*c217d954SCole Faust             IMemoryGroup *mm_group      = get_memory_group_from_handle(ctx, tensor_handle);
153*c217d954SCole Faust             transition_handles.output_handles.emplace_back(std::make_pair(tensor_handle, mm_group));
154*c217d954SCole Faust         }
155*c217d954SCole Faust     }
156*c217d954SCole Faust 
157*c217d954SCole Faust     return transition_handles;
158*c217d954SCole Faust }
159*c217d954SCole Faust 
160*c217d954SCole Faust /** Counts handles refcount for each input handle of each target
161*c217d954SCole Faust  *
162*c217d954SCole Faust  * @param[in]     task           Execution task containing the managed handles
163*c217d954SCole Faust  * @param[in,out] handle_counter Data structure that keeps the handles reference count
164*c217d954SCole Faust  */
count_input_handles_per_target(const TaskHandles & task_handles,TargetHandleCounter & handle_counter)165*c217d954SCole Faust void count_input_handles_per_target(const TaskHandles &task_handles, TargetHandleCounter &handle_counter)
166*c217d954SCole Faust {
167*c217d954SCole Faust     for(const auto &handle : task_handles.input_handles)
168*c217d954SCole Faust     {
169*c217d954SCole Faust         ITensorHandle *key            = handle.first;
170*c217d954SCole Faust         HandleCounter &target_counter = handle_counter[key->target()];
171*c217d954SCole Faust         if(target_counter.find(key) == std::end(target_counter))
172*c217d954SCole Faust         {
173*c217d954SCole Faust             target_counter.emplace(std::make_pair(key, 1));
174*c217d954SCole Faust         }
175*c217d954SCole Faust         else
176*c217d954SCole Faust         {
177*c217d954SCole Faust             ++target_counter[key];
178*c217d954SCole Faust         }
179*c217d954SCole Faust     }
180*c217d954SCole Faust }
181*c217d954SCole Faust 
182*c217d954SCole Faust /** Calculates the lifetime of each tensor handle
183*c217d954SCole Faust  *
184*c217d954SCole Faust  * @param[in, out] tasks_handles Tensor handles for each task
185*c217d954SCole Faust  * @param[in]      hc            Data structure that keeps the handles reference count
186*c217d954SCole Faust  */
configure_handle_lifetime(std::vector<TaskHandles> & tasks_handles,const HandleCounter & hc)187*c217d954SCole Faust void configure_handle_lifetime(std::vector<TaskHandles> &tasks_handles, const HandleCounter &hc)
188*c217d954SCole Faust {
189*c217d954SCole Faust     // Identify max number of tensors in flight
190*c217d954SCole Faust     HandleCounter tensors_in_flight;
191*c217d954SCole Faust 
192*c217d954SCole Faust     // Acquires the given handles and sets them as in flight if they aren't already
193*c217d954SCole Faust     auto acquire = [&](std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> &handles)
194*c217d954SCole Faust     {
195*c217d954SCole Faust         for(auto &handle : handles)
196*c217d954SCole Faust         {
197*c217d954SCole Faust             ITensorHandle *parent_handle = handle.first;
198*c217d954SCole Faust             ARM_COMPUTE_ERROR_ON(parent_handle == nullptr);
199*c217d954SCole Faust             // If the tensor is not already in flight:
200*c217d954SCole Faust             if(tensors_in_flight.find(parent_handle) == std::end(tensors_in_flight))
201*c217d954SCole Faust             {
202*c217d954SCole Faust                 ARM_COMPUTE_ERROR_ON(hc.find(parent_handle) == std::end(hc));
203*c217d954SCole Faust                 // Then add it to the list of in flight tensors
204*c217d954SCole Faust                 tensors_in_flight.insert(std::make_pair(parent_handle, hc.at(parent_handle)));
205*c217d954SCole Faust                 // Start of allocation's lifetime
206*c217d954SCole Faust                 parent_handle->manage(handle.second);
207*c217d954SCole Faust             }
208*c217d954SCole Faust         }
209*c217d954SCole Faust     };
210*c217d954SCole Faust 
211*c217d954SCole Faust     for(auto &task_handle : tasks_handles)
212*c217d954SCole Faust     {
213*c217d954SCole Faust         // Marking all the input and output tensors of the task as in flight
214*c217d954SCole Faust         acquire(task_handle.input_handles);
215*c217d954SCole Faust         acquire(task_handle.output_handles);
216*c217d954SCole Faust 
217*c217d954SCole Faust         // Releasing the input tensors
218*c217d954SCole Faust         for(auto &input_handle : task_handle.input_handles)
219*c217d954SCole Faust         {
220*c217d954SCole Faust             ITensorHandle *ihandle = input_handle.first;
221*c217d954SCole Faust             ARM_COMPUTE_ERROR_ON(ihandle == nullptr);
222*c217d954SCole Faust             ARM_COMPUTE_ERROR_ON(tensors_in_flight.find(ihandle) == std::end(tensors_in_flight));
223*c217d954SCole Faust             --tensors_in_flight[ihandle];
224*c217d954SCole Faust             if(tensors_in_flight[ihandle] <= 0)
225*c217d954SCole Faust             {
226*c217d954SCole Faust                 // Remove tensor for tensors in flight
227*c217d954SCole Faust                 tensors_in_flight.erase(ihandle);
228*c217d954SCole Faust                 // End of allocation's lifetime
229*c217d954SCole Faust                 ihandle->allocate();
230*c217d954SCole Faust             }
231*c217d954SCole Faust         }
232*c217d954SCole Faust     }
233*c217d954SCole Faust }
234*c217d954SCole Faust } // namespace
235*c217d954SCole Faust 
configure_transition_manager(Graph & g,GraphContext & ctx,ExecutionWorkload & workload)236*c217d954SCole Faust void configure_transition_manager(Graph &g, GraphContext &ctx, ExecutionWorkload &workload)
237*c217d954SCole Faust {
238*c217d954SCole Faust     // Get const tensors (un-managed)
239*c217d954SCole Faust     std::set<ITensorHandle *> const_tensors = get_const_handles(g);
240*c217d954SCole Faust 
241*c217d954SCole Faust     std::vector<TaskHandles> tasks_handles;
242*c217d954SCole Faust     TargetHandleCounter      target_handle_count;
243*c217d954SCole Faust 
244*c217d954SCole Faust     // Count handles
245*c217d954SCole Faust     for(auto &task : workload.tasks)
246*c217d954SCole Faust     {
247*c217d954SCole Faust         // Populates IO handles
248*c217d954SCole Faust         tasks_handles.push_back(get_transition_handles(ctx, task, const_tensors));
249*c217d954SCole Faust 
250*c217d954SCole Faust         // Count handles
251*c217d954SCole Faust         count_input_handles_per_target(tasks_handles.back(), target_handle_count);
252*c217d954SCole Faust     }
253*c217d954SCole Faust 
254*c217d954SCole Faust     // Setup memory managers
255*c217d954SCole Faust     for(auto &hc : target_handle_count)
256*c217d954SCole Faust     {
257*c217d954SCole Faust         MemoryManagerContext *mm_ctx = ctx.memory_management_ctx(hc.first);
258*c217d954SCole Faust         if(mm_ctx != nullptr)
259*c217d954SCole Faust         {
260*c217d954SCole Faust             if(mm_ctx->cross_mm != nullptr && mm_ctx->cross_group != nullptr)
261*c217d954SCole Faust             {
262*c217d954SCole Faust                 // Manage and allocate tensors
263*c217d954SCole Faust                 configure_handle_lifetime(tasks_handles, hc.second);
264*c217d954SCole Faust             }
265*c217d954SCole Faust         }
266*c217d954SCole Faust     }
267*c217d954SCole Faust }
268*c217d954SCole Faust } // namespace detail
269*c217d954SCole Faust } // namespace graph
270*c217d954SCole Faust } // namespace arm_compute
271