xref: /aosp_15_r20/external/ComputeLibrary/src/graph/Graph.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018-2020 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "arm_compute/graph/Graph.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust namespace arm_compute
27*c217d954SCole Faust {
28*c217d954SCole Faust namespace graph
29*c217d954SCole Faust {
Graph(GraphID id,std::string name)30*c217d954SCole Faust Graph::Graph(GraphID id, std::string name)
31*c217d954SCole Faust     : _id(id), _name(std::move(name)), _nodes(), _edges(), _tensors(), _tagged_nodes(), _mtx()
32*c217d954SCole Faust {
33*c217d954SCole Faust }
34*c217d954SCole Faust 
remove_node(NodeID nid)35*c217d954SCole Faust bool Graph::remove_node(NodeID nid)
36*c217d954SCole Faust {
37*c217d954SCole Faust     if(nid >= _nodes.size())
38*c217d954SCole Faust     {
39*c217d954SCole Faust         return false;
40*c217d954SCole Faust     }
41*c217d954SCole Faust 
42*c217d954SCole Faust     std::unique_ptr<INode> &node = _nodes[nid];
43*c217d954SCole Faust 
44*c217d954SCole Faust     if(node)
45*c217d954SCole Faust     {
46*c217d954SCole Faust         // Remove input connections
47*c217d954SCole Faust         for(auto &input_eid : node->_input_edges)
48*c217d954SCole Faust         {
49*c217d954SCole Faust             remove_connection(input_eid);
50*c217d954SCole Faust         }
51*c217d954SCole Faust 
52*c217d954SCole Faust         // Remove output connections
53*c217d954SCole Faust         std::set<EdgeID> output_edges_copy = node->output_edges();
54*c217d954SCole Faust         for(auto &output_eid : output_edges_copy)
55*c217d954SCole Faust         {
56*c217d954SCole Faust             remove_connection(output_eid);
57*c217d954SCole Faust         }
58*c217d954SCole Faust 
59*c217d954SCole Faust         // Remove nid from tagged nodes
60*c217d954SCole Faust         std::vector<NodeID> &tnodes = _tagged_nodes.at(node->type());
61*c217d954SCole Faust         tnodes.erase(std::remove(tnodes.begin(), tnodes.end(), nid), tnodes.end());
62*c217d954SCole Faust     }
63*c217d954SCole Faust 
64*c217d954SCole Faust     node = nullptr;
65*c217d954SCole Faust 
66*c217d954SCole Faust     return true;
67*c217d954SCole Faust }
68*c217d954SCole Faust 
add_connection(NodeID source,size_t source_idx,NodeID sink,size_t sink_idx)69*c217d954SCole Faust EdgeID Graph::add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx)
70*c217d954SCole Faust {
71*c217d954SCole Faust     arm_compute::lock_guard<arm_compute::Mutex> lock(_mtx);
72*c217d954SCole Faust 
73*c217d954SCole Faust     // Check if node index is valid, if node exists and finally if the connection index is valid
74*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON((source >= _nodes.size()) || (_nodes[source] == nullptr) || (source_idx >= _nodes[source]->num_outputs()));
75*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON((sink >= _nodes.size()) || (_nodes[sink] == nullptr) || (sink_idx >= _nodes[sink]->num_inputs()));
76*c217d954SCole Faust 
77*c217d954SCole Faust     // Get nodes
78*c217d954SCole Faust     std::unique_ptr<INode> &source_node = _nodes[source];
79*c217d954SCole Faust     std::unique_ptr<INode> &sink_node   = _nodes[sink];
80*c217d954SCole Faust 
81*c217d954SCole Faust     // Check for duplicate connections (Check only sink node)
82*c217d954SCole Faust     Edge *sink_node_edge = sink_node->input_edge(sink_idx);
83*c217d954SCole Faust     if((sink_node_edge != nullptr) && (sink_node_edge->producer_id() == source) && (sink_node_edge->producer_idx() == source_idx)
84*c217d954SCole Faust        && (sink_node_edge->consumer_id() == sink) && (sink_node_edge->consumer_idx() == sink_idx))
85*c217d954SCole Faust     {
86*c217d954SCole Faust         return sink_node_edge->id();
87*c217d954SCole Faust     }
88*c217d954SCole Faust 
89*c217d954SCole Faust     // Check if there is already a tensor associated with output if not create one
90*c217d954SCole Faust     TensorID tid = source_node->output_id(source_idx);
91*c217d954SCole Faust     if(tid == NullTensorID)
92*c217d954SCole Faust     {
93*c217d954SCole Faust         tid = create_tensor();
94*c217d954SCole Faust     }
95*c217d954SCole Faust     std::unique_ptr<Tensor> &tensor = _tensors[tid];
96*c217d954SCole Faust 
97*c217d954SCole Faust     // Create connections
98*c217d954SCole Faust     EdgeID eid        = _edges.size();
99*c217d954SCole Faust     auto   connection = std::make_unique<Edge>(eid, source_node.get(), source_idx, sink_node.get(), sink_idx, tensor.get());
100*c217d954SCole Faust     _edges.push_back(std::move(connection));
101*c217d954SCole Faust 
102*c217d954SCole Faust     // Add connections to source and sink nodes
103*c217d954SCole Faust     source_node->_output_edges.insert(eid);
104*c217d954SCole Faust     sink_node->_input_edges[sink_idx] = eid;
105*c217d954SCole Faust 
106*c217d954SCole Faust     // Set tensor output node
107*c217d954SCole Faust     source_node->_outputs[source_idx] = tid;
108*c217d954SCole Faust 
109*c217d954SCole Faust     // Bind tensor to the edge
110*c217d954SCole Faust     tensor->bind_edge(eid);
111*c217d954SCole Faust 
112*c217d954SCole Faust     // Try and propagate shapes in sink node
113*c217d954SCole Faust     sink_node->forward_descriptors();
114*c217d954SCole Faust 
115*c217d954SCole Faust     return eid;
116*c217d954SCole Faust }
117*c217d954SCole Faust 
remove_connection(EdgeID eid)118*c217d954SCole Faust bool Graph::remove_connection(EdgeID eid)
119*c217d954SCole Faust {
120*c217d954SCole Faust     if(eid >= _edges.size())
121*c217d954SCole Faust     {
122*c217d954SCole Faust         return false;
123*c217d954SCole Faust     }
124*c217d954SCole Faust 
125*c217d954SCole Faust     std::unique_ptr<Edge> &edge = _edges[eid];
126*c217d954SCole Faust 
127*c217d954SCole Faust     // Remove node connections
128*c217d954SCole Faust     if(edge != nullptr)
129*c217d954SCole Faust     {
130*c217d954SCole Faust         // Get tensor bound to the edge
131*c217d954SCole Faust         if(edge->tensor() != nullptr)
132*c217d954SCole Faust         {
133*c217d954SCole Faust             edge->tensor()->unbind_edge(eid);
134*c217d954SCole Faust         }
135*c217d954SCole Faust 
136*c217d954SCole Faust         // Remove edges from source node
137*c217d954SCole Faust         if(edge->producer() != nullptr)
138*c217d954SCole Faust         {
139*c217d954SCole Faust             edge->producer()->_output_edges.erase(eid);
140*c217d954SCole Faust         }
141*c217d954SCole Faust 
142*c217d954SCole Faust         // Remove edges from sink node
143*c217d954SCole Faust         if((edge->consumer() != nullptr) && (edge->consumer_idx() < edge->consumer()->_input_edges.size()))
144*c217d954SCole Faust         {
145*c217d954SCole Faust             edge->consumer()->_input_edges[edge->consumer_idx()] = EmptyEdgeID;
146*c217d954SCole Faust         }
147*c217d954SCole Faust     }
148*c217d954SCole Faust 
149*c217d954SCole Faust     // Clear edge
150*c217d954SCole Faust     edge = nullptr;
151*c217d954SCole Faust 
152*c217d954SCole Faust     return true;
153*c217d954SCole Faust }
154*c217d954SCole Faust 
create_tensor(const TensorDescriptor & desc)155*c217d954SCole Faust TensorID Graph::create_tensor(const TensorDescriptor &desc)
156*c217d954SCole Faust {
157*c217d954SCole Faust     TensorID tid    = _tensors.size();
158*c217d954SCole Faust     auto     tensor = std::make_unique<Tensor>(tid, desc);
159*c217d954SCole Faust     _tensors.push_back(std::move(tensor));
160*c217d954SCole Faust 
161*c217d954SCole Faust     return tid;
162*c217d954SCole Faust }
163*c217d954SCole Faust 
name() const164*c217d954SCole Faust std::string Graph::name() const
165*c217d954SCole Faust {
166*c217d954SCole Faust     return _name;
167*c217d954SCole Faust }
168*c217d954SCole Faust 
id() const169*c217d954SCole Faust GraphID Graph::id() const
170*c217d954SCole Faust {
171*c217d954SCole Faust     return _id;
172*c217d954SCole Faust }
173*c217d954SCole Faust 
nodes(NodeType type)174*c217d954SCole Faust const std::vector<NodeID> &Graph::nodes(NodeType type)
175*c217d954SCole Faust {
176*c217d954SCole Faust     return _tagged_nodes[type];
177*c217d954SCole Faust }
178*c217d954SCole Faust 
nodes()179*c217d954SCole Faust std::vector<std::unique_ptr<INode>> &Graph::nodes()
180*c217d954SCole Faust {
181*c217d954SCole Faust     return _nodes;
182*c217d954SCole Faust }
183*c217d954SCole Faust 
nodes() const184*c217d954SCole Faust const std::vector<std::unique_ptr<INode>> &Graph::nodes() const
185*c217d954SCole Faust {
186*c217d954SCole Faust     return _nodes;
187*c217d954SCole Faust }
188*c217d954SCole Faust 
edges() const189*c217d954SCole Faust const std::vector<std::unique_ptr<Edge>> &Graph::edges() const
190*c217d954SCole Faust {
191*c217d954SCole Faust     return _edges;
192*c217d954SCole Faust }
193*c217d954SCole Faust 
tensors()194*c217d954SCole Faust std::vector<std::unique_ptr<Tensor>> &Graph::tensors()
195*c217d954SCole Faust {
196*c217d954SCole Faust     return _tensors;
197*c217d954SCole Faust }
198*c217d954SCole Faust 
tensors() const199*c217d954SCole Faust const std::vector<std::unique_ptr<Tensor>> &Graph::tensors() const
200*c217d954SCole Faust {
201*c217d954SCole Faust     return _tensors;
202*c217d954SCole Faust }
203*c217d954SCole Faust 
node(NodeID id) const204*c217d954SCole Faust const INode *Graph::node(NodeID id) const
205*c217d954SCole Faust {
206*c217d954SCole Faust     return (id >= _nodes.size()) ? nullptr : _nodes[id].get();
207*c217d954SCole Faust }
208*c217d954SCole Faust 
node(NodeID id)209*c217d954SCole Faust INode *Graph::node(NodeID id)
210*c217d954SCole Faust {
211*c217d954SCole Faust     return (id >= _nodes.size()) ? nullptr : _nodes[id].get();
212*c217d954SCole Faust }
213*c217d954SCole Faust 
edge(EdgeID id) const214*c217d954SCole Faust const Edge *Graph::edge(EdgeID id) const
215*c217d954SCole Faust {
216*c217d954SCole Faust     return (id >= _edges.size()) ? nullptr : _edges[id].get();
217*c217d954SCole Faust }
218*c217d954SCole Faust 
edge(EdgeID id)219*c217d954SCole Faust Edge *Graph::edge(EdgeID id)
220*c217d954SCole Faust {
221*c217d954SCole Faust     return (id >= _edges.size()) ? nullptr : _edges[id].get();
222*c217d954SCole Faust }
223*c217d954SCole Faust 
tensor(TensorID id) const224*c217d954SCole Faust const Tensor *Graph::tensor(TensorID id) const
225*c217d954SCole Faust {
226*c217d954SCole Faust     return (id >= _tensors.size()) ? nullptr : _tensors[id].get();
227*c217d954SCole Faust }
228*c217d954SCole Faust 
tensor(TensorID id)229*c217d954SCole Faust Tensor *Graph::tensor(TensorID id)
230*c217d954SCole Faust {
231*c217d954SCole Faust     return (id >= _tensors.size()) ? nullptr : _tensors[id].get();
232*c217d954SCole Faust }
233*c217d954SCole Faust } // namespace graph
234*c217d954SCole Faust } // namespace arm_compute