xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/graph/Graph.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_GRAPH_GRAPH_H
25 #define ARM_COMPUTE_GRAPH_GRAPH_H
26 
27 #include "arm_compute/graph/Edge.h"
28 #include "arm_compute/graph/INode.h"
29 #include "arm_compute/graph/Tensor.h"
30 #include "arm_compute/graph/Types.h"
31 
32 #include "support/Mutex.h"
33 #include "support/ToolchainSupport.h"
34 
35 #include <map>
36 #include <memory>
37 #include <string>
38 #include <utility>
39 #include <vector>
40 
41 #ifndef BARE_METAL
42 #include <thread>
43 #endif /* BARE_METAL */
44 
45 namespace arm_compute
46 {
47 namespace graph
48 {
49 /** Graph class
50  *
51  * Represents a multiple source - multiple sink directed graph
52  */
53 class Graph final
54 {
55 public:
56     Graph() = default;
57     /** Constructor
58      *
59      * @param[in] id   Graph identification number. Can be used to differentiate between graphs. Default value 0
60      * @param[in] name Graph name. Default value empty string
61      */
62     Graph(GraphID id, std::string name);
63     /** Prevent instances of this class from being copied (As this class contains pointers) */
64     Graph(const Graph &) = delete;
65     /** Prevent instances of this class from being copy assigned (As this class contains pointers) */
66     Graph &operator=(const Graph &) = delete;
67     /** Prevent instances of this class from being moved (As this class contains non movable objects) */
68     Graph(Graph &&) = delete;
69     /** Prevent instances of this class from being moved (As this class contains non movable objects) */
70     Graph &operator=(Graph &&) = delete;
71     /** Adds a node to the graph
72      *
73      * @note Models a single output node
74      *
75      * @tparam NT Node operation
76      * @tparam Ts Arguments to operation
77      *
78      * @param[in] args Node arguments
79      *
80      * @return ID of the node
81      */
82     template <typename NT, typename... Ts>
83     NodeID add_node(Ts &&... args);
84     /** Remove the node with the given ID
85      *
86      * @param[in] nid ID of the node to remove
87      *
88      * @return True if the removal took place else false
89      */
90     bool remove_node(NodeID nid);
91     /** Adds a connection between two nodes
92      *
93      * @param[in] source     ID of the source node
94      * @param[in] source_idx Output index of the source node
95      * @param[in] sink       ID of the sink node
96      * @param[in] sink_idx   Input index of the sink node
97      *
98      * @return ID of this connection
99      */
100     EdgeID add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx);
101     /** Removes an edge (connection)
102      *
103      * @param[in] eid Connection to remove
104      *
105      * @return True if the removal took place else false
106      */
107     bool remove_connection(EdgeID eid);
108     /** Returns graph name
109      *
110      * @return Graph name
111      */
112     std::string name() const;
113     /** Returns graph id
114      *
115      * @return Graph id
116      */
117     GraphID id() const;
118     /** Returns graph input nodes
119      *
120      * @param[in] type Type of nodes to return
121      *
122      * @return vector containing the graph node of given type
123      */
124     const std::vector<NodeID> &nodes(NodeType type);
125     /** Returns nodes of graph
126      *
127      * @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
128      *
129      * @return Nodes of graph
130      */
131     std::vector<std::unique_ptr<INode>> &nodes();
132     /** Returns nodes of graph
133      *
134      * @warning Nodes can be nullptr if they have been removed during the mutation steps of the graph
135      *
136      * @return Nodes of graph
137      */
138     const std::vector<std::unique_ptr<INode>> &nodes() const;
139     /** Returns edges of graph
140      *
141      * @warning Edges can be nullptr if they have been removed during the mutation steps of the graph
142      *
143      * @return Edges of graph
144      */
145     const std::vector<std::unique_ptr<Edge>> &edges() const;
146     /** Returns tensors of graph
147      *
148      * @warning Tensor can be nullptr if they have been removed during the mutation steps of the graph
149      *
150      * @return Tensors of graph
151      */
152     std::vector<std::unique_ptr<Tensor>> &tensors();
153     /** Returns tensors of graph
154      *
155      * @warning Tensor can be nullptr if they have been removed during the mutation steps of the graph
156      *
157      * @return Tensors of graph
158      */
159     const std::vector<std::unique_ptr<Tensor>> &tensors() const;
160     /** Get node object given its id
161      *
162      * @warning Can be nullptr if node was removed during the mutation steps of the graph
163      *
164      * @param[in] id Node ID
165      *
166      * @return The actual node object
167      */
168     const INode *node(NodeID id) const;
169     /** Get node object given its id
170      *
171      * @warning Can be nullptr if node was removed during the mutation steps of the graph
172      *
173      * @param[in] id Node ID
174      *
175      * @return The actual node object
176      */
177     INode *node(NodeID id);
178     /** Get edge object given its id
179      *
180      * @warning Can be nullptr if node was removed during the mutation steps of the graph
181      *
182      * @param[in] id Edge ID
183      *
184      * @return The actual edge object
185      */
186     const Edge *edge(EdgeID id) const;
187     /** Get edge object given its id
188      *
189      * @warning Can be nullptr if node was removed during the mutation steps of the graph
190      *
191      * @param[in] id Edge ID
192      *
193      * @return The actual edge object
194      */
195     Edge *edge(EdgeID id);
196     /** Get tensor object given its id
197      *
198      * @warning Can be nullptr if tensor was removed during the mutation steps of the graph
199      *
200      * @param[in] id Tensor ID
201      *
202      * @return The actual tensor object
203      */
204     const Tensor *tensor(TensorID id) const;
205     /** Get tensor object given its id
206      *
207      * @warning Can be nullptr if tensor was removed during the mutation steps of the graph
208      *
209      * @param[in] id Tensor ID
210      *
211      * @return The actual tensor object
212      */
213     Tensor *tensor(TensorID id);
214 
215 private:
216     /** Creates a tensor object
217      *
218      * @param[in] desc Tensor descriptor
219      *
220      * @return Tensor ID
221      */
222     TensorID create_tensor(const TensorDescriptor &desc = TensorDescriptor());
223 
224 private:
225     GraphID                              _id      = GraphID(0); /**< Graph id */
226     std::string                          _name    = {};         /**< Graph name */
227     std::vector<std::unique_ptr<INode>>  _nodes   = {};         /**< Graph nodes */
228     std::vector<std::unique_ptr<Edge>>   _edges   = {};         /**< Graph edges */
229     std::vector<std::unique_ptr<Tensor>> _tensors = {};         /**< Graph tensors */
230     std::map<NodeType, std::vector<NodeID>> _tagged_nodes = {}; /**< Graph nodes map with the node type as key */
231     arm_compute::Mutex _mtx = {};                               /**< Mutex used for graph construction */
232 };
233 
234 template <typename NT, typename... Ts>
add_node(Ts &&...args)235 inline NodeID Graph::add_node(Ts &&... args)
236 {
237     arm_compute::lock_guard<arm_compute::Mutex> lock(_mtx);
238 
239     // Create node
240     NodeID nid  = _nodes.size();
241     auto   node = std::make_unique<NT>(std::forward<Ts>(args)...);
242     node->set_graph(this);
243     node->set_id(nid);
244 
245     // Keep track of input nodes
246     _tagged_nodes[node->type()].push_back(nid);
247 
248     // Associate a new tensor with each output
249     for(auto &output : node->_outputs)
250     {
251         output = create_tensor();
252     }
253 
254     // Propagate node shape if possible
255     node->forward_descriptors();
256 
257     // Add node to the graph nodes
258     _nodes.push_back(std::move(node));
259 
260     return nid;
261 }
262 } // namespace graph
263 } // namespace arm_compute
264 #endif /* ARM_COMPUTE_GRAPH_GRAPH_H */
265