xref: /aosp_15_r20/external/ComputeLibrary/arm_compute/graph/frontend/Layers.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2018-2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #ifndef ARM_COMPUTE_GRAPH_LAYERS_H
25*c217d954SCole Faust #define ARM_COMPUTE_GRAPH_LAYERS_H
26*c217d954SCole Faust 
27*c217d954SCole Faust #include "arm_compute/graph/GraphBuilder.h"
28*c217d954SCole Faust #include "arm_compute/graph/Types.h"
29*c217d954SCole Faust #include "arm_compute/graph/frontend/ILayer.h"
30*c217d954SCole Faust #include "arm_compute/graph/frontend/IStream.h"
31*c217d954SCole Faust #include "arm_compute/graph/frontend/SubStream.h"
32*c217d954SCole Faust 
33*c217d954SCole Faust #include "arm_compute/core/utils/misc/Utility.h"
34*c217d954SCole Faust 
35*c217d954SCole Faust #include <memory>
36*c217d954SCole Faust #include <string>
37*c217d954SCole Faust 
38*c217d954SCole Faust namespace arm_compute
39*c217d954SCole Faust {
40*c217d954SCole Faust namespace graph
41*c217d954SCole Faust {
42*c217d954SCole Faust namespace frontend
43*c217d954SCole Faust {
44*c217d954SCole Faust /** Input Layer */
45*c217d954SCole Faust class InputLayer final : public ILayer
46*c217d954SCole Faust {
47*c217d954SCole Faust public:
48*c217d954SCole Faust     /** Construct an input layer.
49*c217d954SCole Faust      *
50*c217d954SCole Faust      * @param[in] desc     Description of input tensor.
51*c217d954SCole Faust      * @param[in] accessor Accessor to get input tensor data from.
52*c217d954SCole Faust      */
InputLayer(TensorDescriptor desc,ITensorAccessorUPtr accessor)53*c217d954SCole Faust     InputLayer(TensorDescriptor desc, ITensorAccessorUPtr accessor)
54*c217d954SCole Faust         : _desc(desc), _accessor(std::move(accessor))
55*c217d954SCole Faust     {
56*c217d954SCole Faust     }
57*c217d954SCole Faust 
create_layer(IStream & s)58*c217d954SCole Faust     NodeID create_layer(IStream &s) override
59*c217d954SCole Faust     {
60*c217d954SCole Faust         NodeParams common_params = { name(), s.hints().target_hint };
61*c217d954SCole Faust         return GraphBuilder::add_input_node(s.graph(), common_params, _desc, std::move(_accessor));
62*c217d954SCole Faust     }
63*c217d954SCole Faust 
64*c217d954SCole Faust private:
65*c217d954SCole Faust     TensorDescriptor    _desc;
66*c217d954SCole Faust     ITensorAccessorUPtr _accessor;
67*c217d954SCole Faust };
68*c217d954SCole Faust 
69*c217d954SCole Faust /** Constant Layer */
70*c217d954SCole Faust class ConstantLayer final : public ILayer
71*c217d954SCole Faust {
72*c217d954SCole Faust public:
73*c217d954SCole Faust     /** Construct a constant layer.
74*c217d954SCole Faust      *
75*c217d954SCole Faust      * @param[in] desc     Description of input tensor.
76*c217d954SCole Faust      * @param[in] accessor Accessor to get input tensor data from.
77*c217d954SCole Faust      */
ConstantLayer(TensorDescriptor desc,ITensorAccessorUPtr accessor)78*c217d954SCole Faust     ConstantLayer(TensorDescriptor desc, ITensorAccessorUPtr accessor)
79*c217d954SCole Faust         : _desc(desc), _accessor(std::move(accessor))
80*c217d954SCole Faust     {
81*c217d954SCole Faust     }
82*c217d954SCole Faust 
create_layer(IStream & s)83*c217d954SCole Faust     NodeID create_layer(IStream &s) override
84*c217d954SCole Faust     {
85*c217d954SCole Faust         NodeParams common_params = { name(), s.hints().target_hint };
86*c217d954SCole Faust         return GraphBuilder::add_const_node(s.graph(), common_params, _desc, std::move(_accessor));
87*c217d954SCole Faust     }
88*c217d954SCole Faust 
89*c217d954SCole Faust private:
90*c217d954SCole Faust     TensorDescriptor    _desc;
91*c217d954SCole Faust     ITensorAccessorUPtr _accessor;
92*c217d954SCole Faust };
93*c217d954SCole Faust 
94*c217d954SCole Faust /** Output Layer */
95*c217d954SCole Faust class OutputLayer final : public ILayer
96*c217d954SCole Faust {
97*c217d954SCole Faust public:
98*c217d954SCole Faust     /** Construct an output layer.
99*c217d954SCole Faust      *
100*c217d954SCole Faust      * @param[in] accessor       Accessor to give output tensor data to.
101*c217d954SCole Faust      * @param[in] connection_idx (Optional) Input connection index
102*c217d954SCole Faust      */
103*c217d954SCole Faust     OutputLayer(ITensorAccessorUPtr accessor, unsigned int connection_idx = 0)
_accessor(std::move (accessor))104*c217d954SCole Faust         : _accessor(std::move(accessor)), _connection_idx(connection_idx)
105*c217d954SCole Faust     {
106*c217d954SCole Faust     }
107*c217d954SCole Faust 
create_layer(IStream & s)108*c217d954SCole Faust     NodeID create_layer(IStream &s) override
109*c217d954SCole Faust     {
110*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
111*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), _connection_idx };
112*c217d954SCole Faust         return GraphBuilder::add_output_node(s.graph(), common_params, input, std::move(_accessor));
113*c217d954SCole Faust     }
114*c217d954SCole Faust 
115*c217d954SCole Faust private:
116*c217d954SCole Faust     ITensorAccessorUPtr _accessor;
117*c217d954SCole Faust     unsigned int        _connection_idx;
118*c217d954SCole Faust };
119*c217d954SCole Faust 
120*c217d954SCole Faust /** Activation Layer */
121*c217d954SCole Faust class ActivationLayer final : public ILayer
122*c217d954SCole Faust {
123*c217d954SCole Faust public:
124*c217d954SCole Faust     /** Construct an activation layer.
125*c217d954SCole Faust      *
126*c217d954SCole Faust      * @param[in] act_info       Activation information
127*c217d954SCole Faust      * @param[in] out_quant_info (Optional) Output quantization info
128*c217d954SCole Faust      */
129*c217d954SCole Faust     ActivationLayer(ActivationLayerInfo    act_info,
130*c217d954SCole Faust                     const QuantizationInfo out_quant_info = QuantizationInfo())
_act_info(act_info)131*c217d954SCole Faust         : _act_info(act_info),
132*c217d954SCole Faust           _out_quant_info(std::move(out_quant_info))
133*c217d954SCole Faust     {
134*c217d954SCole Faust     }
135*c217d954SCole Faust 
create_layer(IStream & s)136*c217d954SCole Faust     NodeID create_layer(IStream &s) override
137*c217d954SCole Faust     {
138*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
139*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
140*c217d954SCole Faust         return GraphBuilder::add_activation_node(s.graph(), common_params, input, _act_info, std::move(_out_quant_info));
141*c217d954SCole Faust     }
142*c217d954SCole Faust 
143*c217d954SCole Faust private:
144*c217d954SCole Faust     ActivationLayerInfo    _act_info;
145*c217d954SCole Faust     const QuantizationInfo _out_quant_info;
146*c217d954SCole Faust };
147*c217d954SCole Faust 
148*c217d954SCole Faust /** ArgMinMax Layer */
149*c217d954SCole Faust class ArgMinMaxLayer final : public ILayer
150*c217d954SCole Faust {
151*c217d954SCole Faust public:
152*c217d954SCole Faust     /** Construct an activation layer.
153*c217d954SCole Faust      *
154*c217d954SCole Faust      * @param[in] op             Reduction Operation: min or max
155*c217d954SCole Faust      * @param[in] axis           Axis to perform reduction along
156*c217d954SCole Faust      * @param[in] out_data_type  (Optional) Output tensor data type
157*c217d954SCole Faust      * @param[in] out_quant_info (Optional) Output quantization info
158*c217d954SCole Faust      */
159*c217d954SCole Faust     ArgMinMaxLayer(ReductionOperation     op,
160*c217d954SCole Faust                    unsigned int           axis,
161*c217d954SCole Faust                    DataType               out_data_type  = DataType::UNKNOWN,
162*c217d954SCole Faust                    const QuantizationInfo out_quant_info = QuantizationInfo())
_op(op)163*c217d954SCole Faust         : _op(op),
164*c217d954SCole Faust           _axis(axis),
165*c217d954SCole Faust           _out_data_type(out_data_type),
166*c217d954SCole Faust           _out_quant_info(std::move(out_quant_info))
167*c217d954SCole Faust     {
168*c217d954SCole Faust     }
169*c217d954SCole Faust 
170*c217d954SCole Faust     /** Create layer and add to the given stream.
171*c217d954SCole Faust      *
172*c217d954SCole Faust      * @param[in] s Stream to add layer to.
173*c217d954SCole Faust      *
174*c217d954SCole Faust      * @return ID of the created node.
175*c217d954SCole Faust      */
create_layer(IStream & s)176*c217d954SCole Faust     NodeID create_layer(IStream &s) override
177*c217d954SCole Faust     {
178*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
179*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
180*c217d954SCole Faust         return GraphBuilder::add_arg_min_max_node(s.graph(), common_params, input, _op, _axis, _out_data_type, std::move(_out_quant_info));
181*c217d954SCole Faust     }
182*c217d954SCole Faust 
183*c217d954SCole Faust private:
184*c217d954SCole Faust     ReductionOperation _op;
185*c217d954SCole Faust     unsigned int       _axis;
186*c217d954SCole Faust     DataType           _out_data_type;
187*c217d954SCole Faust     QuantizationInfo   _out_quant_info;
188*c217d954SCole Faust };
189*c217d954SCole Faust 
190*c217d954SCole Faust /** Batchnormalization Layer */
191*c217d954SCole Faust class BatchNormalizationLayer final : public ILayer
192*c217d954SCole Faust {
193*c217d954SCole Faust public:
194*c217d954SCole Faust     /** Construct a batch normalization layer.
195*c217d954SCole Faust      *
196*c217d954SCole Faust      * @param[in] mean    Accessor to get mean tensor data from.
197*c217d954SCole Faust      * @param[in] var     Accessor to get var tensor data from.
198*c217d954SCole Faust      * @param[in] gamma   (Optional) Accessor to get gamma tensor data from. Default: nullptr.
199*c217d954SCole Faust      * @param[in] beta    (Optional) Accessor to get beta tensor data from. Default: nullptr.
200*c217d954SCole Faust      * @param[in] epsilon (Optional) Epsilon value. Default: 0.001.
201*c217d954SCole Faust      */
202*c217d954SCole Faust     BatchNormalizationLayer(ITensorAccessorUPtr mean,
203*c217d954SCole Faust                             ITensorAccessorUPtr var,
204*c217d954SCole Faust                             ITensorAccessorUPtr gamma   = nullptr,
205*c217d954SCole Faust                             ITensorAccessorUPtr beta    = nullptr,
206*c217d954SCole Faust                             float               epsilon = 0.001f)
_mean(std::move (mean))207*c217d954SCole Faust         : _mean(std::move(mean)), _var(std::move(var)), _gamma(std::move(gamma)), _beta(std::move(beta)), _epsilon(epsilon)
208*c217d954SCole Faust     {
209*c217d954SCole Faust     }
210*c217d954SCole Faust 
create_layer(IStream & s)211*c217d954SCole Faust     NodeID create_layer(IStream &s) override
212*c217d954SCole Faust     {
213*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON(_mean == nullptr);
214*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON(_var == nullptr);
215*c217d954SCole Faust 
216*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
217*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
218*c217d954SCole Faust         return GraphBuilder::add_batch_normalization_node(s.graph(), common_params, input, _epsilon,
219*c217d954SCole Faust                                                           std::move(_mean), std::move(_var), std::move(_beta), std::move(_gamma));
220*c217d954SCole Faust     }
221*c217d954SCole Faust 
222*c217d954SCole Faust private:
223*c217d954SCole Faust     ITensorAccessorUPtr _mean;
224*c217d954SCole Faust     ITensorAccessorUPtr _var;
225*c217d954SCole Faust     ITensorAccessorUPtr _gamma;
226*c217d954SCole Faust     ITensorAccessorUPtr _beta;
227*c217d954SCole Faust     float               _epsilon;
228*c217d954SCole Faust };
229*c217d954SCole Faust 
230*c217d954SCole Faust /** Bounding Box Transform Layer */
231*c217d954SCole Faust class BoundingBoxTransformLayer final : public ILayer
232*c217d954SCole Faust {
233*c217d954SCole Faust public:
234*c217d954SCole Faust     /** Construct a bounding box transform layer.
235*c217d954SCole Faust      *
236*c217d954SCole Faust      * @param[in] sub_stream_input  Graph sub-stream for the input
237*c217d954SCole Faust      * @param[in] sub_stream_deltas Graph sub-stream for the deltas
238*c217d954SCole Faust      * @param[in] info              Contains BoundingBox operation information described in @ref BoundingBoxTransformInfo.
239*c217d954SCole Faust      */
BoundingBoxTransformLayer(SubStream && sub_stream_input,SubStream && sub_stream_deltas,BoundingBoxTransformInfo info)240*c217d954SCole Faust     BoundingBoxTransformLayer(SubStream &&sub_stream_input, SubStream &&sub_stream_deltas, BoundingBoxTransformInfo info)
241*c217d954SCole Faust         : _ss_input(sub_stream_input), _ss_deltas(sub_stream_deltas), _bbox_info(info)
242*c217d954SCole Faust     {
243*c217d954SCole Faust     }
244*c217d954SCole Faust 
245*c217d954SCole Faust     /** Create layer and add to the given stream.
246*c217d954SCole Faust      *
247*c217d954SCole Faust      * @param[in] s Stream to add layer to.
248*c217d954SCole Faust      *
249*c217d954SCole Faust      * @return ID of the created node.
250*c217d954SCole Faust      */
create_layer(IStream & s)251*c217d954SCole Faust     NodeID create_layer(IStream &s) override
252*c217d954SCole Faust     {
253*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
254*c217d954SCole Faust         NodeIdxPair input         = { _ss_input.tail_node(), 0 };
255*c217d954SCole Faust         NodeIdxPair deltas        = { _ss_deltas.tail_node(), 0 };
256*c217d954SCole Faust         return GraphBuilder::add_bounding_box_transform_node(s.graph(), common_params, input, deltas, _bbox_info);
257*c217d954SCole Faust     }
258*c217d954SCole Faust 
259*c217d954SCole Faust private:
260*c217d954SCole Faust     SubStream                _ss_input;
261*c217d954SCole Faust     SubStream                _ss_deltas;
262*c217d954SCole Faust     BoundingBoxTransformInfo _bbox_info;
263*c217d954SCole Faust };
264*c217d954SCole Faust 
265*c217d954SCole Faust /** Channel Shuffle Layer */
266*c217d954SCole Faust class ChannelShuffleLayer final : public ILayer
267*c217d954SCole Faust {
268*c217d954SCole Faust public:
269*c217d954SCole Faust     /** Construct a Channel Shuffle layer.
270*c217d954SCole Faust      *
271*c217d954SCole Faust      * @param[in] num_groups Number of groups
272*c217d954SCole Faust      */
ChannelShuffleLayer(unsigned int num_groups)273*c217d954SCole Faust     ChannelShuffleLayer(unsigned int num_groups)
274*c217d954SCole Faust         : _num_groups(num_groups)
275*c217d954SCole Faust     {
276*c217d954SCole Faust     }
277*c217d954SCole Faust 
create_layer(IStream & s)278*c217d954SCole Faust     NodeID create_layer(IStream &s) override
279*c217d954SCole Faust     {
280*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
281*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
282*c217d954SCole Faust         return GraphBuilder::add_channel_shuffle_node(s.graph(), common_params, input, _num_groups);
283*c217d954SCole Faust     }
284*c217d954SCole Faust 
285*c217d954SCole Faust private:
286*c217d954SCole Faust     unsigned int _num_groups;
287*c217d954SCole Faust };
288*c217d954SCole Faust 
289*c217d954SCole Faust /** Concat Layer */
290*c217d954SCole Faust class ConcatLayer final : public ILayer
291*c217d954SCole Faust {
292*c217d954SCole Faust public:
293*c217d954SCole Faust     /** Construct a concatenation layer
294*c217d954SCole Faust      *
295*c217d954SCole Faust      * @param[in] sub_stream1      First graph branch
296*c217d954SCole Faust      * @param[in] sub_stream2      Second graph branch
297*c217d954SCole Faust      * @param[in] rest_sub_streams Rest sub-graph branches
298*c217d954SCole Faust      */
299*c217d954SCole Faust     template <typename... Ts>
ConcatLayer(SubStream && sub_stream1,SubStream && sub_stream2,Ts &&...rest_sub_streams)300*c217d954SCole Faust     ConcatLayer(SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams)
301*c217d954SCole Faust         : _sub_streams(), _concat_descriptor(DataLayoutDimension::CHANNEL)
302*c217d954SCole Faust     {
303*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream1)));
304*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream2)));
305*c217d954SCole Faust 
306*c217d954SCole Faust         utility::for_each([&](SubStream && sub_stream)
307*c217d954SCole Faust         {
308*c217d954SCole Faust             _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream)));
309*c217d954SCole Faust         },
310*c217d954SCole Faust         std::move(rest_sub_streams)...);
311*c217d954SCole Faust     }
312*c217d954SCole Faust     /** Construct a concatenation layer
313*c217d954SCole Faust      *
314*c217d954SCole Faust      * @param[in] concat_descriptor Concat layer descriptor
315*c217d954SCole Faust      * @param[in] sub_stream1       First graph branch
316*c217d954SCole Faust      * @param[in] sub_stream2       Second graph branch
317*c217d954SCole Faust      * @param[in] rest_sub_streams  Rest sub-graph branches
318*c217d954SCole Faust      */
319*c217d954SCole Faust     template <typename... Ts>
ConcatLayer(descriptors::ConcatLayerDescriptor concat_descriptor,SubStream && sub_stream1,SubStream && sub_stream2,Ts &&...rest_sub_streams)320*c217d954SCole Faust     ConcatLayer(descriptors::ConcatLayerDescriptor concat_descriptor, SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams)
321*c217d954SCole Faust         : _sub_streams(), _concat_descriptor(concat_descriptor)
322*c217d954SCole Faust     {
323*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream1)));
324*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream2)));
325*c217d954SCole Faust 
326*c217d954SCole Faust         utility::for_each([&](SubStream && sub_stream)
327*c217d954SCole Faust         {
328*c217d954SCole Faust             _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream)));
329*c217d954SCole Faust         },
330*c217d954SCole Faust         std::move(rest_sub_streams)...);
331*c217d954SCole Faust     }
332*c217d954SCole Faust     /** Construct a concat layer
333*c217d954SCole Faust      *
334*c217d954SCole Faust      * @param[in] sub_stream Sub-stream
335*c217d954SCole Faust      */
336*c217d954SCole Faust     template <typename... Ts>
ConcatLayer(SubStream && sub_stream)337*c217d954SCole Faust     ConcatLayer(SubStream &&sub_stream)
338*c217d954SCole Faust         : _sub_streams(), _concat_descriptor(DataLayoutDimension::CHANNEL)
339*c217d954SCole Faust     {
340*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream)));
341*c217d954SCole Faust     }
create_layer(IStream & s)342*c217d954SCole Faust     NodeID create_layer(IStream &s) override
343*c217d954SCole Faust     {
344*c217d954SCole Faust         NodeID     nid           = EmptyNodeID;
345*c217d954SCole Faust         NodeParams common_params = { name(), s.hints().target_hint };
346*c217d954SCole Faust         if(_sub_streams.size() == 1 && _sub_streams.at(0) != nullptr)
347*c217d954SCole Faust         {
348*c217d954SCole Faust             nid = _sub_streams[0]->tail_node();
349*c217d954SCole Faust         }
350*c217d954SCole Faust         else
351*c217d954SCole Faust         {
352*c217d954SCole Faust             // Collect tail nodes and concatenate
353*c217d954SCole Faust             std::vector<NodeIdxPair> nodes;
354*c217d954SCole Faust             for(auto &ss : _sub_streams)
355*c217d954SCole Faust             {
356*c217d954SCole Faust                 if(ss && (ss->tail_node() != EmptyNodeID))
357*c217d954SCole Faust                 {
358*c217d954SCole Faust                     const auto tail_node = s.graph().node(ss->tail_node());
359*c217d954SCole Faust                     if(tail_node != nullptr && tail_node->type() != NodeType::Output)
360*c217d954SCole Faust                     {
361*c217d954SCole Faust                         nodes.push_back({ ss->tail_node(), 0 });
362*c217d954SCole Faust                     }
363*c217d954SCole Faust                 }
364*c217d954SCole Faust             }
365*c217d954SCole Faust             nid = GraphBuilder::add_concatenate_node(s.graph(), common_params, nodes, _concat_descriptor);
366*c217d954SCole Faust         }
367*c217d954SCole Faust         return nid;
368*c217d954SCole Faust     }
369*c217d954SCole Faust 
370*c217d954SCole Faust private:
371*c217d954SCole Faust     std::vector<std::unique_ptr<SubStream>> _sub_streams;
372*c217d954SCole Faust     descriptors::ConcatLayerDescriptor      _concat_descriptor;
373*c217d954SCole Faust };
374*c217d954SCole Faust 
375*c217d954SCole Faust /** Convolution Layer */
376*c217d954SCole Faust class ConvolutionLayer final : public ILayer
377*c217d954SCole Faust {
378*c217d954SCole Faust public:
379*c217d954SCole Faust     /** Construct a convolution layer.
380*c217d954SCole Faust      *
381*c217d954SCole Faust      * @param[in] conv_width         Convolution width.
382*c217d954SCole Faust      * @param[in] conv_height        Convolution height.
383*c217d954SCole Faust      * @param[in] ofm                Output feature map.
384*c217d954SCole Faust      * @param[in] weights            Accessor to get kernel weights from.
385*c217d954SCole Faust      * @param[in] bias               Accessor to get kernel bias from.
386*c217d954SCole Faust      * @param[in] conv_info          Padding and stride information.
387*c217d954SCole Faust      * @param[in] num_groups         (Optional) Number of groups. Default: 1.
388*c217d954SCole Faust      * @param[in] weights_quant_info (Optional) Weights quantization information
389*c217d954SCole Faust      * @param[in] out_quant_info     (Optional) Output quantization info
390*c217d954SCole Faust      */
391*c217d954SCole Faust     ConvolutionLayer(unsigned int           conv_width,
392*c217d954SCole Faust                      unsigned int           conv_height,
393*c217d954SCole Faust                      unsigned int           ofm,
394*c217d954SCole Faust                      ITensorAccessorUPtr    weights,
395*c217d954SCole Faust                      ITensorAccessorUPtr    bias,
396*c217d954SCole Faust                      PadStrideInfo          conv_info,
397*c217d954SCole Faust                      unsigned int           num_groups         = 1,
398*c217d954SCole Faust                      const QuantizationInfo weights_quant_info = QuantizationInfo(),
399*c217d954SCole Faust                      const QuantizationInfo out_quant_info     = QuantizationInfo())
_conv_width(conv_width)400*c217d954SCole Faust         : _conv_width(conv_width),
401*c217d954SCole Faust           _conv_height(conv_height),
402*c217d954SCole Faust           _ofm(ofm),
403*c217d954SCole Faust           _conv_info(std::move(conv_info)),
404*c217d954SCole Faust           _num_groups(num_groups),
405*c217d954SCole Faust           _weights(std::move(weights)),
406*c217d954SCole Faust           _bias(std::move(bias)),
407*c217d954SCole Faust           _weights_quant_info(std::move(weights_quant_info)),
408*c217d954SCole Faust           _out_quant_info(std::move(out_quant_info))
409*c217d954SCole Faust     {
410*c217d954SCole Faust     }
411*c217d954SCole Faust 
create_layer(IStream & s)412*c217d954SCole Faust     NodeID create_layer(IStream &s) override
413*c217d954SCole Faust     {
414*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
415*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
416*c217d954SCole Faust         return GraphBuilder::add_convolution_node(s.graph(), common_params, input,
417*c217d954SCole Faust                                                   Size2D(_conv_width, _conv_height), _ofm, _conv_info, _num_groups,
418*c217d954SCole Faust                                                   s.hints().convolution_method_hint, s.hints().fast_math_hint,
419*c217d954SCole Faust                                                   std::move(_weights), std::move(_bias), std::move(_weights_quant_info), std::move(_out_quant_info));
420*c217d954SCole Faust     }
421*c217d954SCole Faust 
422*c217d954SCole Faust private:
423*c217d954SCole Faust     unsigned int           _conv_width;
424*c217d954SCole Faust     unsigned int           _conv_height;
425*c217d954SCole Faust     unsigned int           _ofm;
426*c217d954SCole Faust     const PadStrideInfo    _conv_info;
427*c217d954SCole Faust     unsigned int           _num_groups;
428*c217d954SCole Faust     ITensorAccessorUPtr    _weights;
429*c217d954SCole Faust     ITensorAccessorUPtr    _bias;
430*c217d954SCole Faust     const QuantizationInfo _weights_quant_info;
431*c217d954SCole Faust     const QuantizationInfo _out_quant_info;
432*c217d954SCole Faust };
433*c217d954SCole Faust 
434*c217d954SCole Faust /** Deconvolution Layer */
435*c217d954SCole Faust class DeconvolutionLayer final : public ILayer
436*c217d954SCole Faust {
437*c217d954SCole Faust public:
438*c217d954SCole Faust     /** Construct a convolution layer.
439*c217d954SCole Faust      *
440*c217d954SCole Faust      * @param[in] conv_width  Convolution width.
441*c217d954SCole Faust      * @param[in] conv_height Convolution height.
442*c217d954SCole Faust      * @param[in] ofm         Output feature map.
443*c217d954SCole Faust      * @param[in] weights     Accessor to get kernel weights from.
444*c217d954SCole Faust      * @param[in] bias        Accessor to get kernel bias from.
445*c217d954SCole Faust      * @param[in] deconv_info Padding and stride information.
446*c217d954SCole Faust      */
DeconvolutionLayer(unsigned int conv_width,unsigned int conv_height,unsigned int ofm,ITensorAccessorUPtr weights,ITensorAccessorUPtr bias,PadStrideInfo deconv_info)447*c217d954SCole Faust     DeconvolutionLayer(unsigned int        conv_width,
448*c217d954SCole Faust                        unsigned int        conv_height,
449*c217d954SCole Faust                        unsigned int        ofm,
450*c217d954SCole Faust                        ITensorAccessorUPtr weights,
451*c217d954SCole Faust                        ITensorAccessorUPtr bias,
452*c217d954SCole Faust                        PadStrideInfo       deconv_info)
453*c217d954SCole Faust         : _conv_width(conv_width),
454*c217d954SCole Faust           _conv_height(conv_height),
455*c217d954SCole Faust           _ofm(ofm),
456*c217d954SCole Faust           _deconv_info(std::move(deconv_info)),
457*c217d954SCole Faust           _weights(std::move(weights)),
458*c217d954SCole Faust           _bias(std::move(bias))
459*c217d954SCole Faust     {
460*c217d954SCole Faust     }
461*c217d954SCole Faust 
create_layer(IStream & s)462*c217d954SCole Faust     NodeID create_layer(IStream &s) override
463*c217d954SCole Faust     {
464*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
465*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
466*c217d954SCole Faust         return GraphBuilder::add_deconvolution_node(s.graph(), common_params, input,
467*c217d954SCole Faust                                                     Size2D(_conv_width, _conv_height), _ofm, _deconv_info,
468*c217d954SCole Faust                                                     std::move(_weights), std::move(_bias));
469*c217d954SCole Faust     }
470*c217d954SCole Faust 
471*c217d954SCole Faust private:
472*c217d954SCole Faust     unsigned int        _conv_width;
473*c217d954SCole Faust     unsigned int        _conv_height;
474*c217d954SCole Faust     unsigned int        _ofm;
475*c217d954SCole Faust     const PadStrideInfo _deconv_info;
476*c217d954SCole Faust     ITensorAccessorUPtr _weights;
477*c217d954SCole Faust     ITensorAccessorUPtr _bias;
478*c217d954SCole Faust };
479*c217d954SCole Faust 
480*c217d954SCole Faust /** Depthwise Convolution Layer */
481*c217d954SCole Faust class DepthwiseConvolutionLayer final : public ILayer
482*c217d954SCole Faust {
483*c217d954SCole Faust public:
484*c217d954SCole Faust     /** Construct a depthwise convolution layer.
485*c217d954SCole Faust      *
486*c217d954SCole Faust      * @param[in] conv_width         Convolution width.
487*c217d954SCole Faust      * @param[in] conv_height        Convolution height.
488*c217d954SCole Faust      * @param[in] weights            Accessor to get kernel weights from.
489*c217d954SCole Faust      * @param[in] bias               Accessor to get kernel bias from.
490*c217d954SCole Faust      * @param[in] conv_info          Padding and stride information.
491*c217d954SCole Faust      * @param[in] depth_multiplier   (Optional) Depth multiplier parameter.
492*c217d954SCole Faust      * @param[in] weights_quant_info (Optional) Quantization info used for weights
493*c217d954SCole Faust      * @param[in] out_quant_info     (Optional) Output quantization info
494*c217d954SCole Faust      */
495*c217d954SCole Faust     DepthwiseConvolutionLayer(unsigned int           conv_width,
496*c217d954SCole Faust                               unsigned int           conv_height,
497*c217d954SCole Faust                               ITensorAccessorUPtr    weights,
498*c217d954SCole Faust                               ITensorAccessorUPtr    bias,
499*c217d954SCole Faust                               PadStrideInfo          conv_info,
500*c217d954SCole Faust                               int                    depth_multiplier   = 1,
501*c217d954SCole Faust                               const QuantizationInfo weights_quant_info = QuantizationInfo(),
502*c217d954SCole Faust                               const QuantizationInfo out_quant_info     = QuantizationInfo())
_conv_width(conv_width)503*c217d954SCole Faust         : _conv_width(conv_width),
504*c217d954SCole Faust           _conv_height(conv_height),
505*c217d954SCole Faust           _conv_info(std::move(conv_info)),
506*c217d954SCole Faust           _weights(std::move(weights)),
507*c217d954SCole Faust           _bias(std::move(bias)),
508*c217d954SCole Faust           _depth_multiplier(depth_multiplier),
509*c217d954SCole Faust           _weights_quant_info(std::move(weights_quant_info)),
510*c217d954SCole Faust           _out_quant_info(std::move(out_quant_info))
511*c217d954SCole Faust     {
512*c217d954SCole Faust     }
513*c217d954SCole Faust 
create_layer(IStream & s)514*c217d954SCole Faust     NodeID create_layer(IStream &s) override
515*c217d954SCole Faust     {
516*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
517*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
518*c217d954SCole Faust         return GraphBuilder::add_depthwise_convolution_node(s.graph(), common_params,
519*c217d954SCole Faust                                                             input, Size2D(_conv_width, _conv_height), _conv_info, _depth_multiplier,
520*c217d954SCole Faust                                                             s.hints().depthwise_convolution_method_hint,
521*c217d954SCole Faust                                                             std::move(_weights), std::move(_bias), std::move(_weights_quant_info), std::move(_out_quant_info));
522*c217d954SCole Faust     }
523*c217d954SCole Faust 
524*c217d954SCole Faust private:
525*c217d954SCole Faust     unsigned int           _conv_width;
526*c217d954SCole Faust     unsigned int           _conv_height;
527*c217d954SCole Faust     const PadStrideInfo    _conv_info;
528*c217d954SCole Faust     ITensorAccessorUPtr    _weights;
529*c217d954SCole Faust     ITensorAccessorUPtr    _bias;
530*c217d954SCole Faust     int                    _depth_multiplier;
531*c217d954SCole Faust     const QuantizationInfo _weights_quant_info;
532*c217d954SCole Faust     const QuantizationInfo _out_quant_info;
533*c217d954SCole Faust };
534*c217d954SCole Faust 
535*c217d954SCole Faust /** DepthToSpace Layer */
536*c217d954SCole Faust class DepthToSpaceLayer final : public ILayer
537*c217d954SCole Faust {
538*c217d954SCole Faust public:
539*c217d954SCole Faust     /** Construct an DepthToSpace layer.
540*c217d954SCole Faust      *
541*c217d954SCole Faust      * @param[in] block_shape Block size to rearranged
542*c217d954SCole Faust      */
DepthToSpaceLayer(int32_t block_shape)543*c217d954SCole Faust     DepthToSpaceLayer(int32_t block_shape)
544*c217d954SCole Faust         : _block_shape(block_shape)
545*c217d954SCole Faust     {
546*c217d954SCole Faust     }
547*c217d954SCole Faust 
create_layer(IStream & s)548*c217d954SCole Faust     NodeID create_layer(IStream &s) override
549*c217d954SCole Faust     {
550*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
551*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
552*c217d954SCole Faust         return GraphBuilder::add_depth_to_space_node(s.graph(), common_params, input, _block_shape);
553*c217d954SCole Faust     }
554*c217d954SCole Faust 
555*c217d954SCole Faust private:
556*c217d954SCole Faust     int32_t _block_shape;
557*c217d954SCole Faust };
558*c217d954SCole Faust 
559*c217d954SCole Faust /** Dequantization Layer */
560*c217d954SCole Faust class DequantizationLayer final : public ILayer
561*c217d954SCole Faust {
562*c217d954SCole Faust public:
563*c217d954SCole Faust     /** Construct a dequantization layer.
564*c217d954SCole Faust      *
565*c217d954SCole Faust      */
DequantizationLayer()566*c217d954SCole Faust     DequantizationLayer()
567*c217d954SCole Faust     {
568*c217d954SCole Faust     }
569*c217d954SCole Faust 
create_layer(IStream & s)570*c217d954SCole Faust     NodeID create_layer(IStream &s) override
571*c217d954SCole Faust     {
572*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
573*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
574*c217d954SCole Faust         return GraphBuilder::add_dequantization_node(s.graph(), common_params, input);
575*c217d954SCole Faust     }
576*c217d954SCole Faust };
577*c217d954SCole Faust 
578*c217d954SCole Faust /** DetectionOutput Layer */
579*c217d954SCole Faust class DetectionOutputLayer final : public ILayer
580*c217d954SCole Faust {
581*c217d954SCole Faust public:
582*c217d954SCole Faust     /** Construct a detection output layer.
583*c217d954SCole Faust      *
584*c217d954SCole Faust      * @param[in] sub_stream_conf  Confidence graph sub-stream.
585*c217d954SCole Faust      * @param[in] sub_stream_prior PriorBox graph sub-stream.
586*c217d954SCole Faust      * @param[in] detect_info      DetectionOutput parameters.
587*c217d954SCole Faust      */
DetectionOutputLayer(SubStream && sub_stream_conf,SubStream && sub_stream_prior,const DetectionOutputLayerInfo & detect_info)588*c217d954SCole Faust     DetectionOutputLayer(SubStream &&sub_stream_conf, SubStream &&sub_stream_prior, const DetectionOutputLayerInfo &detect_info)
589*c217d954SCole Faust         : _ss_conf(std::move(sub_stream_conf)), _ss_prior(std::move(sub_stream_prior)), _detect_info(detect_info)
590*c217d954SCole Faust     {
591*c217d954SCole Faust     }
592*c217d954SCole Faust 
create_layer(IStream & s)593*c217d954SCole Faust     NodeID create_layer(IStream &s) override
594*c217d954SCole Faust     {
595*c217d954SCole Faust         NodeParams  common_params  = { name(), s.hints().target_hint };
596*c217d954SCole Faust         NodeIdxPair input_loc      = { s.tail_node(), 0 };
597*c217d954SCole Faust         NodeIdxPair input_conf     = { _ss_conf.tail_node(), 0 };
598*c217d954SCole Faust         NodeIdxPair input_priorbox = { _ss_prior.tail_node(), 0 };
599*c217d954SCole Faust         return GraphBuilder::add_detection_output_node(s.graph(), common_params, input_loc, input_conf, input_priorbox, _detect_info);
600*c217d954SCole Faust     }
601*c217d954SCole Faust 
602*c217d954SCole Faust private:
603*c217d954SCole Faust     SubStream                _ss_conf;
604*c217d954SCole Faust     SubStream                _ss_prior;
605*c217d954SCole Faust     DetectionOutputLayerInfo _detect_info;
606*c217d954SCole Faust };
607*c217d954SCole Faust /** DetectionOutputPostProcess Layer */
608*c217d954SCole Faust class DetectionPostProcessLayer final : public ILayer
609*c217d954SCole Faust {
610*c217d954SCole Faust public:
611*c217d954SCole Faust     /** Construct a detection output layer.
612*c217d954SCole Faust      *
613*c217d954SCole Faust      * @param[in] sub_stream_class_prediction Class prediction graph sub-stream.
614*c217d954SCole Faust      * @param[in] detect_info                 DetectionOutput parameters.
615*c217d954SCole Faust      * @param[in] anchors                     Accessor to get anchors tensor data from.
616*c217d954SCole Faust      * @param[in] out_quant_info              (Optional) Output quantization info
617*c217d954SCole Faust      */
618*c217d954SCole Faust     DetectionPostProcessLayer(SubStream &&sub_stream_class_prediction, DetectionPostProcessLayerInfo detect_info, ITensorAccessorUPtr anchors,
619*c217d954SCole Faust                               const QuantizationInfo out_quant_info = QuantizationInfo())
_sub_stream_class_prediction(std::move (sub_stream_class_prediction))620*c217d954SCole Faust         : _sub_stream_class_prediction(std::move(sub_stream_class_prediction)), _detect_info(detect_info), _anchors(std::move(anchors)), _out_quant_info(std::move(out_quant_info))
621*c217d954SCole Faust     {
622*c217d954SCole Faust     }
623*c217d954SCole Faust 
create_layer(IStream & s)624*c217d954SCole Faust     NodeID create_layer(IStream &s) override
625*c217d954SCole Faust     {
626*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON(_anchors == nullptr);
627*c217d954SCole Faust 
628*c217d954SCole Faust         NodeParams  common_params          = { name(), s.hints().target_hint };
629*c217d954SCole Faust         NodeIdxPair input_box_encoding     = { s.tail_node(), 0 };
630*c217d954SCole Faust         NodeIdxPair input_class_prediction = { _sub_stream_class_prediction.tail_node(), 0 };
631*c217d954SCole Faust         return GraphBuilder::add_detection_post_process_node(s.graph(), common_params, input_box_encoding, input_class_prediction, _detect_info, std::move(_anchors), std::move(_out_quant_info));
632*c217d954SCole Faust     }
633*c217d954SCole Faust 
634*c217d954SCole Faust private:
635*c217d954SCole Faust     SubStream                     _sub_stream_class_prediction;
636*c217d954SCole Faust     DetectionPostProcessLayerInfo _detect_info;
637*c217d954SCole Faust     ITensorAccessorUPtr           _anchors;
638*c217d954SCole Faust     const QuantizationInfo        _out_quant_info;
639*c217d954SCole Faust };
640*c217d954SCole Faust /** Dummy Layer */
641*c217d954SCole Faust class DummyLayer final : public ILayer
642*c217d954SCole Faust {
643*c217d954SCole Faust public:
644*c217d954SCole Faust     /** Construct a dummy layer.
645*c217d954SCole Faust      *
646*c217d954SCole Faust      * @param[in] shape Output shape
647*c217d954SCole Faust      */
DummyLayer(TensorShape shape)648*c217d954SCole Faust     DummyLayer(TensorShape shape)
649*c217d954SCole Faust         : _shape(shape)
650*c217d954SCole Faust     {
651*c217d954SCole Faust     }
652*c217d954SCole Faust 
create_layer(IStream & s)653*c217d954SCole Faust     NodeID create_layer(IStream &s) override
654*c217d954SCole Faust     {
655*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
656*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
657*c217d954SCole Faust         return GraphBuilder::add_dummy_node(s.graph(), common_params, input, _shape);
658*c217d954SCole Faust     }
659*c217d954SCole Faust 
660*c217d954SCole Faust private:
661*c217d954SCole Faust     TensorShape _shape;
662*c217d954SCole Faust };
663*c217d954SCole Faust 
664*c217d954SCole Faust class EltwiseLayer final : public ILayer
665*c217d954SCole Faust {
666*c217d954SCole Faust public:
667*c217d954SCole Faust     /** Construct an element-wise operation layer
668*c217d954SCole Faust      *
669*c217d954SCole Faust      * @param[in] sub_stream0 First graph sub-stream
670*c217d954SCole Faust      * @param[in] sub_stream1 First graph sub-stream
671*c217d954SCole Faust      * @param[in] op          Element-wise operation to perform
672*c217d954SCole Faust      */
EltwiseLayer(SubStream && sub_stream0,SubStream && sub_stream1,EltwiseOperation op)673*c217d954SCole Faust     EltwiseLayer(SubStream &&sub_stream0, SubStream &&sub_stream1, EltwiseOperation op)
674*c217d954SCole Faust         : _ss0(std::move(sub_stream0)), _ss1(std::move(sub_stream1)), _op(op)
675*c217d954SCole Faust     {
676*c217d954SCole Faust     }
677*c217d954SCole Faust 
create_layer(IStream & s)678*c217d954SCole Faust     NodeID create_layer(IStream &s) override
679*c217d954SCole Faust     {
680*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
681*c217d954SCole Faust         NodeIdxPair input0        = { _ss0.tail_node(), 0 };
682*c217d954SCole Faust         NodeIdxPair input1        = { _ss1.tail_node(), 0 };
683*c217d954SCole Faust 
684*c217d954SCole Faust         return GraphBuilder::add_elementwise_node(s.graph(), common_params, input0, input1, _op);
685*c217d954SCole Faust     }
686*c217d954SCole Faust 
687*c217d954SCole Faust private:
688*c217d954SCole Faust     SubStream        _ss0;
689*c217d954SCole Faust     SubStream        _ss1;
690*c217d954SCole Faust     EltwiseOperation _op;
691*c217d954SCole Faust };
692*c217d954SCole Faust /** Flatten Layer */
693*c217d954SCole Faust class FlattenLayer final : public ILayer
694*c217d954SCole Faust {
695*c217d954SCole Faust public:
696*c217d954SCole Faust     /** Construct a flatten layer. */
FlattenLayer()697*c217d954SCole Faust     FlattenLayer()
698*c217d954SCole Faust     {
699*c217d954SCole Faust     }
700*c217d954SCole Faust 
create_layer(IStream & s)701*c217d954SCole Faust     NodeID create_layer(IStream &s) override
702*c217d954SCole Faust     {
703*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
704*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
705*c217d954SCole Faust         return GraphBuilder::add_flatten_node(s.graph(), common_params, input);
706*c217d954SCole Faust     }
707*c217d954SCole Faust };
708*c217d954SCole Faust 
709*c217d954SCole Faust /** Fully Connected Layer */
710*c217d954SCole Faust class FullyConnectedLayer final : public ILayer
711*c217d954SCole Faust {
712*c217d954SCole Faust public:
713*c217d954SCole Faust     /** Construct a fully connected layer.
714*c217d954SCole Faust      *
715*c217d954SCole Faust      * @param[in] num_outputs        Number of outputs.
716*c217d954SCole Faust      * @param[in] weights            Accessor to get weights from.
717*c217d954SCole Faust      * @param[in] bias               Accessor to get bias from.
718*c217d954SCole Faust      * @param[in] fc_info            (Optional) Fully connected layer metadata
719*c217d954SCole Faust      * @param[in] weights_quant_info (Optional) Weights quantization information
720*c217d954SCole Faust      * @param[in] out_quant_info     (Optional) Output quantization info
721*c217d954SCole Faust      */
722*c217d954SCole Faust     FullyConnectedLayer(unsigned int                  num_outputs,
723*c217d954SCole Faust                         ITensorAccessorUPtr           weights,
724*c217d954SCole Faust                         ITensorAccessorUPtr           bias,
725*c217d954SCole Faust                         const FullyConnectedLayerInfo fc_info            = FullyConnectedLayerInfo(),
726*c217d954SCole Faust                         const QuantizationInfo        weights_quant_info = QuantizationInfo(),
727*c217d954SCole Faust                         const QuantizationInfo        out_quant_info     = QuantizationInfo())
_num_outputs(num_outputs)728*c217d954SCole Faust         : _num_outputs(num_outputs),
729*c217d954SCole Faust           _weights(std::move(weights)),
730*c217d954SCole Faust           _bias(std::move(bias)),
731*c217d954SCole Faust           _weights_ss(nullptr),
732*c217d954SCole Faust           _bias_ss(nullptr),
733*c217d954SCole Faust           _fc_info(fc_info),
734*c217d954SCole Faust           _weights_quant_info(std::move(weights_quant_info)),
735*c217d954SCole Faust           _out_quant_info(std::move(out_quant_info))
736*c217d954SCole Faust     {
737*c217d954SCole Faust     }
738*c217d954SCole Faust 
739*c217d954SCole Faust     /** Construct a fully connected layer.
740*c217d954SCole Faust      *
741*c217d954SCole Faust      * @param[in] num_outputs        Number of outputs.
742*c217d954SCole Faust      * @param[in] sub_stream_weights Graph sub-stream for the weights.
743*c217d954SCole Faust      * @param[in] sub_stream_bias    Graph sub-stream for the bias.
744*c217d954SCole Faust      * @param[in] fc_info            (Optional) Fully connected layer metadata
745*c217d954SCole Faust      * @param[in] weights_quant_info (Optional) Weights quantization information
746*c217d954SCole Faust      * @param[in] out_quant_info     (Optional) Output quantization info
747*c217d954SCole Faust      */
748*c217d954SCole Faust     FullyConnectedLayer(unsigned int                  num_outputs,
749*c217d954SCole Faust                         SubStream                     sub_stream_weights,
750*c217d954SCole Faust                         SubStream                     sub_stream_bias,
751*c217d954SCole Faust                         const FullyConnectedLayerInfo fc_info            = FullyConnectedLayerInfo(),
752*c217d954SCole Faust                         const QuantizationInfo        weights_quant_info = QuantizationInfo(),
753*c217d954SCole Faust                         const QuantizationInfo        out_quant_info     = QuantizationInfo())
_num_outputs(num_outputs)754*c217d954SCole Faust         : _num_outputs(num_outputs),
755*c217d954SCole Faust           _weights(nullptr),
756*c217d954SCole Faust           _bias(nullptr),
757*c217d954SCole Faust           _weights_ss(std::make_unique<SubStream>(std::move(sub_stream_weights))),
758*c217d954SCole Faust           _bias_ss(std::make_unique<SubStream>(std::move(sub_stream_bias))),
759*c217d954SCole Faust           _fc_info(fc_info),
760*c217d954SCole Faust           _weights_quant_info(std::move(weights_quant_info)),
761*c217d954SCole Faust           _out_quant_info(std::move(out_quant_info))
762*c217d954SCole Faust     {
763*c217d954SCole Faust     }
764*c217d954SCole Faust 
765*c217d954SCole Faust     /** Create layer and add to the given stream.
766*c217d954SCole Faust      *
767*c217d954SCole Faust      * @param[in] s Stream to add layer to.
768*c217d954SCole Faust      *
769*c217d954SCole Faust      * @return ID of the created node.
770*c217d954SCole Faust      */
create_layer(IStream & s)771*c217d954SCole Faust     NodeID create_layer(IStream &s) override
772*c217d954SCole Faust     {
773*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
774*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
775*c217d954SCole Faust         if(_weights != nullptr)
776*c217d954SCole Faust         {
777*c217d954SCole Faust             return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
778*c217d954SCole Faust                                                            std::move(_weights), std::move(_bias), _fc_info,
779*c217d954SCole Faust                                                            std::move(_weights_quant_info), std::move(_out_quant_info), s.hints().fast_math_hint);
780*c217d954SCole Faust         }
781*c217d954SCole Faust         else
782*c217d954SCole Faust         {
783*c217d954SCole Faust             ARM_COMPUTE_ERROR_ON(_weights_ss == nullptr);
784*c217d954SCole Faust 
785*c217d954SCole Faust             NodeID bias_nid = (_bias_ss == nullptr) ? EmptyNodeID : _bias_ss->tail_node();
786*c217d954SCole Faust             return GraphBuilder::add_fully_connected_layer(s.graph(), common_params, input, _num_outputs,
787*c217d954SCole Faust                                                            _weights_ss->tail_node(), bias_nid, _fc_info,
788*c217d954SCole Faust                                                            std::move(_out_quant_info), s.hints().fast_math_hint);
789*c217d954SCole Faust         }
790*c217d954SCole Faust     }
791*c217d954SCole Faust 
792*c217d954SCole Faust private:
793*c217d954SCole Faust     unsigned int                  _num_outputs;
794*c217d954SCole Faust     ITensorAccessorUPtr           _weights;
795*c217d954SCole Faust     ITensorAccessorUPtr           _bias;
796*c217d954SCole Faust     std::unique_ptr<SubStream>    _weights_ss;
797*c217d954SCole Faust     std::unique_ptr<SubStream>    _bias_ss;
798*c217d954SCole Faust     const FullyConnectedLayerInfo _fc_info;
799*c217d954SCole Faust     const QuantizationInfo        _weights_quant_info;
800*c217d954SCole Faust     const QuantizationInfo        _out_quant_info;
801*c217d954SCole Faust };
802*c217d954SCole Faust 
803*c217d954SCole Faust /** Generate Proposals Layer */
804*c217d954SCole Faust class GenerateProposalsLayer final : public ILayer
805*c217d954SCole Faust {
806*c217d954SCole Faust public:
807*c217d954SCole Faust     /** Construct a generate proposals layer.
808*c217d954SCole Faust      *
809*c217d954SCole Faust      * @param[in] ss_scores  Graph sub-stream for the scores.
810*c217d954SCole Faust      * @param[in] ss_deltas  Graph sub-stream for the deltas.
811*c217d954SCole Faust      * @param[in] ss_anchors Graph sub-stream for the anchors.
812*c217d954SCole Faust      * @param[in] info       Generate Proposals operation information.
813*c217d954SCole Faust      */
GenerateProposalsLayer(SubStream && ss_scores,SubStream && ss_deltas,SubStream && ss_anchors,GenerateProposalsInfo info)814*c217d954SCole Faust     GenerateProposalsLayer(SubStream &&ss_scores, SubStream &&ss_deltas, SubStream &&ss_anchors, GenerateProposalsInfo info)
815*c217d954SCole Faust         : _ss_scores(std::move(ss_scores)), _ss_deltas(std::move(ss_deltas)), _ss_anchors(std::move(ss_anchors)), _info(info)
816*c217d954SCole Faust     {
817*c217d954SCole Faust     }
818*c217d954SCole Faust 
819*c217d954SCole Faust     /** Create layer and add to the given stream.
820*c217d954SCole Faust      *
821*c217d954SCole Faust      * @param[in] s Stream to add layer to.
822*c217d954SCole Faust      *
823*c217d954SCole Faust      * @return ID of the created node.
824*c217d954SCole Faust      */
create_layer(IStream & s)825*c217d954SCole Faust     NodeID create_layer(IStream &s) override
826*c217d954SCole Faust     {
827*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
828*c217d954SCole Faust         NodeIdxPair scores        = { _ss_scores.tail_node(), 0 };
829*c217d954SCole Faust         NodeIdxPair deltas        = { _ss_deltas.tail_node(), 0 };
830*c217d954SCole Faust         NodeIdxPair anchors       = { _ss_anchors.tail_node(), 0 };
831*c217d954SCole Faust         return GraphBuilder::add_generate_proposals_node(s.graph(), common_params, scores, deltas, anchors, _info);
832*c217d954SCole Faust     }
833*c217d954SCole Faust 
834*c217d954SCole Faust private:
835*c217d954SCole Faust     SubStream             _ss_scores;
836*c217d954SCole Faust     SubStream             _ss_deltas;
837*c217d954SCole Faust     SubStream             _ss_anchors;
838*c217d954SCole Faust     GenerateProposalsInfo _info;
839*c217d954SCole Faust };
840*c217d954SCole Faust 
841*c217d954SCole Faust /** L2 Normalize Layer */
842*c217d954SCole Faust class L2NormalizeLayer final : public ILayer
843*c217d954SCole Faust {
844*c217d954SCole Faust public:
845*c217d954SCole Faust     /** Construct a L2 Normalize layer.
846*c217d954SCole Faust      *
847*c217d954SCole Faust      * @param[in] axis    Axis to perform normalization on
848*c217d954SCole Faust      * @param[in] epsilon Lower bound value for the normalization
849*c217d954SCole Faust      */
L2NormalizeLayer(int axis,float epsilon)850*c217d954SCole Faust     L2NormalizeLayer(int axis, float epsilon)
851*c217d954SCole Faust         : _axis(axis), _epsilon(epsilon)
852*c217d954SCole Faust     {
853*c217d954SCole Faust     }
854*c217d954SCole Faust 
create_layer(IStream & s)855*c217d954SCole Faust     NodeID create_layer(IStream &s) override
856*c217d954SCole Faust     {
857*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
858*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
859*c217d954SCole Faust         return GraphBuilder::add_l2_normalize_node(s.graph(), common_params, input, _axis, _epsilon);
860*c217d954SCole Faust     }
861*c217d954SCole Faust 
862*c217d954SCole Faust private:
863*c217d954SCole Faust     int   _axis;
864*c217d954SCole Faust     float _epsilon;
865*c217d954SCole Faust };
866*c217d954SCole Faust 
867*c217d954SCole Faust /** Normalization Layer */
868*c217d954SCole Faust class NormalizationLayer final : public ILayer
869*c217d954SCole Faust {
870*c217d954SCole Faust public:
871*c217d954SCole Faust     /** Construct a normalization layer.
872*c217d954SCole Faust      *
873*c217d954SCole Faust      * @param[in] norm_info Normalization information.
874*c217d954SCole Faust      */
NormalizationLayer(NormalizationLayerInfo norm_info)875*c217d954SCole Faust     NormalizationLayer(NormalizationLayerInfo norm_info)
876*c217d954SCole Faust         : _norm_info(norm_info)
877*c217d954SCole Faust     {
878*c217d954SCole Faust     }
879*c217d954SCole Faust 
create_layer(IStream & s)880*c217d954SCole Faust     NodeID create_layer(IStream &s) override
881*c217d954SCole Faust     {
882*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
883*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
884*c217d954SCole Faust         return GraphBuilder::add_normalization_node(s.graph(), common_params, input, _norm_info);
885*c217d954SCole Faust     }
886*c217d954SCole Faust 
887*c217d954SCole Faust private:
888*c217d954SCole Faust     NormalizationLayerInfo _norm_info;
889*c217d954SCole Faust };
890*c217d954SCole Faust 
891*c217d954SCole Faust /** Normalize planar YUV Layer */
892*c217d954SCole Faust class NormalizePlanarYUVLayer final : public ILayer
893*c217d954SCole Faust {
894*c217d954SCole Faust public:
895*c217d954SCole Faust     /** Construct a normalize planar YUV layer.
896*c217d954SCole Faust      *
897*c217d954SCole Faust      * @param[in] mean Accessor to get mean tensor data from.
898*c217d954SCole Faust      * @param[in] std  Accessor to get std tensor data from.
899*c217d954SCole Faust      */
NormalizePlanarYUVLayer(ITensorAccessorUPtr mean,ITensorAccessorUPtr std)900*c217d954SCole Faust     NormalizePlanarYUVLayer(ITensorAccessorUPtr mean,
901*c217d954SCole Faust                             ITensorAccessorUPtr std)
902*c217d954SCole Faust         : _mean(std::move(mean)), _std(std::move(std))
903*c217d954SCole Faust     {
904*c217d954SCole Faust     }
905*c217d954SCole Faust 
create_layer(IStream & s)906*c217d954SCole Faust     NodeID create_layer(IStream &s) override
907*c217d954SCole Faust     {
908*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON(_mean == nullptr);
909*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON(_std == nullptr);
910*c217d954SCole Faust 
911*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
912*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
913*c217d954SCole Faust         return GraphBuilder::add_normalize_planar_yuv_node(s.graph(), common_params, input,
914*c217d954SCole Faust                                                            std::move(_mean), std::move(_std));
915*c217d954SCole Faust     }
916*c217d954SCole Faust 
917*c217d954SCole Faust private:
918*c217d954SCole Faust     ITensorAccessorUPtr _mean;
919*c217d954SCole Faust     ITensorAccessorUPtr _std;
920*c217d954SCole Faust };
921*c217d954SCole Faust 
922*c217d954SCole Faust /** Pad Layer */
923*c217d954SCole Faust class PadLayer final : public ILayer
924*c217d954SCole Faust {
925*c217d954SCole Faust public:
926*c217d954SCole Faust     /** Construct a pad layer.
927*c217d954SCole Faust      *
928*c217d954SCole Faust      * @param[in] padding   The padding for each spatial dimension of the input tensor. The pair padding[i]
929*c217d954SCole Faust      *                      specifies the front and the end padding in the i-th dimension.
930*c217d954SCole Faust      * @param[in] pad_value Padding value to use. Defaults to 0.
931*c217d954SCole Faust      */
932*c217d954SCole Faust     PadLayer(PaddingList padding, PixelValue pad_value = PixelValue())
_padding(padding)933*c217d954SCole Faust         : _padding(padding), _pad_value(pad_value)
934*c217d954SCole Faust     {
935*c217d954SCole Faust     }
936*c217d954SCole Faust 
create_layer(IStream & s)937*c217d954SCole Faust     NodeID create_layer(IStream &s) override
938*c217d954SCole Faust     {
939*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
940*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
941*c217d954SCole Faust         return GraphBuilder::add_pad_node(s.graph(), common_params, input, _padding, _pad_value);
942*c217d954SCole Faust     }
943*c217d954SCole Faust 
944*c217d954SCole Faust private:
945*c217d954SCole Faust     PaddingList _padding;
946*c217d954SCole Faust     PixelValue  _pad_value;
947*c217d954SCole Faust };
948*c217d954SCole Faust 
949*c217d954SCole Faust /** Permute Layer */
950*c217d954SCole Faust class PermuteLayer final : public ILayer
951*c217d954SCole Faust {
952*c217d954SCole Faust public:
953*c217d954SCole Faust     /** Construct a permute layer.
954*c217d954SCole Faust      *
955*c217d954SCole Faust      * @param[in] perm   Permutation vector.
956*c217d954SCole Faust      * @param[in] layout (Optional) Data layout to assign to permuted tensor.
957*c217d954SCole Faust      *                   If UNKNOWN then the input's layout will be used.
958*c217d954SCole Faust      */
959*c217d954SCole Faust     PermuteLayer(PermutationVector perm, DataLayout layout = DataLayout::UNKNOWN)
_perm(perm)960*c217d954SCole Faust         : _perm(perm), _layout(layout)
961*c217d954SCole Faust     {
962*c217d954SCole Faust     }
963*c217d954SCole Faust 
create_layer(IStream & s)964*c217d954SCole Faust     NodeID create_layer(IStream &s) override
965*c217d954SCole Faust     {
966*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
967*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
968*c217d954SCole Faust         return GraphBuilder::add_permute_node(s.graph(), common_params, input, _perm, _layout);
969*c217d954SCole Faust     }
970*c217d954SCole Faust 
971*c217d954SCole Faust private:
972*c217d954SCole Faust     PermutationVector _perm;
973*c217d954SCole Faust     DataLayout        _layout;
974*c217d954SCole Faust };
975*c217d954SCole Faust 
976*c217d954SCole Faust /** Pooling Layer */
977*c217d954SCole Faust class PoolingLayer final : public ILayer
978*c217d954SCole Faust {
979*c217d954SCole Faust public:
980*c217d954SCole Faust     /** Construct a pooling layer.
981*c217d954SCole Faust      *
982*c217d954SCole Faust      * @param[in] pool_info Pooling information.
983*c217d954SCole Faust      */
PoolingLayer(PoolingLayerInfo pool_info)984*c217d954SCole Faust     PoolingLayer(PoolingLayerInfo pool_info)
985*c217d954SCole Faust         : _pool_info(pool_info)
986*c217d954SCole Faust     {
987*c217d954SCole Faust     }
988*c217d954SCole Faust 
create_layer(IStream & s)989*c217d954SCole Faust     NodeID create_layer(IStream &s) override
990*c217d954SCole Faust     {
991*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
992*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
993*c217d954SCole Faust         return GraphBuilder::add_pooling_node(s.graph(), common_params, input, _pool_info);
994*c217d954SCole Faust     }
995*c217d954SCole Faust 
996*c217d954SCole Faust private:
997*c217d954SCole Faust     PoolingLayerInfo _pool_info;
998*c217d954SCole Faust };
999*c217d954SCole Faust 
1000*c217d954SCole Faust /** PRelu Layer */
1001*c217d954SCole Faust class PReluLayer final : public ILayer
1002*c217d954SCole Faust {
1003*c217d954SCole Faust public:
1004*c217d954SCole Faust     /** Construct an PRelu operation layer
1005*c217d954SCole Faust      *
1006*c217d954SCole Faust      * @param[in] sub_stream0 First graph sub-stream
1007*c217d954SCole Faust      * @param[in] sub_stream1 First graph sub-stream
1008*c217d954SCole Faust      */
PReluLayer(SubStream && sub_stream0,SubStream && sub_stream1)1009*c217d954SCole Faust     PReluLayer(SubStream &&sub_stream0, SubStream &&sub_stream1)
1010*c217d954SCole Faust         : _ss0(std::move(sub_stream0)), _ss1(std::move(sub_stream1))
1011*c217d954SCole Faust     {
1012*c217d954SCole Faust     }
1013*c217d954SCole Faust 
create_layer(IStream & s)1014*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1015*c217d954SCole Faust     {
1016*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1017*c217d954SCole Faust         NodeIdxPair input         = { _ss0.tail_node(), 0 };
1018*c217d954SCole Faust         NodeIdxPair alpha         = { _ss1.tail_node(), 0 };
1019*c217d954SCole Faust 
1020*c217d954SCole Faust         return GraphBuilder::add_prelu_node(s.graph(), common_params, input, alpha);
1021*c217d954SCole Faust     }
1022*c217d954SCole Faust 
1023*c217d954SCole Faust private:
1024*c217d954SCole Faust     SubStream _ss0;
1025*c217d954SCole Faust     SubStream _ss1;
1026*c217d954SCole Faust };
1027*c217d954SCole Faust 
1028*c217d954SCole Faust /** Print Layer */
1029*c217d954SCole Faust class PrintLayer final : public ILayer
1030*c217d954SCole Faust {
1031*c217d954SCole Faust public:
1032*c217d954SCole Faust     /** Construct a print layer.
1033*c217d954SCole Faust      *
1034*c217d954SCole Faust      * Example usage to locally dequantize and print a tensor:
1035*c217d954SCole Faust      *
1036*c217d954SCole Faust      * Tensor *output = new Tensor();
1037*c217d954SCole Faust      * const auto transform = [output](ITensor *input)
1038*c217d954SCole Faust      * {
1039*c217d954SCole Faust      *     output->allocator()->init(*input->info());
1040*c217d954SCole Faust      *     output->info()->set_data_type(DataType::F32);
1041*c217d954SCole Faust      *     output->allocator()->allocate();
1042*c217d954SCole Faust      *
1043*c217d954SCole Faust      *     Window win;
1044*c217d954SCole Faust      *     win.use_tensor_dimensions(input->info()->tensor_shape());
1045*c217d954SCole Faust      *     Iterator in(input, win);
1046*c217d954SCole Faust      *     Iterator out(output, win);
1047*c217d954SCole Faust      *     execute_window_loop(win, [&](const Coordinates &)
1048*c217d954SCole Faust      *     {
1049*c217d954SCole Faust      *         *(reinterpret_cast<float *>(out.ptr())) = dequantize_qasymm8(*in.ptr(), input->info()->quantization_info().uniform());
1050*c217d954SCole Faust      *     }, in, out);
1051*c217d954SCole Faust      *
1052*c217d954SCole Faust      *     return output;
1053*c217d954SCole Faust      * };
1054*c217d954SCole Faust      *
1055*c217d954SCole Faust      * graph << InputLayer(input_descriptor.set_quantization_info(in_quant_info), get_input_accessor(common_params, nullptr, false))
1056*c217d954SCole Faust      *       << ...
1057*c217d954SCole Faust      *       << \\ CNN Layers
1058*c217d954SCole Faust      *       << ...
1059*c217d954SCole Faust      *       << PrintLayer(std::cout, IOFormatInfo(), transform)
1060*c217d954SCole Faust      *       << ...
1061*c217d954SCole Faust      *       << OutputLayer(get_output_accessor(common_params, 5));
1062*c217d954SCole Faust      *
1063*c217d954SCole Faust      * @param[in] stream      Output stream.
1064*c217d954SCole Faust      * @param[in] format_info (Optional) Format info.
1065*c217d954SCole Faust      * @param[in] transform   (Optional) Input transform function.
1066*c217d954SCole Faust      */
1067*c217d954SCole Faust     PrintLayer(std::ostream &stream, const IOFormatInfo &format_info = IOFormatInfo(), const std::function<ITensor *(ITensor *)> transform = nullptr)
_stream(stream)1068*c217d954SCole Faust         : _stream(stream), _format_info(format_info), _transform(transform)
1069*c217d954SCole Faust     {
1070*c217d954SCole Faust     }
1071*c217d954SCole Faust 
create_layer(IStream & s)1072*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1073*c217d954SCole Faust     {
1074*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1075*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1076*c217d954SCole Faust         return GraphBuilder::add_print_node(s.graph(), common_params, input, _stream, _format_info, _transform);
1077*c217d954SCole Faust     }
1078*c217d954SCole Faust 
1079*c217d954SCole Faust private:
1080*c217d954SCole Faust     std::ostream                             &_stream;
1081*c217d954SCole Faust     const IOFormatInfo                       &_format_info;
1082*c217d954SCole Faust     const std::function<ITensor *(ITensor *)> _transform;
1083*c217d954SCole Faust };
1084*c217d954SCole Faust 
1085*c217d954SCole Faust /** PriorBox Layer */
1086*c217d954SCole Faust class PriorBoxLayer final : public ILayer
1087*c217d954SCole Faust {
1088*c217d954SCole Faust public:
1089*c217d954SCole Faust     /** Construct a priorbox layer.
1090*c217d954SCole Faust      *
1091*c217d954SCole Faust      * @param[in] sub_stream First graph sub-stream
1092*c217d954SCole Faust      * @param[in] prior_info PriorBox parameters.
1093*c217d954SCole Faust      */
PriorBoxLayer(SubStream && sub_stream,const PriorBoxLayerInfo & prior_info)1094*c217d954SCole Faust     PriorBoxLayer(SubStream &&sub_stream, const PriorBoxLayerInfo &prior_info)
1095*c217d954SCole Faust         : _ss(std::move(sub_stream)), _prior_info(prior_info)
1096*c217d954SCole Faust     {
1097*c217d954SCole Faust     }
1098*c217d954SCole Faust 
create_layer(IStream & s)1099*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1100*c217d954SCole Faust     {
1101*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1102*c217d954SCole Faust         NodeIdxPair input0        = { s.tail_node(), 0 };
1103*c217d954SCole Faust         NodeIdxPair input1        = { _ss.tail_node(), 0 };
1104*c217d954SCole Faust         return GraphBuilder::add_priorbox_node(s.graph(), common_params, input0, input1, _prior_info);
1105*c217d954SCole Faust     }
1106*c217d954SCole Faust 
1107*c217d954SCole Faust private:
1108*c217d954SCole Faust     SubStream         _ss;
1109*c217d954SCole Faust     PriorBoxLayerInfo _prior_info;
1110*c217d954SCole Faust };
1111*c217d954SCole Faust 
1112*c217d954SCole Faust /** Quantization Layer */
1113*c217d954SCole Faust class QuantizationLayer final : public ILayer
1114*c217d954SCole Faust {
1115*c217d954SCole Faust public:
1116*c217d954SCole Faust     /** Construct a quantization layer.
1117*c217d954SCole Faust      *
1118*c217d954SCole Faust      * @param[in] out_quant_info Output tensor quantization info
1119*c217d954SCole Faust      */
QuantizationLayer(QuantizationInfo out_quant_info)1120*c217d954SCole Faust     QuantizationLayer(QuantizationInfo out_quant_info)
1121*c217d954SCole Faust         : _out_quant_info(out_quant_info)
1122*c217d954SCole Faust     {
1123*c217d954SCole Faust     }
1124*c217d954SCole Faust 
create_layer(IStream & s)1125*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1126*c217d954SCole Faust     {
1127*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1128*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1129*c217d954SCole Faust         return GraphBuilder::add_quantization_node(s.graph(), common_params, input, _out_quant_info);
1130*c217d954SCole Faust     }
1131*c217d954SCole Faust 
1132*c217d954SCole Faust private:
1133*c217d954SCole Faust     QuantizationInfo _out_quant_info;
1134*c217d954SCole Faust };
1135*c217d954SCole Faust 
1136*c217d954SCole Faust /** Reduction Layer */
1137*c217d954SCole Faust class ReductionLayer final : public ILayer
1138*c217d954SCole Faust {
1139*c217d954SCole Faust public:
1140*c217d954SCole Faust     /** Construct a reduction layer.
1141*c217d954SCole Faust      *
1142*c217d954SCole Faust      * @param[in] op        Reduction operation
1143*c217d954SCole Faust      * @param[in] axis      Reduction axis
1144*c217d954SCole Faust      * @param[in] keep_dims (Optional) Whether to keep the reduced dimension after the operation. Defaults to true.
1145*c217d954SCole Faust      */
ReductionLayer(ReductionOperation op,unsigned int axis,bool keep_dims)1146*c217d954SCole Faust     ReductionLayer(ReductionOperation op, unsigned int axis, bool keep_dims)
1147*c217d954SCole Faust         : _op(op), _axis(axis), _keep_dims(keep_dims)
1148*c217d954SCole Faust     {
1149*c217d954SCole Faust     }
1150*c217d954SCole Faust 
create_layer(IStream & s)1151*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1152*c217d954SCole Faust     {
1153*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1154*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1155*c217d954SCole Faust         return GraphBuilder::add_reduction_operation_node(s.graph(), common_params, input, _op, _axis, _keep_dims);
1156*c217d954SCole Faust     }
1157*c217d954SCole Faust 
1158*c217d954SCole Faust private:
1159*c217d954SCole Faust     ReductionOperation _op;
1160*c217d954SCole Faust     unsigned int       _axis;
1161*c217d954SCole Faust     bool               _keep_dims;
1162*c217d954SCole Faust };
1163*c217d954SCole Faust 
1164*c217d954SCole Faust /** Reorg Layer */
1165*c217d954SCole Faust class ReorgLayer final : public ILayer
1166*c217d954SCole Faust {
1167*c217d954SCole Faust public:
1168*c217d954SCole Faust     /** Construct a reorg layer.
1169*c217d954SCole Faust      *
1170*c217d954SCole Faust      * @param[in] stride Stride value to use for reorganizing the values in the output tensor.
1171*c217d954SCole Faust      *                   It defines the spatial distance between 2 consecutive pixels in the x and y direction
1172*c217d954SCole Faust      */
ReorgLayer(int stride)1173*c217d954SCole Faust     ReorgLayer(int stride)
1174*c217d954SCole Faust         : _stride(stride)
1175*c217d954SCole Faust     {
1176*c217d954SCole Faust     }
1177*c217d954SCole Faust 
create_layer(IStream & s)1178*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1179*c217d954SCole Faust     {
1180*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1181*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1182*c217d954SCole Faust         return GraphBuilder::add_reorg_node(s.graph(), common_params, input, _stride);
1183*c217d954SCole Faust     }
1184*c217d954SCole Faust 
1185*c217d954SCole Faust private:
1186*c217d954SCole Faust     int _stride;
1187*c217d954SCole Faust };
1188*c217d954SCole Faust 
1189*c217d954SCole Faust /** Reshape Layer */
1190*c217d954SCole Faust class ReshapeLayer final : public ILayer
1191*c217d954SCole Faust {
1192*c217d954SCole Faust public:
1193*c217d954SCole Faust     /** Construct a reshape layer.
1194*c217d954SCole Faust      *
1195*c217d954SCole Faust      * @param[in] shape Target shape.
1196*c217d954SCole Faust      */
ReshapeLayer(TensorShape shape)1197*c217d954SCole Faust     ReshapeLayer(TensorShape shape)
1198*c217d954SCole Faust         : _shape(shape)
1199*c217d954SCole Faust     {
1200*c217d954SCole Faust     }
1201*c217d954SCole Faust 
create_layer(IStream & s)1202*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1203*c217d954SCole Faust     {
1204*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1205*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1206*c217d954SCole Faust         return GraphBuilder::add_reshape_node(s.graph(), common_params, input, _shape);
1207*c217d954SCole Faust     }
1208*c217d954SCole Faust 
1209*c217d954SCole Faust private:
1210*c217d954SCole Faust     TensorShape _shape;
1211*c217d954SCole Faust };
1212*c217d954SCole Faust 
1213*c217d954SCole Faust /** Resize Layer */
1214*c217d954SCole Faust class ResizeLayer final : public ILayer
1215*c217d954SCole Faust {
1216*c217d954SCole Faust public:
ResizeLayer(InterpolationPolicy policy,float width_scale,float height_scale)1217*c217d954SCole Faust     ResizeLayer(InterpolationPolicy policy, float width_scale, float height_scale)
1218*c217d954SCole Faust         : _policy(policy), _width_scale(width_scale), _height_scale(height_scale)
1219*c217d954SCole Faust     {
1220*c217d954SCole Faust     }
1221*c217d954SCole Faust 
create_layer(IStream & s)1222*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1223*c217d954SCole Faust     {
1224*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1225*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1226*c217d954SCole Faust         return GraphBuilder::add_resize_node(s.graph(), common_params, input, _policy, _width_scale, _height_scale);
1227*c217d954SCole Faust     }
1228*c217d954SCole Faust 
1229*c217d954SCole Faust private:
1230*c217d954SCole Faust     InterpolationPolicy _policy;
1231*c217d954SCole Faust     float               _width_scale;
1232*c217d954SCole Faust     float               _height_scale;
1233*c217d954SCole Faust };
1234*c217d954SCole Faust 
1235*c217d954SCole Faust /** ROIAlign Layer */
1236*c217d954SCole Faust class ROIAlignLayer final : public ILayer
1237*c217d954SCole Faust {
1238*c217d954SCole Faust public:
1239*c217d954SCole Faust     /** Construct a RoiAlign layer.
1240*c217d954SCole Faust      *
1241*c217d954SCole Faust      * @param[in] sub_stream_input Graph sub-stream for the input
1242*c217d954SCole Faust      * @param[in] sub_stream_rois  Graph sub-stream for the rois
1243*c217d954SCole Faust      * @param[in] pool_info        Pooling information.
1244*c217d954SCole Faust      */
ROIAlignLayer(SubStream && sub_stream_input,SubStream && sub_stream_rois,ROIPoolingLayerInfo pool_info)1245*c217d954SCole Faust     ROIAlignLayer(SubStream &&sub_stream_input, SubStream &&sub_stream_rois, ROIPoolingLayerInfo pool_info)
1246*c217d954SCole Faust         : _ss_input(sub_stream_input), _ss_rois(sub_stream_rois), _pool_info(pool_info)
1247*c217d954SCole Faust     {
1248*c217d954SCole Faust     }
1249*c217d954SCole Faust 
1250*c217d954SCole Faust     /** Prevent instances of this class from being copy constructed */
1251*c217d954SCole Faust     ROIAlignLayer(const ROIAlignLayer &) = delete;
1252*c217d954SCole Faust     /** Prevent instances of this class from being copied */
1253*c217d954SCole Faust     ROIAlignLayer &operator=(const ROIAlignLayer &) = delete;
1254*c217d954SCole Faust 
create_layer(IStream & s)1255*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1256*c217d954SCole Faust     {
1257*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1258*c217d954SCole Faust         NodeIdxPair input         = { _ss_input.tail_node(), 0 };
1259*c217d954SCole Faust         NodeIdxPair rois          = { _ss_rois.tail_node(), 0 };
1260*c217d954SCole Faust         return GraphBuilder::add_roi_align_node(s.graph(), common_params, input, rois, _pool_info);
1261*c217d954SCole Faust     }
1262*c217d954SCole Faust 
1263*c217d954SCole Faust private:
1264*c217d954SCole Faust     SubStream           _ss_input;
1265*c217d954SCole Faust     SubStream           _ss_rois;
1266*c217d954SCole Faust     ROIPoolingLayerInfo _pool_info;
1267*c217d954SCole Faust };
1268*c217d954SCole Faust 
1269*c217d954SCole Faust /** Scale Layer */
1270*c217d954SCole Faust class ScaleLayer final : public ILayer
1271*c217d954SCole Faust {
1272*c217d954SCole Faust public:
1273*c217d954SCole Faust     /** Construct a scale layer.
1274*c217d954SCole Faust      *
1275*c217d954SCole Faust      * @param[in] mul_w Accessor to get mul weight from.
1276*c217d954SCole Faust      * @param[in] add_w Accessor to get add weight from.
1277*c217d954SCole Faust      */
ScaleLayer(ITensorAccessorUPtr mul_w,ITensorAccessorUPtr add_w)1278*c217d954SCole Faust     ScaleLayer(ITensorAccessorUPtr mul_w,
1279*c217d954SCole Faust                ITensorAccessorUPtr add_w)
1280*c217d954SCole Faust         : _mul_w(std::move(mul_w)), _add_w(std::move(add_w))
1281*c217d954SCole Faust     {
1282*c217d954SCole Faust     }
1283*c217d954SCole Faust 
create_layer(IStream & s)1284*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1285*c217d954SCole Faust     {
1286*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1287*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1288*c217d954SCole Faust         return GraphBuilder::add_scale_layer(s.graph(), common_params, input, std::move(_mul_w), std::move(_add_w));
1289*c217d954SCole Faust     }
1290*c217d954SCole Faust 
1291*c217d954SCole Faust private:
1292*c217d954SCole Faust     ITensorAccessorUPtr _mul_w;
1293*c217d954SCole Faust     ITensorAccessorUPtr _add_w;
1294*c217d954SCole Faust };
1295*c217d954SCole Faust 
1296*c217d954SCole Faust /** Slice Layer */
1297*c217d954SCole Faust class SliceLayer final : public ILayer
1298*c217d954SCole Faust {
1299*c217d954SCole Faust public:
1300*c217d954SCole Faust     /** Construct a slice layer.
1301*c217d954SCole Faust      *
1302*c217d954SCole Faust      * @param[in] starts The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input).
1303*c217d954SCole Faust      * @param[in] ends   The ends of the dimensions of the input tensor to be sliced. The length must be of rank(input).
1304*c217d954SCole Faust      */
SliceLayer(Coordinates & starts,Coordinates & ends)1305*c217d954SCole Faust     SliceLayer(Coordinates &starts, Coordinates &ends)
1306*c217d954SCole Faust         : _starts(starts), _ends(ends)
1307*c217d954SCole Faust     {
1308*c217d954SCole Faust     }
1309*c217d954SCole Faust 
create_layer(IStream & s)1310*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1311*c217d954SCole Faust     {
1312*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1313*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1314*c217d954SCole Faust         return GraphBuilder::add_slice_node(s.graph(), common_params, input, _starts, _ends);
1315*c217d954SCole Faust     }
1316*c217d954SCole Faust 
1317*c217d954SCole Faust private:
1318*c217d954SCole Faust     Coordinates _starts;
1319*c217d954SCole Faust     Coordinates _ends;
1320*c217d954SCole Faust };
1321*c217d954SCole Faust 
1322*c217d954SCole Faust /** Softmax Layer */
1323*c217d954SCole Faust class SoftmaxLayer final : public ILayer
1324*c217d954SCole Faust {
1325*c217d954SCole Faust public:
1326*c217d954SCole Faust     /** Construct a softmax layer.
1327*c217d954SCole Faust      *
1328*c217d954SCole Faust      * @param[in] beta (Optional) Beta value. Default 1.0.
1329*c217d954SCole Faust      */
1330*c217d954SCole Faust     SoftmaxLayer(float beta = 1.0f)
_beta(beta)1331*c217d954SCole Faust         : _beta(beta)
1332*c217d954SCole Faust     {
1333*c217d954SCole Faust     }
1334*c217d954SCole Faust 
create_layer(IStream & s)1335*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1336*c217d954SCole Faust     {
1337*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1338*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1339*c217d954SCole Faust         return GraphBuilder::add_softmax_node(s.graph(), common_params, input, _beta);
1340*c217d954SCole Faust     }
1341*c217d954SCole Faust 
1342*c217d954SCole Faust private:
1343*c217d954SCole Faust     float _beta;
1344*c217d954SCole Faust };
1345*c217d954SCole Faust 
1346*c217d954SCole Faust /** Stack Layer */
1347*c217d954SCole Faust class StackLayer final : public ILayer
1348*c217d954SCole Faust {
1349*c217d954SCole Faust public:
1350*c217d954SCole Faust     /** Construct a concatenation layer
1351*c217d954SCole Faust      *
1352*c217d954SCole Faust      * @param[in] sub_stream1      First graph branch
1353*c217d954SCole Faust      * @param[in] sub_stream2      Second graph branch
1354*c217d954SCole Faust      * @param[in] rest_sub_streams Rest sub-graph branches
1355*c217d954SCole Faust      */
1356*c217d954SCole Faust     template <typename... Ts>
StackLayer(SubStream && sub_stream1,SubStream && sub_stream2,Ts &&...rest_sub_streams)1357*c217d954SCole Faust     StackLayer(SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams)
1358*c217d954SCole Faust         : _sub_streams(), _axis(0)
1359*c217d954SCole Faust     {
1360*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream1)));
1361*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream2)));
1362*c217d954SCole Faust 
1363*c217d954SCole Faust         utility::for_each([&](SubStream && sub_stream)
1364*c217d954SCole Faust         {
1365*c217d954SCole Faust             _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream)));
1366*c217d954SCole Faust         },
1367*c217d954SCole Faust         std::move(rest_sub_streams)...);
1368*c217d954SCole Faust     }
1369*c217d954SCole Faust     /** Construct a concatenation layer
1370*c217d954SCole Faust      *
1371*c217d954SCole Faust      * @param[in] axis             Stack layer axis along which to stack the inputs
1372*c217d954SCole Faust      * @param[in] sub_stream1      First graph branch
1373*c217d954SCole Faust      * @param[in] sub_stream2      Second graph branch
1374*c217d954SCole Faust      * @param[in] rest_sub_streams Rest sub-graph branches
1375*c217d954SCole Faust      */
1376*c217d954SCole Faust     template <typename... Ts>
StackLayer(int axis,SubStream && sub_stream1,SubStream && sub_stream2,Ts &&...rest_sub_streams)1377*c217d954SCole Faust     StackLayer(int axis, SubStream &&sub_stream1, SubStream &&sub_stream2, Ts &&... rest_sub_streams)
1378*c217d954SCole Faust         : _sub_streams(), _axis(axis)
1379*c217d954SCole Faust     {
1380*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream1)));
1381*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream2)));
1382*c217d954SCole Faust 
1383*c217d954SCole Faust         utility::for_each([&](SubStream && sub_stream)
1384*c217d954SCole Faust         {
1385*c217d954SCole Faust             _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream)));
1386*c217d954SCole Faust         },
1387*c217d954SCole Faust         std::move(rest_sub_streams)...);
1388*c217d954SCole Faust     }
1389*c217d954SCole Faust     /** Construct a concat layer
1390*c217d954SCole Faust      *
1391*c217d954SCole Faust      * @param[in] sub_stream Sub-stream
1392*c217d954SCole Faust      */
1393*c217d954SCole Faust     template <typename... Ts>
StackLayer(SubStream && sub_stream)1394*c217d954SCole Faust     StackLayer(SubStream &&sub_stream)
1395*c217d954SCole Faust         : _sub_streams(), _axis(0)
1396*c217d954SCole Faust     {
1397*c217d954SCole Faust         _sub_streams.push_back(std::make_unique<SubStream>(std::move(sub_stream)));
1398*c217d954SCole Faust     }
create_layer(IStream & s)1399*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1400*c217d954SCole Faust     {
1401*c217d954SCole Faust         NodeID     nid           = EmptyNodeID;
1402*c217d954SCole Faust         NodeParams common_params = { name(), s.hints().target_hint };
1403*c217d954SCole Faust         if(_sub_streams.size() == 1 && _sub_streams.at(0) != nullptr)
1404*c217d954SCole Faust         {
1405*c217d954SCole Faust             nid = _sub_streams[0]->tail_node();
1406*c217d954SCole Faust         }
1407*c217d954SCole Faust         else
1408*c217d954SCole Faust         {
1409*c217d954SCole Faust             // Collect tail nodes and stack
1410*c217d954SCole Faust             std::vector<NodeIdxPair> nodes;
1411*c217d954SCole Faust             for(auto &ss : _sub_streams)
1412*c217d954SCole Faust             {
1413*c217d954SCole Faust                 if(ss && (ss->tail_node() != EmptyNodeID))
1414*c217d954SCole Faust                 {
1415*c217d954SCole Faust                     const auto tail_node = s.graph().node(ss->tail_node());
1416*c217d954SCole Faust                     if(tail_node != nullptr && tail_node->type() != NodeType::Output)
1417*c217d954SCole Faust                     {
1418*c217d954SCole Faust                         nodes.push_back({ ss->tail_node(), 0 });
1419*c217d954SCole Faust                     }
1420*c217d954SCole Faust                 }
1421*c217d954SCole Faust             }
1422*c217d954SCole Faust             nid = GraphBuilder::add_stack_node(s.graph(), common_params, nodes, _axis);
1423*c217d954SCole Faust         }
1424*c217d954SCole Faust         return nid;
1425*c217d954SCole Faust     }
1426*c217d954SCole Faust 
1427*c217d954SCole Faust private:
1428*c217d954SCole Faust     std::vector<std::unique_ptr<SubStream>> _sub_streams;
1429*c217d954SCole Faust     int                                     _axis;
1430*c217d954SCole Faust };
1431*c217d954SCole Faust 
1432*c217d954SCole Faust /** StridedSlice Layer */
1433*c217d954SCole Faust class StridedSliceLayer final : public ILayer
1434*c217d954SCole Faust {
1435*c217d954SCole Faust public:
1436*c217d954SCole Faust     /** Construct a strided slice layer.
1437*c217d954SCole Faust      *
1438*c217d954SCole Faust      * @param[in] starts             The starts of the dimensions of the input tensor to be sliced. The length must be of rank(input).
1439*c217d954SCole Faust      * @param[in] ends               The ends of the dimensions of the input tensor to be sliced. The length must be of rank(input).
1440*c217d954SCole Faust      * @param[in] strides            The strides of the dimensions of the input tensor to be sliced. The length must be of rank(input).
1441*c217d954SCole Faust      * @param[in] strided_slice_info Contains masks for the starts, ends and strides
1442*c217d954SCole Faust      */
StridedSliceLayer(Coordinates & starts,Coordinates & ends,BiStrides & strides,StridedSliceLayerInfo strided_slice_info)1443*c217d954SCole Faust     StridedSliceLayer(Coordinates &starts, Coordinates &ends, BiStrides &strides, StridedSliceLayerInfo strided_slice_info)
1444*c217d954SCole Faust         : _starts(starts), _ends(ends), _strides(strides), _info(strided_slice_info)
1445*c217d954SCole Faust     {
1446*c217d954SCole Faust     }
1447*c217d954SCole Faust 
create_layer(IStream & s)1448*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1449*c217d954SCole Faust     {
1450*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1451*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1452*c217d954SCole Faust         return GraphBuilder::add_strided_slice_node(s.graph(), common_params, input, _starts, _ends, _strides, _info);
1453*c217d954SCole Faust     }
1454*c217d954SCole Faust 
1455*c217d954SCole Faust private:
1456*c217d954SCole Faust     Coordinates           _starts;
1457*c217d954SCole Faust     Coordinates           _ends;
1458*c217d954SCole Faust     BiStrides             _strides;
1459*c217d954SCole Faust     StridedSliceLayerInfo _info;
1460*c217d954SCole Faust };
1461*c217d954SCole Faust 
1462*c217d954SCole Faust /** YOLO Layer */
1463*c217d954SCole Faust class YOLOLayer final : public ILayer
1464*c217d954SCole Faust {
1465*c217d954SCole Faust public:
1466*c217d954SCole Faust     /** Construct a YOLO layer.
1467*c217d954SCole Faust      *
1468*c217d954SCole Faust      * @param[in] act_info Activation info
1469*c217d954SCole Faust      */
YOLOLayer(ActivationLayerInfo act_info)1470*c217d954SCole Faust     YOLOLayer(ActivationLayerInfo act_info)
1471*c217d954SCole Faust         : _act_info(act_info)
1472*c217d954SCole Faust     {
1473*c217d954SCole Faust     }
1474*c217d954SCole Faust 
create_layer(IStream & s)1475*c217d954SCole Faust     NodeID create_layer(IStream &s) override
1476*c217d954SCole Faust     {
1477*c217d954SCole Faust         NodeParams  common_params = { name(), s.hints().target_hint };
1478*c217d954SCole Faust         NodeIdxPair input         = { s.tail_node(), 0 };
1479*c217d954SCole Faust         return GraphBuilder::add_yolo_node(s.graph(), common_params, input, _act_info);
1480*c217d954SCole Faust     }
1481*c217d954SCole Faust 
1482*c217d954SCole Faust private:
1483*c217d954SCole Faust     ActivationLayerInfo _act_info;
1484*c217d954SCole Faust };
1485*c217d954SCole Faust } // namespace frontend
1486*c217d954SCole Faust } // namespace graph
1487*c217d954SCole Faust } // namespace arm_compute
1488*c217d954SCole Faust #endif /* ARM_COMPUTE_GRAPH_LAYERS_H */
1489