1 /*
2 * Copyright (c) 2018,2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_compute/graph/INode.h"
25
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/graph/Edge.h"
28 #include "arm_compute/graph/Graph.h"
29 #include "arm_compute/graph/Tensor.h"
30
31 namespace arm_compute
32 {
33 namespace graph
34 {
35 // *INDENT-OFF*
36 // clang-format off
INode()37 INode::INode()
38 : _graph(nullptr), _id(EmptyNodeID), _common_params({ "", Target::UNSPECIFIED}),
39 _outputs(), _input_edges(), _output_edges(), _assigned_target(Target::UNSPECIFIED)
40 ,_post_op_info_list(std::list<std::unique_ptr<ConvPostOpInfo>> {})
41 {
42 }
43 // clang-format on
44 // *INDENT-ON*
45
validate() const46 Status INode::validate() const
47 {
48 return Status{};
49 }
50
set_graph(Graph * g)51 void INode::set_graph(Graph *g)
52 {
53 ARM_COMPUTE_ERROR_ON(g == nullptr);
54 _graph = g;
55 }
56
set_id(NodeID id)57 void INode::set_id(NodeID id)
58 {
59 _id = id;
60 }
61
set_common_node_parameters(NodeParams common_params)62 void INode::set_common_node_parameters(NodeParams common_params)
63 {
64 _common_params = std::move(common_params);
65 }
66
set_requested_target(Target target)67 void INode::set_requested_target(Target target)
68 {
69 _common_params.target = target;
70 }
71
set_assigned_target(Target target)72 void INode::set_assigned_target(Target target)
73 {
74 _assigned_target = target;
75 }
76
set_output_tensor(TensorID tid,size_t idx)77 void INode::set_output_tensor(TensorID tid, size_t idx)
78 {
79 if(tid != NullTensorID && (idx < _outputs.size()) && (_graph->tensor(tid) != nullptr))
80 {
81 ARM_COMPUTE_ERROR_ON(_graph == nullptr);
82 Tensor *updated_tensor = _graph->tensor(tid);
83 _outputs[idx] = tid;
84
85 // Set tensor to all output edges of the node
86 for(auto &output_edge_id : _output_edges)
87 {
88 auto output_edge = _graph->edge(output_edge_id);
89 if(output_edge != nullptr)
90 {
91 // Unbind edge from current tensor
92 auto current_output_tensor = output_edge->tensor();
93 current_output_tensor->unbind_edge(output_edge->id());
94
95 // Update tensor to edge and rebind tensor
96 output_edge->update_bound_tensor(updated_tensor);
97 updated_tensor->bind_edge(output_edge->id());
98 }
99 }
100 }
101 }
102
id() const103 NodeID INode::id() const
104 {
105 return _id;
106 }
107
name() const108 std::string INode::name() const
109 {
110 return _common_params.name;
111 }
112
graph() const113 const Graph *INode::graph() const
114 {
115 return _graph;
116 }
117
graph()118 Graph *INode::graph()
119 {
120 return _graph;
121 }
122
outputs() const123 const std::vector<TensorID> &INode::outputs() const
124 {
125 return _outputs;
126 }
127
input_edges() const128 const std::vector<EdgeID> &INode::input_edges() const
129 {
130 return _input_edges;
131 }
132
output_edges() const133 const std::set<EdgeID> &INode::output_edges() const
134 {
135 return _output_edges;
136 }
137
input_id(size_t idx) const138 TensorID INode::input_id(size_t idx) const
139 {
140 ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size());
141 Edge *e = _graph->edge(_input_edges[idx]);
142 return (e != nullptr) ? e->tensor_id() : NullTensorID;
143 }
144
output_id(size_t idx) const145 TensorID INode::output_id(size_t idx) const
146 {
147 ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
148 return _outputs[idx];
149 }
150
input(size_t idx) const151 Tensor *INode::input(size_t idx) const
152 {
153 ARM_COMPUTE_ERROR_ON(_graph == nullptr);
154 ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size());
155 Edge *e = _graph->edge(_input_edges[idx]);
156 return (e != nullptr) ? e->tensor() : nullptr;
157 }
158
output(size_t idx) const159 Tensor *INode::output(size_t idx) const
160 {
161 ARM_COMPUTE_ERROR_ON(_graph == nullptr);
162 ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
163 return _graph->tensor(_outputs[idx]);
164 }
165
input_edge_id(size_t idx) const166 EdgeID INode::input_edge_id(size_t idx) const
167 {
168 ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size());
169 return _input_edges[idx];
170 }
171
input_edge(size_t idx) const172 Edge *INode::input_edge(size_t idx) const
173 {
174 ARM_COMPUTE_ERROR_ON(_graph == nullptr);
175 ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size());
176 return _graph->edge(_input_edges[idx]);
177 }
178
num_inputs() const179 size_t INode::num_inputs() const
180 {
181 return _input_edges.size();
182 }
183
num_outputs() const184 size_t INode::num_outputs() const
185 {
186 return _outputs.size();
187 }
188
common_node_params() const189 NodeParams INode::common_node_params() const
190 {
191 return _common_params;
192 }
193
requested_target() const194 Target INode::requested_target() const
195 {
196 return _common_params.target;
197 }
198
assigned_target() const199 Target INode::assigned_target() const
200 {
201 return _assigned_target;
202 }
203
post_op_info_list() const204 const std::list<std::unique_ptr<ConvPostOpInfo>> &INode::post_op_info_list() const
205 {
206 return _post_op_info_list;
207 }
208
post_op_info_list()209 std::list<std::unique_ptr<ConvPostOpInfo>> &INode::post_op_info_list()
210 {
211 return _post_op_info_list;
212 }
213 } // namespace graph
214 } // namespace arm_compute