xref: /aosp_15_r20/external/ComputeLibrary/src/graph/INode.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018,2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "arm_compute/graph/INode.h"
25*c217d954SCole Faust 
26*c217d954SCole Faust #include "arm_compute/core/Error.h"
27*c217d954SCole Faust #include "arm_compute/graph/Edge.h"
28*c217d954SCole Faust #include "arm_compute/graph/Graph.h"
29*c217d954SCole Faust #include "arm_compute/graph/Tensor.h"
30*c217d954SCole Faust 
31*c217d954SCole Faust namespace arm_compute
32*c217d954SCole Faust {
33*c217d954SCole Faust namespace graph
34*c217d954SCole Faust {
35*c217d954SCole Faust // *INDENT-OFF*
36*c217d954SCole Faust // clang-format off
INode()37*c217d954SCole Faust INode::INode()
38*c217d954SCole Faust     : _graph(nullptr), _id(EmptyNodeID), _common_params({ "", Target::UNSPECIFIED}),
39*c217d954SCole Faust       _outputs(), _input_edges(), _output_edges(), _assigned_target(Target::UNSPECIFIED)
40*c217d954SCole Faust       ,_post_op_info_list(std::list<std::unique_ptr<ConvPostOpInfo>> {})
41*c217d954SCole Faust {
42*c217d954SCole Faust }
43*c217d954SCole Faust // clang-format on
44*c217d954SCole Faust // *INDENT-ON*
45*c217d954SCole Faust 
validate() const46*c217d954SCole Faust Status INode::validate() const
47*c217d954SCole Faust {
48*c217d954SCole Faust     return Status{};
49*c217d954SCole Faust }
50*c217d954SCole Faust 
set_graph(Graph * g)51*c217d954SCole Faust void INode::set_graph(Graph *g)
52*c217d954SCole Faust {
53*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(g == nullptr);
54*c217d954SCole Faust     _graph = g;
55*c217d954SCole Faust }
56*c217d954SCole Faust 
set_id(NodeID id)57*c217d954SCole Faust void INode::set_id(NodeID id)
58*c217d954SCole Faust {
59*c217d954SCole Faust     _id = id;
60*c217d954SCole Faust }
61*c217d954SCole Faust 
set_common_node_parameters(NodeParams common_params)62*c217d954SCole Faust void INode::set_common_node_parameters(NodeParams common_params)
63*c217d954SCole Faust {
64*c217d954SCole Faust     _common_params = std::move(common_params);
65*c217d954SCole Faust }
66*c217d954SCole Faust 
set_requested_target(Target target)67*c217d954SCole Faust void INode::set_requested_target(Target target)
68*c217d954SCole Faust {
69*c217d954SCole Faust     _common_params.target = target;
70*c217d954SCole Faust }
71*c217d954SCole Faust 
set_assigned_target(Target target)72*c217d954SCole Faust void INode::set_assigned_target(Target target)
73*c217d954SCole Faust {
74*c217d954SCole Faust     _assigned_target = target;
75*c217d954SCole Faust }
76*c217d954SCole Faust 
set_output_tensor(TensorID tid,size_t idx)77*c217d954SCole Faust void INode::set_output_tensor(TensorID tid, size_t idx)
78*c217d954SCole Faust {
79*c217d954SCole Faust     if(tid != NullTensorID && (idx < _outputs.size()) && (_graph->tensor(tid) != nullptr))
80*c217d954SCole Faust     {
81*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON(_graph == nullptr);
82*c217d954SCole Faust         Tensor *updated_tensor = _graph->tensor(tid);
83*c217d954SCole Faust         _outputs[idx]          = tid;
84*c217d954SCole Faust 
85*c217d954SCole Faust         // Set tensor to all output edges of the node
86*c217d954SCole Faust         for(auto &output_edge_id : _output_edges)
87*c217d954SCole Faust         {
88*c217d954SCole Faust             auto output_edge = _graph->edge(output_edge_id);
89*c217d954SCole Faust             if(output_edge != nullptr)
90*c217d954SCole Faust             {
91*c217d954SCole Faust                 // Unbind edge from current tensor
92*c217d954SCole Faust                 auto current_output_tensor = output_edge->tensor();
93*c217d954SCole Faust                 current_output_tensor->unbind_edge(output_edge->id());
94*c217d954SCole Faust 
95*c217d954SCole Faust                 // Update tensor to edge and rebind tensor
96*c217d954SCole Faust                 output_edge->update_bound_tensor(updated_tensor);
97*c217d954SCole Faust                 updated_tensor->bind_edge(output_edge->id());
98*c217d954SCole Faust             }
99*c217d954SCole Faust         }
100*c217d954SCole Faust     }
101*c217d954SCole Faust }
102*c217d954SCole Faust 
id() const103*c217d954SCole Faust NodeID INode::id() const
104*c217d954SCole Faust {
105*c217d954SCole Faust     return _id;
106*c217d954SCole Faust }
107*c217d954SCole Faust 
name() const108*c217d954SCole Faust std::string INode::name() const
109*c217d954SCole Faust {
110*c217d954SCole Faust     return _common_params.name;
111*c217d954SCole Faust }
112*c217d954SCole Faust 
graph() const113*c217d954SCole Faust const Graph *INode::graph() const
114*c217d954SCole Faust {
115*c217d954SCole Faust     return _graph;
116*c217d954SCole Faust }
117*c217d954SCole Faust 
graph()118*c217d954SCole Faust Graph *INode::graph()
119*c217d954SCole Faust {
120*c217d954SCole Faust     return _graph;
121*c217d954SCole Faust }
122*c217d954SCole Faust 
outputs() const123*c217d954SCole Faust const std::vector<TensorID> &INode::outputs() const
124*c217d954SCole Faust {
125*c217d954SCole Faust     return _outputs;
126*c217d954SCole Faust }
127*c217d954SCole Faust 
input_edges() const128*c217d954SCole Faust const std::vector<EdgeID> &INode::input_edges() const
129*c217d954SCole Faust {
130*c217d954SCole Faust     return _input_edges;
131*c217d954SCole Faust }
132*c217d954SCole Faust 
output_edges() const133*c217d954SCole Faust const std::set<EdgeID> &INode::output_edges() const
134*c217d954SCole Faust {
135*c217d954SCole Faust     return _output_edges;
136*c217d954SCole Faust }
137*c217d954SCole Faust 
input_id(size_t idx) const138*c217d954SCole Faust TensorID INode::input_id(size_t idx) const
139*c217d954SCole Faust {
140*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size());
141*c217d954SCole Faust     Edge *e = _graph->edge(_input_edges[idx]);
142*c217d954SCole Faust     return (e != nullptr) ? e->tensor_id() : NullTensorID;
143*c217d954SCole Faust }
144*c217d954SCole Faust 
output_id(size_t idx) const145*c217d954SCole Faust TensorID INode::output_id(size_t idx) const
146*c217d954SCole Faust {
147*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
148*c217d954SCole Faust     return _outputs[idx];
149*c217d954SCole Faust }
150*c217d954SCole Faust 
input(size_t idx) const151*c217d954SCole Faust Tensor *INode::input(size_t idx) const
152*c217d954SCole Faust {
153*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(_graph == nullptr);
154*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size());
155*c217d954SCole Faust     Edge *e = _graph->edge(_input_edges[idx]);
156*c217d954SCole Faust     return (e != nullptr) ? e->tensor() : nullptr;
157*c217d954SCole Faust }
158*c217d954SCole Faust 
output(size_t idx) const159*c217d954SCole Faust Tensor *INode::output(size_t idx) const
160*c217d954SCole Faust {
161*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(_graph == nullptr);
162*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(idx >= _outputs.size());
163*c217d954SCole Faust     return _graph->tensor(_outputs[idx]);
164*c217d954SCole Faust }
165*c217d954SCole Faust 
input_edge_id(size_t idx) const166*c217d954SCole Faust EdgeID INode::input_edge_id(size_t idx) const
167*c217d954SCole Faust {
168*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size());
169*c217d954SCole Faust     return _input_edges[idx];
170*c217d954SCole Faust }
171*c217d954SCole Faust 
input_edge(size_t idx) const172*c217d954SCole Faust Edge *INode::input_edge(size_t idx) const
173*c217d954SCole Faust {
174*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(_graph == nullptr);
175*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(idx >= _input_edges.size());
176*c217d954SCole Faust     return _graph->edge(_input_edges[idx]);
177*c217d954SCole Faust }
178*c217d954SCole Faust 
num_inputs() const179*c217d954SCole Faust size_t INode::num_inputs() const
180*c217d954SCole Faust {
181*c217d954SCole Faust     return _input_edges.size();
182*c217d954SCole Faust }
183*c217d954SCole Faust 
num_outputs() const184*c217d954SCole Faust size_t INode::num_outputs() const
185*c217d954SCole Faust {
186*c217d954SCole Faust     return _outputs.size();
187*c217d954SCole Faust }
188*c217d954SCole Faust 
common_node_params() const189*c217d954SCole Faust NodeParams INode::common_node_params() const
190*c217d954SCole Faust {
191*c217d954SCole Faust     return _common_params;
192*c217d954SCole Faust }
193*c217d954SCole Faust 
requested_target() const194*c217d954SCole Faust Target INode::requested_target() const
195*c217d954SCole Faust {
196*c217d954SCole Faust     return _common_params.target;
197*c217d954SCole Faust }
198*c217d954SCole Faust 
assigned_target() const199*c217d954SCole Faust Target INode::assigned_target() const
200*c217d954SCole Faust {
201*c217d954SCole Faust     return _assigned_target;
202*c217d954SCole Faust }
203*c217d954SCole Faust 
post_op_info_list() const204*c217d954SCole Faust const std::list<std::unique_ptr<ConvPostOpInfo>> &INode::post_op_info_list() const
205*c217d954SCole Faust {
206*c217d954SCole Faust     return _post_op_info_list;
207*c217d954SCole Faust }
208*c217d954SCole Faust 
post_op_info_list()209*c217d954SCole Faust std::list<std::unique_ptr<ConvPostOpInfo>> &INode::post_op_info_list()
210*c217d954SCole Faust {
211*c217d954SCole Faust     return _post_op_info_list;
212*c217d954SCole Faust }
213*c217d954SCole Faust } // namespace graph
214*c217d954SCole Faust } // namespace arm_compute