xref: /aosp_15_r20/external/ComputeLibrary/src/graph/detail/CrossLayerMemoryManagerHelpers.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/detail/CrossLayerMemoryManagerHelpers.h"
25 
26 #include "arm_compute/graph/Graph.h"
27 #include "arm_compute/graph/GraphContext.h"
28 #include "arm_compute/graph/GraphManager.h"
29 #include "arm_compute/graph/INode.h"
30 #include "arm_compute/graph/Tensor.h"
31 #include "arm_compute/graph/Types.h"
32 #include "arm_compute/graph/Utils.h"
33 #include "arm_compute/graph/backends/BackendRegistry.h"
34 
35 #include "arm_compute/core/ITensor.h"
36 #include "support/Cast.h"
37 
38 #include <algorithm>
39 #include <map>
40 
41 namespace arm_compute
42 {
43 namespace graph
44 {
45 namespace detail
46 {
47 namespace
48 {
49 using HandleCountPair     = std::pair<ITensorHandle *, unsigned int>;
50 using HandleCounter       = std::map<HandleCountPair::first_type, HandleCountPair::second_type>;
51 using TargetHandleCounter = std::map<Target, HandleCounter>;
52 
53 /** Holds managed IO tensor handles if a task */
54 struct TaskHandles
55 {
56     std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> input_handles  = {}; /**< Input handles to a task */
57     std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> output_handles = {}; /**< Output handles of a task */
58 };
59 
60 /** Returns memory group depending on handle backend type
61  *
62  * @param[in] ctx    Graph context
63  * @param[in] handle Tensor handle
64  *
65  * @return Memory groupb
66  */
get_memory_group_from_handle(GraphContext & ctx,ITensorHandle * handle)67 IMemoryGroup *get_memory_group_from_handle(GraphContext &ctx, ITensorHandle *handle)
68 {
69     ARM_COMPUTE_ERROR_ON(handle == nullptr);
70     return ctx.memory_management_ctx(handle->target())->cross_group.get();
71 }
72 
73 /** Get handles of const tensors of graph
74  *
75  * @param[in] g Graph
76  *
77  * @return Handles of const tensors of graph
78  */
get_const_handles(const Graph & g)79 std::set<ITensorHandle *> get_const_handles(const Graph &g)
80 {
81     std::set<NodeType> const_node_types = { NodeType::Input, NodeType::Output, NodeType::Const };
82 
83     std::set<ITensorHandle *> const_tensors;
84 
85     auto &nodes = g.nodes();
86     for(auto &node : nodes)
87     {
88         // If its a const node:
89         if(node != nullptr && const_node_types.find(node->type()) != std::end(const_node_types))
90         {
91             // TODO (geopin01) : Create IO iterator wrappers
92             // Add all its inputs / outputs to the list of constant handles
93             for(unsigned int i = 0; i < node->num_inputs(); ++i)
94             {
95                 if(node->input(i) != nullptr)
96                 {
97                     const_tensors.insert(node->input(i)->handle()->parent_handle());
98                 }
99             }
100             for(unsigned int i = 0; i < node->num_outputs(); ++i)
101             {
102                 if(node->output(i) != nullptr)
103                 {
104                     const_tensors.insert(node->output(i)->handle()->parent_handle());
105                 }
106             }
107         }
108     }
109 
110     return const_tensors;
111 }
112 
113 /** Builds a list of all the transition handles (Handles that are used to link two nodes)
114  *
115  * @param[in] ctx           Graph context
116  * @param[in] task          Workload task
117  * @param[in] const_tensors Constant tensors
118  *
119  * @return List of transition handles
120  */
get_transition_handles(GraphContext & ctx,ExecutionTask & task,const std::set<ITensorHandle * > & const_tensors)121 TaskHandles get_transition_handles(GraphContext                    &ctx,
122                                    ExecutionTask                   &task,
123                                    const std::set<ITensorHandle *> &const_tensors)
124 {
125     ARM_COMPUTE_ERROR_ON(task.node == nullptr || (task.task == nullptr && !is_utility_node(task.node)));
126     INode &node = *task.node;
127 
128     TaskHandles transition_handles;
129 
130     // Add input handles
131     for(unsigned int i = 0; i < node.input_edges().size(); ++i)
132     {
133         Edge *input_edge = node.input_edge(i);
134         // If this input is the output of another node
135         if(input_edge != nullptr && input_edge->tensor() != nullptr && const_tensors.find(input_edge->tensor()->handle()->parent_handle()) == std::end(const_tensors))
136         {
137             // Then add it to the list of transition buffers
138             ITensorHandle *tensor_handle = input_edge->tensor()->handle()->parent_handle();
139             IMemoryGroup *mm_group      = get_memory_group_from_handle(ctx, tensor_handle);
140             transition_handles.input_handles.emplace_back(std::make_pair(tensor_handle, mm_group));
141         }
142     }
143 
144     // Add output handles
145     for(unsigned int i = 0; i < node.num_outputs(); ++i)
146     {
147         Tensor *output_tensor = node.output(i);
148         // If this output is used as an input for another node
149         if(output_tensor != nullptr && const_tensors.find(output_tensor->handle()->parent_handle()) == std::end(const_tensors))
150         {
151             ITensorHandle *tensor_handle = output_tensor->handle()->parent_handle();
152             IMemoryGroup *mm_group      = get_memory_group_from_handle(ctx, tensor_handle);
153             transition_handles.output_handles.emplace_back(std::make_pair(tensor_handle, mm_group));
154         }
155     }
156 
157     return transition_handles;
158 }
159 
160 /** Counts handles refcount for each input handle of each target
161  *
162  * @param[in]     task           Execution task containing the managed handles
163  * @param[in,out] handle_counter Data structure that keeps the handles reference count
164  */
count_input_handles_per_target(const TaskHandles & task_handles,TargetHandleCounter & handle_counter)165 void count_input_handles_per_target(const TaskHandles &task_handles, TargetHandleCounter &handle_counter)
166 {
167     for(const auto &handle : task_handles.input_handles)
168     {
169         ITensorHandle *key            = handle.first;
170         HandleCounter &target_counter = handle_counter[key->target()];
171         if(target_counter.find(key) == std::end(target_counter))
172         {
173             target_counter.emplace(std::make_pair(key, 1));
174         }
175         else
176         {
177             ++target_counter[key];
178         }
179     }
180 }
181 
182 /** Calculates the lifetime of each tensor handle
183  *
184  * @param[in, out] tasks_handles Tensor handles for each task
185  * @param[in]      hc            Data structure that keeps the handles reference count
186  */
configure_handle_lifetime(std::vector<TaskHandles> & tasks_handles,const HandleCounter & hc)187 void configure_handle_lifetime(std::vector<TaskHandles> &tasks_handles, const HandleCounter &hc)
188 {
189     // Identify max number of tensors in flight
190     HandleCounter tensors_in_flight;
191 
192     // Acquires the given handles and sets them as in flight if they aren't already
193     auto acquire = [&](std::vector<std::pair<ITensorHandle *, IMemoryGroup *>> &handles)
194     {
195         for(auto &handle : handles)
196         {
197             ITensorHandle *parent_handle = handle.first;
198             ARM_COMPUTE_ERROR_ON(parent_handle == nullptr);
199             // If the tensor is not already in flight:
200             if(tensors_in_flight.find(parent_handle) == std::end(tensors_in_flight))
201             {
202                 ARM_COMPUTE_ERROR_ON(hc.find(parent_handle) == std::end(hc));
203                 // Then add it to the list of in flight tensors
204                 tensors_in_flight.insert(std::make_pair(parent_handle, hc.at(parent_handle)));
205                 // Start of allocation's lifetime
206                 parent_handle->manage(handle.second);
207             }
208         }
209     };
210 
211     for(auto &task_handle : tasks_handles)
212     {
213         // Marking all the input and output tensors of the task as in flight
214         acquire(task_handle.input_handles);
215         acquire(task_handle.output_handles);
216 
217         // Releasing the input tensors
218         for(auto &input_handle : task_handle.input_handles)
219         {
220             ITensorHandle *ihandle = input_handle.first;
221             ARM_COMPUTE_ERROR_ON(ihandle == nullptr);
222             ARM_COMPUTE_ERROR_ON(tensors_in_flight.find(ihandle) == std::end(tensors_in_flight));
223             --tensors_in_flight[ihandle];
224             if(tensors_in_flight[ihandle] <= 0)
225             {
226                 // Remove tensor for tensors in flight
227                 tensors_in_flight.erase(ihandle);
228                 // End of allocation's lifetime
229                 ihandle->allocate();
230             }
231         }
232     }
233 }
234 } // namespace
235 
configure_transition_manager(Graph & g,GraphContext & ctx,ExecutionWorkload & workload)236 void configure_transition_manager(Graph &g, GraphContext &ctx, ExecutionWorkload &workload)
237 {
238     // Get const tensors (un-managed)
239     std::set<ITensorHandle *> const_tensors = get_const_handles(g);
240 
241     std::vector<TaskHandles> tasks_handles;
242     TargetHandleCounter      target_handle_count;
243 
244     // Count handles
245     for(auto &task : workload.tasks)
246     {
247         // Populates IO handles
248         tasks_handles.push_back(get_transition_handles(ctx, task, const_tensors));
249 
250         // Count handles
251         count_input_handles_per_target(tasks_handles.back(), target_handle_count);
252     }
253 
254     // Setup memory managers
255     for(auto &hc : target_handle_count)
256     {
257         MemoryManagerContext *mm_ctx = ctx.memory_management_ctx(hc.first);
258         if(mm_ctx != nullptr)
259         {
260             if(mm_ctx->cross_mm != nullptr && mm_ctx->cross_group != nullptr)
261             {
262                 // Manage and allocate tensors
263                 configure_handle_lifetime(tasks_handles, hc.second);
264             }
265         }
266     }
267 }
268 } // namespace detail
269 } // namespace graph
270 } // namespace arm_compute
271