1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2018-2019,2021 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_GRAPH_INODE_H 25*c217d954SCole Faust #define ARM_COMPUTE_GRAPH_INODE_H 26*c217d954SCole Faust 27*c217d954SCole Faust #include "arm_compute/core/Error.h" 28*c217d954SCole Faust #include "arm_compute/graph/LayerDescriptors.h" 29*c217d954SCole Faust #include "arm_compute/graph/TensorDescriptor.h" 30*c217d954SCole Faust #include "arm_compute/graph/Types.h" 31*c217d954SCole Faust 32*c217d954SCole Faust #include <list> 33*c217d954SCole Faust #include <set> 34*c217d954SCole Faust 35*c217d954SCole Faust namespace arm_compute 36*c217d954SCole Faust { 37*c217d954SCole Faust namespace graph 38*c217d954SCole Faust { 39*c217d954SCole Faust // Forward declarations 40*c217d954SCole Faust class Graph; 41*c217d954SCole Faust class Edge; 42*c217d954SCole Faust class INodeVisitor; 43*c217d954SCole Faust class Tensor; 44*c217d954SCole Faust 45*c217d954SCole Faust /** Node interface */ 46*c217d954SCole Faust class INode 47*c217d954SCole Faust { 48*c217d954SCole Faust public: 49*c217d954SCole Faust /** Constructor */ 50*c217d954SCole Faust INode(); 51*c217d954SCole Faust /** Destructor **/ 52*c217d954SCole Faust virtual ~INode() = default; 53*c217d954SCole Faust /** Prevent instances of this class from being copied (As this class contains pointers) */ 54*c217d954SCole Faust INode(const INode &) = delete; 55*c217d954SCole Faust /** Prevent instances of this class from being copy assigned (As this class contains pointers) */ 56*c217d954SCole Faust INode &operator=(const INode &) = delete; 57*c217d954SCole Faust /** Allow instances of this class to be moved */ 58*c217d954SCole Faust INode(INode &&) = default; 59*c217d954SCole Faust /** Allow instances of this class to be move assigned */ 60*c217d954SCole Faust INode &operator=(INode &&) = default; 61*c217d954SCole Faust /** Validate node 62*c217d954SCole Faust * 63*c217d954SCole Faust * @return Status containing any errors 64*c217d954SCole Faust */ 65*c217d954SCole Faust virtual Status validate() const; 66*c217d954SCole Faust /** Returns node's type 67*c217d954SCole Faust * 68*c217d954SCole Faust * @return Node's type 69*c217d954SCole Faust */ 70*c217d954SCole Faust virtual NodeType type() const = 0; 71*c217d954SCole Faust /** Accepts a node visitor 72*c217d954SCole Faust * 73*c217d954SCole Faust * @param[in] v Visitor to accept 74*c217d954SCole Faust */ 75*c217d954SCole Faust virtual void accept(INodeVisitor &v) = 0; 76*c217d954SCole Faust /** Forwards descriptor information to outputs if possible 77*c217d954SCole Faust * 78*c217d954SCole Faust * @return True if descriptor information could be forwarded otherwise false 79*c217d954SCole Faust */ 80*c217d954SCole Faust virtual bool forward_descriptors() = 0; 81*c217d954SCole Faust /** Calculates output configuration 82*c217d954SCole Faust * 83*c217d954SCole Faust * @param[in] idx Output index to configure 84*c217d954SCole Faust * 85*c217d954SCole Faust * @return Output descriptor configuration 86*c217d954SCole Faust */ 87*c217d954SCole Faust virtual TensorDescriptor configure_output(size_t idx) const = 0; 88*c217d954SCole Faust /** Returns node's name 89*c217d954SCole Faust * 90*c217d954SCole Faust * @return Node name 91*c217d954SCole Faust */ 92*c217d954SCole Faust std::string name() const; 93*c217d954SCole Faust /** Returns node's ID 94*c217d954SCole Faust * 95*c217d954SCole Faust * @return Node's ID 96*c217d954SCole Faust */ 97*c217d954SCole Faust NodeID id() const; 98*c217d954SCole Faust /** Returns node's Graph 99*c217d954SCole Faust * 100*c217d954SCole Faust * @return Node's graph 101*c217d954SCole Faust */ 102*c217d954SCole Faust const Graph *graph() const; 103*c217d954SCole Faust /** Returns node's Graph 104*c217d954SCole Faust * 105*c217d954SCole Faust * @return Node's graph 106*c217d954SCole Faust */ 107*c217d954SCole Faust Graph *graph(); 108*c217d954SCole Faust /** Sets the graph that this node is registered to 109*c217d954SCole Faust * 110*c217d954SCole Faust * @param[in] g Back reference to graph 111*c217d954SCole Faust */ 112*c217d954SCole Faust void set_graph(Graph *g); 113*c217d954SCole Faust /** Sets the node id 114*c217d954SCole Faust * 115*c217d954SCole Faust * @param[in] id Node id 116*c217d954SCole Faust */ 117*c217d954SCole Faust void set_id(NodeID id); 118*c217d954SCole Faust /** Sets common node parameters 119*c217d954SCole Faust * 120*c217d954SCole Faust * @param[in] common_params Common node parameters to set 121*c217d954SCole Faust */ 122*c217d954SCole Faust void set_common_node_parameters(NodeParams common_params); 123*c217d954SCole Faust /** Sets target preference 124*c217d954SCole Faust * 125*c217d954SCole Faust * @note This is not the target that the graph executor might choose, its just an indication 126*c217d954SCole Faust * 127*c217d954SCole Faust * @param[in] target Target preference 128*c217d954SCole Faust */ 129*c217d954SCole Faust void set_requested_target(Target target); 130*c217d954SCole Faust /** Sets the final execution target 131*c217d954SCole Faust * 132*c217d954SCole Faust * @note GraphManager might change this target 133*c217d954SCole Faust * 134*c217d954SCole Faust * @param[in] target Final execution target 135*c217d954SCole Faust */ 136*c217d954SCole Faust void set_assigned_target(Target target); 137*c217d954SCole Faust /** Sets the output tensor of at a given index 138*c217d954SCole Faust * 139*c217d954SCole Faust * @note All edges will get updated 140*c217d954SCole Faust * 141*c217d954SCole Faust * @param[in] tid Tensor ID 142*c217d954SCole Faust * @param[in] idx Output index 143*c217d954SCole Faust */ 144*c217d954SCole Faust void set_output_tensor(TensorID tid, size_t idx); 145*c217d954SCole Faust /** Returns inputs of the node 146*c217d954SCole Faust * 147*c217d954SCole Faust * @return Inputs of the node 148*c217d954SCole Faust */ 149*c217d954SCole Faust const std::vector<TensorID> &inputs() const; 150*c217d954SCole Faust /** Returns outputs of the node 151*c217d954SCole Faust * 152*c217d954SCole Faust * @return Outputs of the node 153*c217d954SCole Faust */ 154*c217d954SCole Faust const std::vector<TensorID> &outputs() const; 155*c217d954SCole Faust /** Returns input edge set 156*c217d954SCole Faust * 157*c217d954SCole Faust * @return Set of input edges 158*c217d954SCole Faust */ 159*c217d954SCole Faust const std::vector<EdgeID> &input_edges() const; 160*c217d954SCole Faust /** Returns output edge set 161*c217d954SCole Faust * 162*c217d954SCole Faust * @return Set of output edges 163*c217d954SCole Faust */ 164*c217d954SCole Faust const std::set<EdgeID> &output_edges() const; 165*c217d954SCole Faust /** Returns the tensor ID of a given input of the node 166*c217d954SCole Faust * 167*c217d954SCole Faust * @note Precondition : idx should be a valid input index 168*c217d954SCole Faust * 169*c217d954SCole Faust * @param[in] idx Index of the node input 170*c217d954SCole Faust * 171*c217d954SCole Faust * @return TensorID of the requested input 172*c217d954SCole Faust */ 173*c217d954SCole Faust TensorID input_id(size_t idx) const; 174*c217d954SCole Faust /** Returns the tensor ID of a given output of the node 175*c217d954SCole Faust * 176*c217d954SCole Faust * @note Precondition : idx should be a valid output index 177*c217d954SCole Faust * 178*c217d954SCole Faust * @param[in] idx Index of the node output 179*c217d954SCole Faust * 180*c217d954SCole Faust * @return TensorID of the requested output 181*c217d954SCole Faust */ 182*c217d954SCole Faust TensorID output_id(size_t idx) const; 183*c217d954SCole Faust /** Returns the tensor of a given input of the node 184*c217d954SCole Faust * 185*c217d954SCole Faust * @note Precondition : idx should be a valid input index 186*c217d954SCole Faust * 187*c217d954SCole Faust * @param[in] idx Index of the node input 188*c217d954SCole Faust * 189*c217d954SCole Faust * @return Tensor of the requested input 190*c217d954SCole Faust */ 191*c217d954SCole Faust Tensor *input(size_t idx) const; 192*c217d954SCole Faust /** Returns the tensor of a given output of the node 193*c217d954SCole Faust * 194*c217d954SCole Faust * @note Precondition : idx should be a valid output index 195*c217d954SCole Faust * 196*c217d954SCole Faust * @param[in] idx Index of the node output 197*c217d954SCole Faust * 198*c217d954SCole Faust * @return Tensor of the requested output 199*c217d954SCole Faust */ 200*c217d954SCole Faust Tensor *output(size_t idx) const; 201*c217d954SCole Faust /** Returns the edge ID of a given input of the node 202*c217d954SCole Faust * 203*c217d954SCole Faust * @note Precondition : idx should be a valid input index 204*c217d954SCole Faust * 205*c217d954SCole Faust * @param[in] idx Index of the node input 206*c217d954SCole Faust * 207*c217d954SCole Faust * @return EdgeID of the requested input 208*c217d954SCole Faust */ 209*c217d954SCole Faust EdgeID input_edge_id(size_t idx) const; 210*c217d954SCole Faust /** Returns the edge of a given input of the node 211*c217d954SCole Faust * 212*c217d954SCole Faust * @note Precondition : idx should be a valid input index 213*c217d954SCole Faust * 214*c217d954SCole Faust * @param[in] idx Index of the node input 215*c217d954SCole Faust * 216*c217d954SCole Faust * @return Edge of the requested input 217*c217d954SCole Faust */ 218*c217d954SCole Faust Edge *input_edge(size_t idx) const; 219*c217d954SCole Faust /** Returns number of inputs of the node 220*c217d954SCole Faust * 221*c217d954SCole Faust * @return Number of inputs 222*c217d954SCole Faust */ 223*c217d954SCole Faust size_t num_inputs() const; 224*c217d954SCole Faust /** Returns number of outputs of the node 225*c217d954SCole Faust * 226*c217d954SCole Faust * @return Number of outputs 227*c217d954SCole Faust */ 228*c217d954SCole Faust size_t num_outputs() const; 229*c217d954SCole Faust /** Returns common node parameters 230*c217d954SCole Faust * 231*c217d954SCole Faust * @return Common node parameters 232*c217d954SCole Faust */ 233*c217d954SCole Faust NodeParams common_node_params() const; 234*c217d954SCole Faust /** Returns requested target for this node 235*c217d954SCole Faust * 236*c217d954SCole Faust * @return Requested execution target 237*c217d954SCole Faust */ 238*c217d954SCole Faust Target requested_target() const; 239*c217d954SCole Faust /** Returns assigned target for this node 240*c217d954SCole Faust * 241*c217d954SCole Faust * @return Assigned target of this node 242*c217d954SCole Faust */ 243*c217d954SCole Faust Target assigned_target() const; 244*c217d954SCole Faust /** Post operator info list 245*c217d954SCole Faust * 246*c217d954SCole Faust * @return Post operator info list 247*c217d954SCole Faust */ 248*c217d954SCole Faust const std::list<std::unique_ptr<ConvPostOpInfo>> &post_op_info_list() const; 249*c217d954SCole Faust /** Post operator info list 250*c217d954SCole Faust * 251*c217d954SCole Faust * @return Post operator info list 252*c217d954SCole Faust */ 253*c217d954SCole Faust std::list<std::unique_ptr<ConvPostOpInfo>> &post_op_info_list(); 254*c217d954SCole Faust 255*c217d954SCole Faust protected: 256*c217d954SCole Faust friend class Graph; 257*c217d954SCole Faust 258*c217d954SCole Faust protected: 259*c217d954SCole Faust Graph *_graph; /**< Backward reference to graph owning the node */ 260*c217d954SCole Faust NodeID _id; /**< Node ID */ 261*c217d954SCole Faust NodeParams _common_params; /**< Node common params */ 262*c217d954SCole Faust std::vector<TensorID> _outputs; /**< Output of the node */ 263*c217d954SCole Faust std::vector<EdgeID> _input_edges; /**< Inputs edge set */ 264*c217d954SCole Faust std::set<EdgeID> _output_edges; /**< Output edge set */ 265*c217d954SCole Faust Target _assigned_target; /**< Assigned target by the Graph executor */ 266*c217d954SCole Faust std::list<std::unique_ptr<ConvPostOpInfo>> _post_op_info_list; /**< Post operator info list */ 267*c217d954SCole Faust }; 268*c217d954SCole Faust } // namespace graph 269*c217d954SCole Faust } // namespace arm_compute 270*c217d954SCole Faust #endif /* ARM_COMPUTE_GRAPH_INODE_H */ 271