1*c217d954SCole Faust /*
2*c217d954SCole Faust * Copyright (c) 2018-2020 Arm Limited.
3*c217d954SCole Faust *
4*c217d954SCole Faust * SPDX-License-Identifier: MIT
5*c217d954SCole Faust *
6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust *
13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust * copies or substantial portions of the Software.
15*c217d954SCole Faust *
16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust * SOFTWARE.
23*c217d954SCole Faust */
24*c217d954SCole Faust #include "arm_compute/graph/Graph.h"
25*c217d954SCole Faust
26*c217d954SCole Faust namespace arm_compute
27*c217d954SCole Faust {
28*c217d954SCole Faust namespace graph
29*c217d954SCole Faust {
Graph(GraphID id,std::string name)30*c217d954SCole Faust Graph::Graph(GraphID id, std::string name)
31*c217d954SCole Faust : _id(id), _name(std::move(name)), _nodes(), _edges(), _tensors(), _tagged_nodes(), _mtx()
32*c217d954SCole Faust {
33*c217d954SCole Faust }
34*c217d954SCole Faust
remove_node(NodeID nid)35*c217d954SCole Faust bool Graph::remove_node(NodeID nid)
36*c217d954SCole Faust {
37*c217d954SCole Faust if(nid >= _nodes.size())
38*c217d954SCole Faust {
39*c217d954SCole Faust return false;
40*c217d954SCole Faust }
41*c217d954SCole Faust
42*c217d954SCole Faust std::unique_ptr<INode> &node = _nodes[nid];
43*c217d954SCole Faust
44*c217d954SCole Faust if(node)
45*c217d954SCole Faust {
46*c217d954SCole Faust // Remove input connections
47*c217d954SCole Faust for(auto &input_eid : node->_input_edges)
48*c217d954SCole Faust {
49*c217d954SCole Faust remove_connection(input_eid);
50*c217d954SCole Faust }
51*c217d954SCole Faust
52*c217d954SCole Faust // Remove output connections
53*c217d954SCole Faust std::set<EdgeID> output_edges_copy = node->output_edges();
54*c217d954SCole Faust for(auto &output_eid : output_edges_copy)
55*c217d954SCole Faust {
56*c217d954SCole Faust remove_connection(output_eid);
57*c217d954SCole Faust }
58*c217d954SCole Faust
59*c217d954SCole Faust // Remove nid from tagged nodes
60*c217d954SCole Faust std::vector<NodeID> &tnodes = _tagged_nodes.at(node->type());
61*c217d954SCole Faust tnodes.erase(std::remove(tnodes.begin(), tnodes.end(), nid), tnodes.end());
62*c217d954SCole Faust }
63*c217d954SCole Faust
64*c217d954SCole Faust node = nullptr;
65*c217d954SCole Faust
66*c217d954SCole Faust return true;
67*c217d954SCole Faust }
68*c217d954SCole Faust
add_connection(NodeID source,size_t source_idx,NodeID sink,size_t sink_idx)69*c217d954SCole Faust EdgeID Graph::add_connection(NodeID source, size_t source_idx, NodeID sink, size_t sink_idx)
70*c217d954SCole Faust {
71*c217d954SCole Faust arm_compute::lock_guard<arm_compute::Mutex> lock(_mtx);
72*c217d954SCole Faust
73*c217d954SCole Faust // Check if node index is valid, if node exists and finally if the connection index is valid
74*c217d954SCole Faust ARM_COMPUTE_ERROR_ON((source >= _nodes.size()) || (_nodes[source] == nullptr) || (source_idx >= _nodes[source]->num_outputs()));
75*c217d954SCole Faust ARM_COMPUTE_ERROR_ON((sink >= _nodes.size()) || (_nodes[sink] == nullptr) || (sink_idx >= _nodes[sink]->num_inputs()));
76*c217d954SCole Faust
77*c217d954SCole Faust // Get nodes
78*c217d954SCole Faust std::unique_ptr<INode> &source_node = _nodes[source];
79*c217d954SCole Faust std::unique_ptr<INode> &sink_node = _nodes[sink];
80*c217d954SCole Faust
81*c217d954SCole Faust // Check for duplicate connections (Check only sink node)
82*c217d954SCole Faust Edge *sink_node_edge = sink_node->input_edge(sink_idx);
83*c217d954SCole Faust if((sink_node_edge != nullptr) && (sink_node_edge->producer_id() == source) && (sink_node_edge->producer_idx() == source_idx)
84*c217d954SCole Faust && (sink_node_edge->consumer_id() == sink) && (sink_node_edge->consumer_idx() == sink_idx))
85*c217d954SCole Faust {
86*c217d954SCole Faust return sink_node_edge->id();
87*c217d954SCole Faust }
88*c217d954SCole Faust
89*c217d954SCole Faust // Check if there is already a tensor associated with output if not create one
90*c217d954SCole Faust TensorID tid = source_node->output_id(source_idx);
91*c217d954SCole Faust if(tid == NullTensorID)
92*c217d954SCole Faust {
93*c217d954SCole Faust tid = create_tensor();
94*c217d954SCole Faust }
95*c217d954SCole Faust std::unique_ptr<Tensor> &tensor = _tensors[tid];
96*c217d954SCole Faust
97*c217d954SCole Faust // Create connections
98*c217d954SCole Faust EdgeID eid = _edges.size();
99*c217d954SCole Faust auto connection = std::make_unique<Edge>(eid, source_node.get(), source_idx, sink_node.get(), sink_idx, tensor.get());
100*c217d954SCole Faust _edges.push_back(std::move(connection));
101*c217d954SCole Faust
102*c217d954SCole Faust // Add connections to source and sink nodes
103*c217d954SCole Faust source_node->_output_edges.insert(eid);
104*c217d954SCole Faust sink_node->_input_edges[sink_idx] = eid;
105*c217d954SCole Faust
106*c217d954SCole Faust // Set tensor output node
107*c217d954SCole Faust source_node->_outputs[source_idx] = tid;
108*c217d954SCole Faust
109*c217d954SCole Faust // Bind tensor to the edge
110*c217d954SCole Faust tensor->bind_edge(eid);
111*c217d954SCole Faust
112*c217d954SCole Faust // Try and propagate shapes in sink node
113*c217d954SCole Faust sink_node->forward_descriptors();
114*c217d954SCole Faust
115*c217d954SCole Faust return eid;
116*c217d954SCole Faust }
117*c217d954SCole Faust
remove_connection(EdgeID eid)118*c217d954SCole Faust bool Graph::remove_connection(EdgeID eid)
119*c217d954SCole Faust {
120*c217d954SCole Faust if(eid >= _edges.size())
121*c217d954SCole Faust {
122*c217d954SCole Faust return false;
123*c217d954SCole Faust }
124*c217d954SCole Faust
125*c217d954SCole Faust std::unique_ptr<Edge> &edge = _edges[eid];
126*c217d954SCole Faust
127*c217d954SCole Faust // Remove node connections
128*c217d954SCole Faust if(edge != nullptr)
129*c217d954SCole Faust {
130*c217d954SCole Faust // Get tensor bound to the edge
131*c217d954SCole Faust if(edge->tensor() != nullptr)
132*c217d954SCole Faust {
133*c217d954SCole Faust edge->tensor()->unbind_edge(eid);
134*c217d954SCole Faust }
135*c217d954SCole Faust
136*c217d954SCole Faust // Remove edges from source node
137*c217d954SCole Faust if(edge->producer() != nullptr)
138*c217d954SCole Faust {
139*c217d954SCole Faust edge->producer()->_output_edges.erase(eid);
140*c217d954SCole Faust }
141*c217d954SCole Faust
142*c217d954SCole Faust // Remove edges from sink node
143*c217d954SCole Faust if((edge->consumer() != nullptr) && (edge->consumer_idx() < edge->consumer()->_input_edges.size()))
144*c217d954SCole Faust {
145*c217d954SCole Faust edge->consumer()->_input_edges[edge->consumer_idx()] = EmptyEdgeID;
146*c217d954SCole Faust }
147*c217d954SCole Faust }
148*c217d954SCole Faust
149*c217d954SCole Faust // Clear edge
150*c217d954SCole Faust edge = nullptr;
151*c217d954SCole Faust
152*c217d954SCole Faust return true;
153*c217d954SCole Faust }
154*c217d954SCole Faust
create_tensor(const TensorDescriptor & desc)155*c217d954SCole Faust TensorID Graph::create_tensor(const TensorDescriptor &desc)
156*c217d954SCole Faust {
157*c217d954SCole Faust TensorID tid = _tensors.size();
158*c217d954SCole Faust auto tensor = std::make_unique<Tensor>(tid, desc);
159*c217d954SCole Faust _tensors.push_back(std::move(tensor));
160*c217d954SCole Faust
161*c217d954SCole Faust return tid;
162*c217d954SCole Faust }
163*c217d954SCole Faust
name() const164*c217d954SCole Faust std::string Graph::name() const
165*c217d954SCole Faust {
166*c217d954SCole Faust return _name;
167*c217d954SCole Faust }
168*c217d954SCole Faust
id() const169*c217d954SCole Faust GraphID Graph::id() const
170*c217d954SCole Faust {
171*c217d954SCole Faust return _id;
172*c217d954SCole Faust }
173*c217d954SCole Faust
nodes(NodeType type)174*c217d954SCole Faust const std::vector<NodeID> &Graph::nodes(NodeType type)
175*c217d954SCole Faust {
176*c217d954SCole Faust return _tagged_nodes[type];
177*c217d954SCole Faust }
178*c217d954SCole Faust
nodes()179*c217d954SCole Faust std::vector<std::unique_ptr<INode>> &Graph::nodes()
180*c217d954SCole Faust {
181*c217d954SCole Faust return _nodes;
182*c217d954SCole Faust }
183*c217d954SCole Faust
nodes() const184*c217d954SCole Faust const std::vector<std::unique_ptr<INode>> &Graph::nodes() const
185*c217d954SCole Faust {
186*c217d954SCole Faust return _nodes;
187*c217d954SCole Faust }
188*c217d954SCole Faust
edges() const189*c217d954SCole Faust const std::vector<std::unique_ptr<Edge>> &Graph::edges() const
190*c217d954SCole Faust {
191*c217d954SCole Faust return _edges;
192*c217d954SCole Faust }
193*c217d954SCole Faust
tensors()194*c217d954SCole Faust std::vector<std::unique_ptr<Tensor>> &Graph::tensors()
195*c217d954SCole Faust {
196*c217d954SCole Faust return _tensors;
197*c217d954SCole Faust }
198*c217d954SCole Faust
tensors() const199*c217d954SCole Faust const std::vector<std::unique_ptr<Tensor>> &Graph::tensors() const
200*c217d954SCole Faust {
201*c217d954SCole Faust return _tensors;
202*c217d954SCole Faust }
203*c217d954SCole Faust
node(NodeID id) const204*c217d954SCole Faust const INode *Graph::node(NodeID id) const
205*c217d954SCole Faust {
206*c217d954SCole Faust return (id >= _nodes.size()) ? nullptr : _nodes[id].get();
207*c217d954SCole Faust }
208*c217d954SCole Faust
node(NodeID id)209*c217d954SCole Faust INode *Graph::node(NodeID id)
210*c217d954SCole Faust {
211*c217d954SCole Faust return (id >= _nodes.size()) ? nullptr : _nodes[id].get();
212*c217d954SCole Faust }
213*c217d954SCole Faust
edge(EdgeID id) const214*c217d954SCole Faust const Edge *Graph::edge(EdgeID id) const
215*c217d954SCole Faust {
216*c217d954SCole Faust return (id >= _edges.size()) ? nullptr : _edges[id].get();
217*c217d954SCole Faust }
218*c217d954SCole Faust
edge(EdgeID id)219*c217d954SCole Faust Edge *Graph::edge(EdgeID id)
220*c217d954SCole Faust {
221*c217d954SCole Faust return (id >= _edges.size()) ? nullptr : _edges[id].get();
222*c217d954SCole Faust }
223*c217d954SCole Faust
tensor(TensorID id) const224*c217d954SCole Faust const Tensor *Graph::tensor(TensorID id) const
225*c217d954SCole Faust {
226*c217d954SCole Faust return (id >= _tensors.size()) ? nullptr : _tensors[id].get();
227*c217d954SCole Faust }
228*c217d954SCole Faust
tensor(TensorID id)229*c217d954SCole Faust Tensor *Graph::tensor(TensorID id)
230*c217d954SCole Faust {
231*c217d954SCole Faust return (id >= _tensors.size()) ? nullptr : _tensors[id].get();
232*c217d954SCole Faust }
233*c217d954SCole Faust } // namespace graph
234*c217d954SCole Faust } // namespace arm_compute