xref: /aosp_15_r20/external/armnn/include/armnn/INetwork.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendOptions.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Deprecated.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/DescriptorsFwd.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IStrategy.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/NetworkFwd.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Optional.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TensorFwd.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <memory>
18*89c4ff92SAndroid Build Coastguard Worker #include <vector>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker namespace armnn
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker /// @brief An input connection slot for a layer.
23*89c4ff92SAndroid Build Coastguard Worker /// The input slot can be connected to an output slot of the preceding layer in the graph.
24*89c4ff92SAndroid Build Coastguard Worker /// Only one connection to the input slot is allowed.
25*89c4ff92SAndroid Build Coastguard Worker class IInputSlot
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker public:
28*89c4ff92SAndroid Build Coastguard Worker     virtual const IOutputSlot* GetConnection() const = 0;
29*89c4ff92SAndroid Build Coastguard Worker     virtual IOutputSlot* GetConnection() = 0;
30*89c4ff92SAndroid Build Coastguard Worker     virtual const IConnectableLayer& GetOwningIConnectableLayer() const = 0;
31*89c4ff92SAndroid Build Coastguard Worker     virtual IConnectableLayer& GetOwningIConnectableLayer() = 0;
32*89c4ff92SAndroid Build Coastguard Worker     virtual unsigned int GetSlotIndex() const = 0;
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker protected:
35*89c4ff92SAndroid Build Coastguard Worker    /// Not user deletable.
~IInputSlot()36*89c4ff92SAndroid Build Coastguard Worker     ~IInputSlot() {}
37*89c4ff92SAndroid Build Coastguard Worker };
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker /// @brief An output connection slot for a layer.
40*89c4ff92SAndroid Build Coastguard Worker /// The output slot may be connected to 1 or more input slots of subsequent layers in the graph.
41*89c4ff92SAndroid Build Coastguard Worker class IOutputSlot
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker public:
44*89c4ff92SAndroid Build Coastguard Worker     virtual unsigned int GetNumConnections() const = 0;
45*89c4ff92SAndroid Build Coastguard Worker     virtual const IInputSlot* GetConnection(unsigned int index) const = 0;
46*89c4ff92SAndroid Build Coastguard Worker     virtual IInputSlot* GetConnection(unsigned int outputindex) = 0;
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     virtual void SetTensorInfo(const TensorInfo& tensorInfo) = 0;
49*89c4ff92SAndroid Build Coastguard Worker     virtual const TensorInfo& GetTensorInfo() const = 0;
50*89c4ff92SAndroid Build Coastguard Worker     virtual bool IsTensorInfoSet() const = 0;
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     virtual int Connect(IInputSlot& destination) = 0;
53*89c4ff92SAndroid Build Coastguard Worker     virtual void Disconnect(IInputSlot& slot) = 0;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     virtual unsigned int CalculateIndexOnOwner() const = 0;
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     virtual LayerGuid GetOwningLayerGuid() const = 0;
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     virtual const IConnectableLayer& GetOwningIConnectableLayer() const = 0;
60*89c4ff92SAndroid Build Coastguard Worker     virtual IConnectableLayer& GetOwningIConnectableLayer() = 0;
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker protected:
63*89c4ff92SAndroid Build Coastguard Worker     /// Not user deletable.
~IOutputSlot()64*89c4ff92SAndroid Build Coastguard Worker     ~IOutputSlot() {}
65*89c4ff92SAndroid Build Coastguard Worker };
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker /// @brief Interface for a layer that is connectable to other layers via InputSlots and OutputSlots.
68*89c4ff92SAndroid Build Coastguard Worker class IConnectableLayer
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker public:
71*89c4ff92SAndroid Build Coastguard Worker     /// Returns the name of the layer
72*89c4ff92SAndroid Build Coastguard Worker     virtual const char* GetName() const = 0;
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     /// Returns the number of connectable input slots
75*89c4ff92SAndroid Build Coastguard Worker     virtual unsigned int GetNumInputSlots() const = 0;
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     /// Returns the number of connectable output slots
78*89c4ff92SAndroid Build Coastguard Worker     virtual unsigned int GetNumOutputSlots() const = 0;
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     /// Get a const input slot handle by slot index
81*89c4ff92SAndroid Build Coastguard Worker     virtual const IInputSlot& GetInputSlot(unsigned int index) const = 0;
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker     /// Get the input slot handle by slot index
84*89c4ff92SAndroid Build Coastguard Worker     virtual IInputSlot& GetInputSlot(unsigned int index) = 0;
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     /// Get the const output slot handle by slot index
87*89c4ff92SAndroid Build Coastguard Worker     virtual const IOutputSlot& GetOutputSlot(unsigned int index) const = 0;
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     /// Get the output slot handle by slot index
90*89c4ff92SAndroid Build Coastguard Worker     virtual IOutputSlot& GetOutputSlot(unsigned int index) = 0;
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker     /// Infer the shape of the output(s) based on the provided input shape(s)
93*89c4ff92SAndroid Build Coastguard Worker     virtual std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const = 0;
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     /// Returns the unique id of the layer
96*89c4ff92SAndroid Build Coastguard Worker     virtual LayerGuid GetGuid() const = 0;
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     /// Apply a visitor to this layer
99*89c4ff92SAndroid Build Coastguard Worker     virtual void ExecuteStrategy(IStrategy& strategy) const = 0;
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker     /// Provide a hint for the optimizer as to which backend to prefer for this layer.
102*89c4ff92SAndroid Build Coastguard Worker     /// By providing a BackendSelectionHint there is no guarantee the input backend supports that layer.
103*89c4ff92SAndroid Build Coastguard Worker     /// If IsLayerSupported() returns false with the backend hint, we default to calling IsLayerSupported()
104*89c4ff92SAndroid Build Coastguard Worker     /// on the BackendPreferences vector. Use SetBackendId() if we can guarantee a backend supports that
105*89c4ff92SAndroid Build Coastguard Worker     /// layer (IsLayerSupported returns true for a specific backend).
106*89c4ff92SAndroid Build Coastguard Worker     virtual void BackendSelectionHint(Optional<BackendId> backend) = 0;
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker     /// Returns the armnn::LayerType of this layer
109*89c4ff92SAndroid Build Coastguard Worker     virtual LayerType GetType() const = 0;
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     /// If the layer has a descriptor return it.
112*89c4ff92SAndroid Build Coastguard Worker     /// The base descriptor can then be cast to the correct descriptor class.
113*89c4ff92SAndroid Build Coastguard Worker     /// If the layer has no associated descriptor a struct of type NullDescriptor will be returned.
114*89c4ff92SAndroid Build Coastguard Worker     /// Note: NullDescriptors can be detected because they return true when
115*89c4ff92SAndroid Build Coastguard Worker     /// the BaseDescriptor IsNull function is invoked.
116*89c4ff92SAndroid Build Coastguard Worker     virtual const BaseDescriptor& GetParameters() const = 0;
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker     /// Set the backend of the IConnectableLayer.
119*89c4ff92SAndroid Build Coastguard Worker     /// By using SetBackendId() we guarantee that the input backend supports that
120*89c4ff92SAndroid Build Coastguard Worker     /// layer (IsLayerSupported returns true for a specific backend). If there is
121*89c4ff92SAndroid Build Coastguard Worker     /// no guarantee the input backend supports that layer use BackendSelectionHint().
122*89c4ff92SAndroid Build Coastguard Worker     virtual void SetBackendId(const BackendId& id) = 0;
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     using ConstantTensors = std::vector<std::reference_wrapper<std::shared_ptr<ConstTensorHandle>>>;
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     // Returns ConstantTensors of this Layer if it has any, otherwise returns empty vector.
127*89c4ff92SAndroid Build Coastguard Worker     virtual ConstantTensors GetConstantTensorsByRef() = 0;
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker     using ImmutableConstantTensors = std::vector<std::reference_wrapper<const std::shared_ptr<ConstTensorHandle>>>;
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker     // Returns ConstantTensors of this Layer if it has any, otherwise returns empty vector.
132*89c4ff92SAndroid Build Coastguard Worker     virtual ImmutableConstantTensors GetConstantTensorsByRef() const = 0;
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker protected:
135*89c4ff92SAndroid Build Coastguard Worker       /// Objects are not deletable via the handle
~IConnectableLayer()136*89c4ff92SAndroid Build Coastguard Worker     ~IConnectableLayer() {}
137*89c4ff92SAndroid Build Coastguard Worker };
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker struct OptimizerOptions
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use ABI stable OptimizerOptionsOpaque instead.", "24.02")
OptimizerOptionsarmnn::OptimizerOptions142*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptions()
143*89c4ff92SAndroid Build Coastguard Worker             : m_ReduceFp32ToFp16(false)
144*89c4ff92SAndroid Build Coastguard Worker             , m_Debug(false)
145*89c4ff92SAndroid Build Coastguard Worker             , m_DebugToFile(false)
146*89c4ff92SAndroid Build Coastguard Worker             , m_ReduceFp32ToBf16(false)
147*89c4ff92SAndroid Build Coastguard Worker             , m_shapeInferenceMethod(armnn::ShapeInferenceMethod::ValidateOnly)
148*89c4ff92SAndroid Build Coastguard Worker             , m_ImportEnabled(false)
149*89c4ff92SAndroid Build Coastguard Worker             , m_ModelOptions()
150*89c4ff92SAndroid Build Coastguard Worker             , m_ProfilingEnabled(false)
151*89c4ff92SAndroid Build Coastguard Worker             , m_ExportEnabled(false)
152*89c4ff92SAndroid Build Coastguard Worker             , m_AllowExpandedDims(false)
153*89c4ff92SAndroid Build Coastguard Worker     {}
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use ABI stable OptimizerOptionsOpaque instead.", "24.02")
OptimizerOptionsarmnn::OptimizerOptions156*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptions(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16, bool importEnabled,
157*89c4ff92SAndroid Build Coastguard Worker                      ModelOptions modelOptions = {}, bool exportEnabled = false, bool debugToFile = false)
158*89c4ff92SAndroid Build Coastguard Worker             : m_ReduceFp32ToFp16(reduceFp32ToFp16)
159*89c4ff92SAndroid Build Coastguard Worker             , m_Debug(debug)
160*89c4ff92SAndroid Build Coastguard Worker             , m_DebugToFile(debugToFile)
161*89c4ff92SAndroid Build Coastguard Worker             , m_ReduceFp32ToBf16(reduceFp32ToBf16)
162*89c4ff92SAndroid Build Coastguard Worker             , m_shapeInferenceMethod(armnn::ShapeInferenceMethod::ValidateOnly)
163*89c4ff92SAndroid Build Coastguard Worker             , m_ImportEnabled(importEnabled)
164*89c4ff92SAndroid Build Coastguard Worker             , m_ModelOptions(modelOptions)
165*89c4ff92SAndroid Build Coastguard Worker             , m_ProfilingEnabled(false)
166*89c4ff92SAndroid Build Coastguard Worker             , m_ExportEnabled(exportEnabled)
167*89c4ff92SAndroid Build Coastguard Worker             , m_AllowExpandedDims(false)
168*89c4ff92SAndroid Build Coastguard Worker     {
169*89c4ff92SAndroid Build Coastguard Worker     }
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use ABI stable OptimizerOptionsOpaque instead.", "24.02")
OptimizerOptionsarmnn::OptimizerOptions172*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptions(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16 = false,
173*89c4ff92SAndroid Build Coastguard Worker                      ShapeInferenceMethod shapeInferenceMethod = armnn::ShapeInferenceMethod::ValidateOnly,
174*89c4ff92SAndroid Build Coastguard Worker                      bool importEnabled = false, ModelOptions modelOptions = {}, bool exportEnabled = false,
175*89c4ff92SAndroid Build Coastguard Worker                      bool debugToFile = false, bool allowExpandedDims = false)
176*89c4ff92SAndroid Build Coastguard Worker             : m_ReduceFp32ToFp16(reduceFp32ToFp16)
177*89c4ff92SAndroid Build Coastguard Worker             , m_Debug(debug)
178*89c4ff92SAndroid Build Coastguard Worker             , m_DebugToFile(debugToFile)
179*89c4ff92SAndroid Build Coastguard Worker             , m_ReduceFp32ToBf16(reduceFp32ToBf16)
180*89c4ff92SAndroid Build Coastguard Worker             , m_shapeInferenceMethod(shapeInferenceMethod)
181*89c4ff92SAndroid Build Coastguard Worker             , m_ImportEnabled(importEnabled)
182*89c4ff92SAndroid Build Coastguard Worker             , m_ModelOptions(modelOptions)
183*89c4ff92SAndroid Build Coastguard Worker             , m_ProfilingEnabled(false)
184*89c4ff92SAndroid Build Coastguard Worker             , m_ExportEnabled(exportEnabled)
185*89c4ff92SAndroid Build Coastguard Worker             , m_AllowExpandedDims(allowExpandedDims)
186*89c4ff92SAndroid Build Coastguard Worker     {
187*89c4ff92SAndroid Build Coastguard Worker     }
188*89c4ff92SAndroid Build Coastguard Worker 
ToStringarmnn::OptimizerOptions189*89c4ff92SAndroid Build Coastguard Worker     const std::string ToString() const
190*89c4ff92SAndroid Build Coastguard Worker     {
191*89c4ff92SAndroid Build Coastguard Worker         std::stringstream stream;
192*89c4ff92SAndroid Build Coastguard Worker         stream << "OptimizerOptions: \n";
193*89c4ff92SAndroid Build Coastguard Worker         stream << "\tReduceFp32ToFp16: " << m_ReduceFp32ToFp16 << "\n";
194*89c4ff92SAndroid Build Coastguard Worker         stream << "\tReduceFp32ToBf16: " << m_ReduceFp32ToBf16 << "\n";
195*89c4ff92SAndroid Build Coastguard Worker         stream << "\tDebug: " << m_Debug << "\n";
196*89c4ff92SAndroid Build Coastguard Worker         stream << "\tDebug to file: " << m_DebugToFile << "\n";
197*89c4ff92SAndroid Build Coastguard Worker         stream << "\tShapeInferenceMethod: " <<
198*89c4ff92SAndroid Build Coastguard Worker                (m_shapeInferenceMethod == ShapeInferenceMethod::ValidateOnly
199*89c4ff92SAndroid Build Coastguard Worker                ? "ValidateOnly" : "InferAndValidate") << "\n";
200*89c4ff92SAndroid Build Coastguard Worker         stream << "\tImportEnabled: " << m_ImportEnabled << "\n";
201*89c4ff92SAndroid Build Coastguard Worker         stream << "\tExportEnabled: " << m_ExportEnabled << "\n";
202*89c4ff92SAndroid Build Coastguard Worker         stream << "\tProfilingEnabled: " << m_ProfilingEnabled << "\n";
203*89c4ff92SAndroid Build Coastguard Worker         stream << "\tAllowExpandedDims: " << m_AllowExpandedDims << "\n";
204*89c4ff92SAndroid Build Coastguard Worker 
205*89c4ff92SAndroid Build Coastguard Worker         stream << "\tModelOptions: \n";
206*89c4ff92SAndroid Build Coastguard Worker         for (auto optionsGroup : m_ModelOptions)
207*89c4ff92SAndroid Build Coastguard Worker         {
208*89c4ff92SAndroid Build Coastguard Worker             for (size_t i=0; i < optionsGroup.GetOptionCount(); i++)
209*89c4ff92SAndroid Build Coastguard Worker             {
210*89c4ff92SAndroid Build Coastguard Worker                 const armnn::BackendOptions::BackendOption option = optionsGroup.GetOption(i);
211*89c4ff92SAndroid Build Coastguard Worker                 stream << "\t\tBackend: "  << optionsGroup.GetBackendId() << "\n"
212*89c4ff92SAndroid Build Coastguard Worker                        << "\t\t\tOption: " << option.GetName() << "\n"
213*89c4ff92SAndroid Build Coastguard Worker                        << "\t\t\tValue: "  << std::string(option.GetValue().ToString()) << "\n";
214*89c4ff92SAndroid Build Coastguard Worker             }
215*89c4ff92SAndroid Build Coastguard Worker         }
216*89c4ff92SAndroid Build Coastguard Worker 
217*89c4ff92SAndroid Build Coastguard Worker         return stream.str();
218*89c4ff92SAndroid Build Coastguard Worker     }
219*89c4ff92SAndroid Build Coastguard Worker 
220*89c4ff92SAndroid Build Coastguard Worker     /// Reduces all Fp32 operators in the model to Fp16 for faster processing.
221*89c4ff92SAndroid Build Coastguard Worker     /// @Note This feature works best if all operators of the model are in Fp32. ArmNN will add conversion layers
222*89c4ff92SAndroid Build Coastguard Worker     ///       between layers that weren't in Fp32 in the first place or if the operator is not supported in Fp16.
223*89c4ff92SAndroid Build Coastguard Worker     ///       The overhead of these conversions can lead to a slower overall performance if too many conversions are
224*89c4ff92SAndroid Build Coastguard Worker     ///       required.
225*89c4ff92SAndroid Build Coastguard Worker     bool m_ReduceFp32ToFp16;
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker     /// Add debug data for easier troubleshooting
228*89c4ff92SAndroid Build Coastguard Worker     bool m_Debug;
229*89c4ff92SAndroid Build Coastguard Worker 
230*89c4ff92SAndroid Build Coastguard Worker     /// Pass debug data to separate output files for easier troubleshooting
231*89c4ff92SAndroid Build Coastguard Worker     bool m_DebugToFile;
232*89c4ff92SAndroid Build Coastguard Worker 
233*89c4ff92SAndroid Build Coastguard Worker     /// @Note This feature has been replaced by enabling Fast Math in compute library backend options.
234*89c4ff92SAndroid Build Coastguard Worker     /// This is currently a placeholder option
235*89c4ff92SAndroid Build Coastguard Worker     bool m_ReduceFp32ToBf16;
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     /// Infer output size when not available
238*89c4ff92SAndroid Build Coastguard Worker     ShapeInferenceMethod m_shapeInferenceMethod;
239*89c4ff92SAndroid Build Coastguard Worker 
240*89c4ff92SAndroid Build Coastguard Worker     /// Enable Import
241*89c4ff92SAndroid Build Coastguard Worker     bool m_ImportEnabled;
242*89c4ff92SAndroid Build Coastguard Worker 
243*89c4ff92SAndroid Build Coastguard Worker     /// Enable Model Options
244*89c4ff92SAndroid Build Coastguard Worker     ModelOptions m_ModelOptions;
245*89c4ff92SAndroid Build Coastguard Worker 
246*89c4ff92SAndroid Build Coastguard Worker     /// Enable profiling dump of the optimizer phase
247*89c4ff92SAndroid Build Coastguard Worker     bool m_ProfilingEnabled;
248*89c4ff92SAndroid Build Coastguard Worker 
249*89c4ff92SAndroid Build Coastguard Worker     /// Enable Export
250*89c4ff92SAndroid Build Coastguard Worker     bool m_ExportEnabled;
251*89c4ff92SAndroid Build Coastguard Worker 
252*89c4ff92SAndroid Build Coastguard Worker     /// When calculating tensor sizes, dimensions of size == 1 will be ignored
253*89c4ff92SAndroid Build Coastguard Worker     bool m_AllowExpandedDims;
254*89c4ff92SAndroid Build Coastguard Worker };
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker /// ArmNN performs an optimization on each model/network before it gets loaded for execution. OptimizerOptions provides
257*89c4ff92SAndroid Build Coastguard Worker /// a set of features that allows the user to customize this optimization on a per model basis.
258*89c4ff92SAndroid Build Coastguard Worker struct OptimizerOptionsOpaqueImpl;
259*89c4ff92SAndroid Build Coastguard Worker 
260*89c4ff92SAndroid Build Coastguard Worker class OptimizerOptionsOpaque
261*89c4ff92SAndroid Build Coastguard Worker {
262*89c4ff92SAndroid Build Coastguard Worker public:
263*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque();
264*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque(const OptimizerOptionsOpaque& other);
265*89c4ff92SAndroid Build Coastguard Worker     ~OptimizerOptionsOpaque();
266*89c4ff92SAndroid Build Coastguard Worker 
267*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque(const OptimizerOptions& OptimizerStruct);
268*89c4ff92SAndroid Build Coastguard Worker 
269*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque& operator=(OptimizerOptionsOpaque other);
270*89c4ff92SAndroid Build Coastguard Worker 
271*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16, bool importEnabled,
272*89c4ff92SAndroid Build Coastguard Worker                            ModelOptions modelOptions = {}, bool exportEnabled = false, bool debugToFile = false);
273*89c4ff92SAndroid Build Coastguard Worker 
274*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16 = false,
275*89c4ff92SAndroid Build Coastguard Worker                            ShapeInferenceMethod shapeInferenceMethod = armnn::ShapeInferenceMethod::ValidateOnly,
276*89c4ff92SAndroid Build Coastguard Worker                            bool importEnabled = false, ModelOptions modelOptions = {}, bool exportEnabled = false,
277*89c4ff92SAndroid Build Coastguard Worker                            bool debugToFile = false, bool allowExpandedDims = false);
278*89c4ff92SAndroid Build Coastguard Worker 
279*89c4ff92SAndroid Build Coastguard Worker     const std::string ToString() const;
280*89c4ff92SAndroid Build Coastguard Worker 
281*89c4ff92SAndroid Build Coastguard Worker     bool GetProfilingEnabled() const;
282*89c4ff92SAndroid Build Coastguard Worker 
283*89c4ff92SAndroid Build Coastguard Worker     bool GetImportEnabled() const;
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker     bool GetExportEnabled() const;
286*89c4ff92SAndroid Build Coastguard Worker 
287*89c4ff92SAndroid Build Coastguard Worker     bool GetReduceFp32ToFp16() const;
288*89c4ff92SAndroid Build Coastguard Worker 
289*89c4ff92SAndroid Build Coastguard Worker     bool GetReduceFp32ToBf16() const;
290*89c4ff92SAndroid Build Coastguard Worker 
291*89c4ff92SAndroid Build Coastguard Worker     bool GetDebugEnabled() const;
292*89c4ff92SAndroid Build Coastguard Worker 
293*89c4ff92SAndroid Build Coastguard Worker     bool GetDebugToFileEnabled() const;
294*89c4ff92SAndroid Build Coastguard Worker 
295*89c4ff92SAndroid Build Coastguard Worker     bool GetAllowExpandedDims() const;
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker     armnn::ModelOptions GetModelOptions() const;
298*89c4ff92SAndroid Build Coastguard Worker 
299*89c4ff92SAndroid Build Coastguard Worker     armnn::ShapeInferenceMethod GetShapeInferenceMethod() const;
300*89c4ff92SAndroid Build Coastguard Worker 
301*89c4ff92SAndroid Build Coastguard Worker     void SetImportEnabled(bool ImportState);
302*89c4ff92SAndroid Build Coastguard Worker 
303*89c4ff92SAndroid Build Coastguard Worker     void SetExportEnabled(bool ExportState);
304*89c4ff92SAndroid Build Coastguard Worker 
305*89c4ff92SAndroid Build Coastguard Worker     void SetProfilingEnabled(bool ProfilingState);
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker     void SetDebugEnabled(bool DebugState);
308*89c4ff92SAndroid Build Coastguard Worker 
309*89c4ff92SAndroid Build Coastguard Worker     void SetDebugToFileEnabled(bool DebugFileState);
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker     void SetReduceFp32ToFp16(bool ReduceFp32ToFp16State);
312*89c4ff92SAndroid Build Coastguard Worker 
313*89c4ff92SAndroid Build Coastguard Worker     void SetShapeInferenceMethod(armnn::ShapeInferenceMethod ShapeInferenceMethodType);
314*89c4ff92SAndroid Build Coastguard Worker 
315*89c4ff92SAndroid Build Coastguard Worker     void AddModelOption(armnn::BackendOptions);
316*89c4ff92SAndroid Build Coastguard Worker 
317*89c4ff92SAndroid Build Coastguard Worker     void SetAllowExpandedDims(bool ExpandedDimsAllowed);
318*89c4ff92SAndroid Build Coastguard Worker 
319*89c4ff92SAndroid Build Coastguard Worker private:
320*89c4ff92SAndroid Build Coastguard Worker 
321*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<armnn::OptimizerOptionsOpaqueImpl> p_OptimizerOptionsImpl;
322*89c4ff92SAndroid Build Coastguard Worker 
323*89c4ff92SAndroid Build Coastguard Worker };
324*89c4ff92SAndroid Build Coastguard Worker 
325*89c4ff92SAndroid Build Coastguard Worker class IWorkloadFactory;
326*89c4ff92SAndroid Build Coastguard Worker class NetworkImpl;
327*89c4ff92SAndroid Build Coastguard Worker using INetworkPtr = std::unique_ptr<INetwork, void(*)(INetwork* network)>;
328*89c4ff92SAndroid Build Coastguard Worker using IOptimizedNetworkPtr = std::unique_ptr<IOptimizedNetwork, void(*)(IOptimizedNetwork* network)>;
329*89c4ff92SAndroid Build Coastguard Worker 
330*89c4ff92SAndroid Build Coastguard Worker using CompiledBlobDeleter = std::function<void(const void*)>;
331*89c4ff92SAndroid Build Coastguard Worker using CompiledBlobPtr = std::unique_ptr<void, CompiledBlobDeleter>;
332*89c4ff92SAndroid Build Coastguard Worker 
333*89c4ff92SAndroid Build Coastguard Worker /// Main network class which provides the interface for building up a neural network.
334*89c4ff92SAndroid Build Coastguard Worker /// This object is subsequently required by the IRuntime::Load() method.
335*89c4ff92SAndroid Build Coastguard Worker class INetwork
336*89c4ff92SAndroid Build Coastguard Worker {
337*89c4ff92SAndroid Build Coastguard Worker public:
338*89c4ff92SAndroid Build Coastguard Worker     static INetwork* CreateRaw(const NetworkOptions& networkOptions = {});
339*89c4ff92SAndroid Build Coastguard Worker     static INetworkPtr Create(const NetworkOptions& networkOptions = {});
340*89c4ff92SAndroid Build Coastguard Worker     static void Destroy(INetwork* network);
341*89c4ff92SAndroid Build Coastguard Worker 
342*89c4ff92SAndroid Build Coastguard Worker     Status PrintGraph();
343*89c4ff92SAndroid Build Coastguard Worker 
344*89c4ff92SAndroid Build Coastguard Worker     /// Adds an input layer to the network.
345*89c4ff92SAndroid Build Coastguard Worker     /// @param id - User generated id to uniquely identify a particular input. The same id needs to be specified.
346*89c4ff92SAndroid Build Coastguard Worker     /// when passing the inputs to the IRuntime::EnqueueWorkload() function.
347*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
348*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
349*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name = nullptr);
350*89c4ff92SAndroid Build Coastguard Worker 
351*89c4ff92SAndroid Build Coastguard Worker     /// Adds an ArgMinMax layer to the network.
352*89c4ff92SAndroid Build Coastguard Worker     /// @param desc - Parameters for the L2 normalization operation.
353*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
354*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
355*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc,
356*89c4ff92SAndroid Build Coastguard Worker                                          const char* name = nullptr);
357*89c4ff92SAndroid Build Coastguard Worker 
358*89c4ff92SAndroid Build Coastguard Worker     /// Adds a cast layer to the network.
359*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
360*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
361*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddCastLayer(const char* name = nullptr);
362*89c4ff92SAndroid Build Coastguard Worker 
363*89c4ff92SAndroid Build Coastguard Worker     /// Add a Comparison layer to the network.
364*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
365*89c4ff92SAndroid Build Coastguard Worker     /// @param desc - Descriptor for the comparison operation.
366*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
367*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
368*89c4ff92SAndroid Build Coastguard Worker                                           const char* name = nullptr);
369*89c4ff92SAndroid Build Coastguard Worker 
370*89c4ff92SAndroid Build Coastguard Worker     /// Adds a concatenation layer to the network.
371*89c4ff92SAndroid Build Coastguard Worker     /// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation
372*89c4ff92SAndroid Build Coastguard Worker     ///                           process. Number of Views must be equal to the number of inputs, and their order
373*89c4ff92SAndroid Build Coastguard Worker     ///                           must match - e.g. first view corresponds to the first input, second view to the
374*89c4ff92SAndroid Build Coastguard Worker     ///                           second input, etc....
375*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
376*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
377*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
378*89c4ff92SAndroid Build Coastguard Worker                                       const char* name = nullptr);
379*89c4ff92SAndroid Build Coastguard Worker 
380*89c4ff92SAndroid Build Coastguard Worker     /// Adds a 2D convolution layer to the network.
381*89c4ff92SAndroid Build Coastguard Worker     /// @param convolution2dDescriptor - Description of the 2D convolution layer.
382*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
383*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
384*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
385*89c4ff92SAndroid Build Coastguard Worker                                              const char* name = nullptr);
386*89c4ff92SAndroid Build Coastguard Worker 
387*89c4ff92SAndroid Build Coastguard Worker     /// Adds a 3D convolution layer to the network.
388*89c4ff92SAndroid Build Coastguard Worker     /// @param convolution3dDescriptor - Description of the 3D convolution layer.
389*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
390*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
391*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddConvolution3dLayer(const Convolution3dDescriptor& convolution3dDescriptor,
392*89c4ff92SAndroid Build Coastguard Worker                                              const char* name = nullptr);
393*89c4ff92SAndroid Build Coastguard Worker 
394*89c4ff92SAndroid Build Coastguard Worker     /// Adds a depth to space layer to the network.
395*89c4ff92SAndroid Build Coastguard Worker     /// @param depthToSpaceDescriptor - Parameters for the depth to space operation.
396*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
397*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
398*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddDepthToSpaceLayer(const DepthToSpaceDescriptor& depthToSpaceDescriptor,
399*89c4ff92SAndroid Build Coastguard Worker                                             const char* name = nullptr);
400*89c4ff92SAndroid Build Coastguard Worker 
401*89c4ff92SAndroid Build Coastguard Worker     /// Adds a 2D depthwise convolution layer to the network.
402*89c4ff92SAndroid Build Coastguard Worker     /// @param convolution2dDescriptor - Description of the 2D depthwise convolution layer.
403*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
404*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
405*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddDepthwiseConvolution2dLayer(const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
406*89c4ff92SAndroid Build Coastguard Worker                                                       const char* name = nullptr);
407*89c4ff92SAndroid Build Coastguard Worker 
408*89c4ff92SAndroid Build Coastguard Worker     /// Adds a Dequantize layer to the network.
409*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
410*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddDequantizeLayer(const char* name = nullptr);
411*89c4ff92SAndroid Build Coastguard Worker 
412*89c4ff92SAndroid Build Coastguard Worker     /// Adds a Detection PostProcess layer to the network.
413*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Description of the Detection PostProcess layer.
414*89c4ff92SAndroid Build Coastguard Worker     /// @param anchors - Tensor for anchors.
415*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
416*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
417*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddDetectionPostProcessLayer(
418*89c4ff92SAndroid Build Coastguard Worker         const DetectionPostProcessDescriptor& descriptor,
419*89c4ff92SAndroid Build Coastguard Worker         const ConstTensor& anchors,
420*89c4ff92SAndroid Build Coastguard Worker         const char* name = nullptr);
421*89c4ff92SAndroid Build Coastguard Worker 
422*89c4ff92SAndroid Build Coastguard Worker     /// Add an ElementwiseBinary layer to the network.
423*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
424*89c4ff92SAndroid Build Coastguard Worker     /// @param desc - Descriptor for the elementwiseBinary operations.
425*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
426*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddElementwiseBinaryLayer(const ElementwiseBinaryDescriptor& elementwiseUnaryDescriptor,
427*89c4ff92SAndroid Build Coastguard Worker                                                  const char* name = nullptr);
428*89c4ff92SAndroid Build Coastguard Worker 
429*89c4ff92SAndroid Build Coastguard Worker     /// Add an ElementwiseUnary layer to the network.
430*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
431*89c4ff92SAndroid Build Coastguard Worker     /// @param desc - Descriptor for the elementwiseUnary operations.
432*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
433*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor,
434*89c4ff92SAndroid Build Coastguard Worker                                                 const char* name = nullptr);
435*89c4ff92SAndroid Build Coastguard Worker 
436*89c4ff92SAndroid Build Coastguard Worker     /// Add an Fill layer to the network.
437*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
438*89c4ff92SAndroid Build Coastguard Worker     /// @param fillDescriptor - Descriptor for the fill operation.
439*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
440*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddFillLayer(const FillDescriptor& fillDescriptor,
441*89c4ff92SAndroid Build Coastguard Worker                                     const char* name = nullptr);
442*89c4ff92SAndroid Build Coastguard Worker 
443*89c4ff92SAndroid Build Coastguard Worker 
444*89c4ff92SAndroid Build Coastguard Worker     /// Adds a fully connected layer to the network.
445*89c4ff92SAndroid Build Coastguard Worker     /// @param fullyConnectedDescriptor - Description of the fully connected layer.
446*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
447*89c4ff92SAndroid Build Coastguard Worker     ///
448*89c4ff92SAndroid Build Coastguard Worker     /// @note Weights and biases are passed in as inputs. If they are constant tensors you can simply store
449*89c4ff92SAndroid Build Coastguard Worker     ///       them in a ConstantLayer as seen below. A full example can be found in samples/SimpleSample.cpp.
450*89c4ff92SAndroid Build Coastguard Worker     ///
451*89c4ff92SAndroid Build Coastguard Worker     /// @code
452*89c4ff92SAndroid Build Coastguard Worker     /// // Make sure the IsConstant flag is set on the weightsInfo before passing it to the ConstTensor.
453*89c4ff92SAndroid Build Coastguard Worker     /// ConstTensor weights(weightsInfo, weightsData);
454*89c4ff92SAndroid Build Coastguard Worker     ///
455*89c4ff92SAndroid Build Coastguard Worker     /// // Constant layer that now holds weights data for FullyConnected
456*89c4ff92SAndroid Build Coastguard Worker     /// IConnectableLayer* const constantWeightsLayer = myNetwork->AddConstantLayer(weights, "weights");
457*89c4ff92SAndroid Build Coastguard Worker     ///
458*89c4ff92SAndroid Build Coastguard Worker     /// FullyConnectedDescriptor fullyConnectedDesc;
459*89c4ff92SAndroid Build Coastguard Worker     /// IConnectableLayer* const fullyConnectedLayer = myNetwork->AddFullyConnectedLayer(fullyConnectedDesc,
460*89c4ff92SAndroid Build Coastguard Worker     ///                                                                                  "fully connected");
461*89c4ff92SAndroid Build Coastguard Worker     /// IConnectableLayer* InputLayer = myNetwork->AddInputLayer(0);
462*89c4ff92SAndroid Build Coastguard Worker     /// InputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
463*89c4ff92SAndroid Build Coastguard Worker     /// constantWeightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
464*89c4ff92SAndroid Build Coastguard Worker     /// @endcode
465*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor,
466*89c4ff92SAndroid Build Coastguard Worker                                               const char* name = nullptr);
467*89c4ff92SAndroid Build Coastguard Worker 
468*89c4ff92SAndroid Build Coastguard Worker     /// Adds a permute layer to the network.
469*89c4ff92SAndroid Build Coastguard Worker     /// @param permuteDescriptor - PermuteDescriptor to configure the permute.
470*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
471*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
472*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor,
473*89c4ff92SAndroid Build Coastguard Worker                                        const char* name = nullptr);
474*89c4ff92SAndroid Build Coastguard Worker 
475*89c4ff92SAndroid Build Coastguard Worker     /// Adds a batch to space ND layer to the network.
476*89c4ff92SAndroid Build Coastguard Worker     /// @param batchToSpaceNdDescriptor - Description of the layer.
477*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
478*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
479*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
480*89c4ff92SAndroid Build Coastguard Worker                                               const char* name = nullptr);
481*89c4ff92SAndroid Build Coastguard Worker 
482*89c4ff92SAndroid Build Coastguard Worker     /// Adds a 2D pooling layer to the network.
483*89c4ff92SAndroid Build Coastguard Worker     /// @param pooling2dDescriptor - Pooling2dDescriptor to configure the pooling.
484*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
485*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
486*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor,
487*89c4ff92SAndroid Build Coastguard Worker         const char* name = nullptr);
488*89c4ff92SAndroid Build Coastguard Worker 
489*89c4ff92SAndroid Build Coastguard Worker     /// Adds a 3D pooling layer to the network.
490*89c4ff92SAndroid Build Coastguard Worker     /// @param pooling3dDescriptor - Pooling3dDescriptor to configure the pooling.
491*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
492*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
493*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddPooling3dLayer(const Pooling3dDescriptor& pooling3dDescriptor,
494*89c4ff92SAndroid Build Coastguard Worker         const char* name = nullptr);
495*89c4ff92SAndroid Build Coastguard Worker 
496*89c4ff92SAndroid Build Coastguard Worker     /// Adds a Precompiled layer to the network.
497*89c4ff92SAndroid Build Coastguard Worker     /// Method use is for backend users.
498*89c4ff92SAndroid Build Coastguard Worker     /// @param preCompiledDescriptor - PreCompiledDescriptor contains parameters for the Precompiled layer.
499*89c4ff92SAndroid Build Coastguard Worker     /// @param compiledBlobPtr - CompiledBlobPtr pre-compiled object set for the Precompiled layer.
500*89c4ff92SAndroid Build Coastguard Worker     /// @param backend - optional BackendId set for the Precompiled layer.
501*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
502*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddPrecompiledLayer(const PreCompiledDescriptor& preCompiledDescriptor,
503*89c4ff92SAndroid Build Coastguard Worker                                            CompiledBlobPtr compiledBlobPtr,
504*89c4ff92SAndroid Build Coastguard Worker                                            const Optional<BackendId>& backend,
505*89c4ff92SAndroid Build Coastguard Worker                                            const char* name = nullptr);
506*89c4ff92SAndroid Build Coastguard Worker 
507*89c4ff92SAndroid Build Coastguard Worker     /// Adds an activation layer to the network.
508*89c4ff92SAndroid Build Coastguard Worker     /// @param activationDescriptor - ActivationDescriptor to configure the activation.
509*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
510*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
511*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor,
512*89c4ff92SAndroid Build Coastguard Worker         const char* name = nullptr);
513*89c4ff92SAndroid Build Coastguard Worker 
514*89c4ff92SAndroid Build Coastguard Worker     /// Adds a normalization layer to the network.
515*89c4ff92SAndroid Build Coastguard Worker     /// @param normalizationDescriptor - NormalizationDescriptor to configure the normalization.
516*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
517*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
518*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor,
519*89c4ff92SAndroid Build Coastguard Worker         const char* name = nullptr);
520*89c4ff92SAndroid Build Coastguard Worker 
521*89c4ff92SAndroid Build Coastguard Worker     /// Adds a slice layer to the network.
522*89c4ff92SAndroid Build Coastguard Worker     /// @param sliceDescriptor - SliceDescriptor to configure the slice operation.
523*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
524*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
525*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr);
526*89c4ff92SAndroid Build Coastguard Worker 
527*89c4ff92SAndroid Build Coastguard Worker     /// Adds a softmax layer to the network.
528*89c4ff92SAndroid Build Coastguard Worker     /// If the data type is QAsymm8, then the output quantization parameters
529*89c4ff92SAndroid Build Coastguard Worker     /// must have a scale of 1/256 and an offset of 0
530*89c4ff92SAndroid Build Coastguard Worker     /// @param softmaxDescriptor - SoftmaxDescriptor to configure the softmax.
531*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
532*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
533*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor,
534*89c4ff92SAndroid Build Coastguard Worker         const char* name = nullptr);
535*89c4ff92SAndroid Build Coastguard Worker 
536*89c4ff92SAndroid Build Coastguard Worker     /// Adds a splitter layer to the network.
537*89c4ff92SAndroid Build Coastguard Worker     /// @param splitterDescriptor - ViewsDescriptor to configure the splitting process.
538*89c4ff92SAndroid Build Coastguard Worker     ///                             Number of Views must be equal to the number of outputs,
539*89c4ff92SAndroid Build Coastguard Worker     ///                             and their order must match - e.g. first view corresponds to
540*89c4ff92SAndroid Build Coastguard Worker     ///                             the first output, second view to the second output, etc....
541*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
542*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
543*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor,
544*89c4ff92SAndroid Build Coastguard Worker                                         const char* name = nullptr);
545*89c4ff92SAndroid Build Coastguard Worker 
546*89c4ff92SAndroid Build Coastguard Worker     /// Adds a merge layer to the network.
547*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
548*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
549*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddMergeLayer(const char* name = nullptr);
550*89c4ff92SAndroid Build Coastguard Worker 
551*89c4ff92SAndroid Build Coastguard Worker     /// Adds an addition layer to the network.
552*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
553*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
554*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02")
555*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddAdditionLayer(const char* name = nullptr);
556*89c4ff92SAndroid Build Coastguard Worker 
557*89c4ff92SAndroid Build Coastguard Worker     /// Adds a multiplication layer to the network.
558*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
559*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
560*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02")
561*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr);
562*89c4ff92SAndroid Build Coastguard Worker 
563*89c4ff92SAndroid Build Coastguard Worker     /// Adds a batch normalization layer to the network.
564*89c4ff92SAndroid Build Coastguard Worker     /// @param mean - Pre-calculated mean for each channel.
565*89c4ff92SAndroid Build Coastguard Worker     /// @param variance - Pre-calculated variance for each channel.
566*89c4ff92SAndroid Build Coastguard Worker     /// @param beta - Per-channel additive factor.
567*89c4ff92SAndroid Build Coastguard Worker     /// @param gamma - Per-channel multiplicative factor.
568*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
569*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
570*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
571*89c4ff92SAndroid Build Coastguard Worker         const ConstTensor& mean,
572*89c4ff92SAndroid Build Coastguard Worker         const ConstTensor& variance,
573*89c4ff92SAndroid Build Coastguard Worker         const ConstTensor& beta,
574*89c4ff92SAndroid Build Coastguard Worker         const ConstTensor& gamma,
575*89c4ff92SAndroid Build Coastguard Worker         const char* name = nullptr);
576*89c4ff92SAndroid Build Coastguard Worker 
577*89c4ff92SAndroid Build Coastguard Worker     /// Adds a rank layer to the network.
578*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
579*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
580*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddRankLayer(const char* name = nullptr);
581*89c4ff92SAndroid Build Coastguard Worker 
582*89c4ff92SAndroid Build Coastguard Worker     /// Adds a resize layer to the network.
583*89c4ff92SAndroid Build Coastguard Worker     /// @param resizeDescriptor - Parameters for the resize operation.
584*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
585*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
586*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor,
587*89c4ff92SAndroid Build Coastguard Worker                                       const char* name = nullptr);
588*89c4ff92SAndroid Build Coastguard Worker 
589*89c4ff92SAndroid Build Coastguard Worker     /// Adds a reduce layer to the network.
590*89c4ff92SAndroid Build Coastguard Worker     /// @param ReduceDescriptor - Parameters for the reduce operation.
591*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
592*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
593*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddReduceLayer(const ReduceDescriptor& reduceDescriptor,
594*89c4ff92SAndroid Build Coastguard Worker                                       const char* name = nullptr);
595*89c4ff92SAndroid Build Coastguard Worker 
596*89c4ff92SAndroid Build Coastguard Worker     /// Adds an instance normalization layer to the network.
597*89c4ff92SAndroid Build Coastguard Worker     /// @param desc - Parameters for the instance normalization operation.
598*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
599*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
600*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc,
601*89c4ff92SAndroid Build Coastguard Worker                                                      const char* name = nullptr);
602*89c4ff92SAndroid Build Coastguard Worker 
603*89c4ff92SAndroid Build Coastguard Worker     /// Adds an L2 normalization layer to the network.
604*89c4ff92SAndroid Build Coastguard Worker     /// Normalization is performed along dimension 1, but requires a 4d input.
605*89c4ff92SAndroid Build Coastguard Worker     /// @param desc - Parameters for the L2 normalization operation.
606*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
607*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
608*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc,
609*89c4ff92SAndroid Build Coastguard Worker                                                const char* name = nullptr);
610*89c4ff92SAndroid Build Coastguard Worker 
611*89c4ff92SAndroid Build Coastguard Worker     /// Adds a log softmax layer to the network.
612*89c4ff92SAndroid Build Coastguard Worker     /// @param logSoftmaxDescriptor - LogSoftmaxDescriptor to configure the log softmax.
613*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
614*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
615*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddLogSoftmaxLayer(const LogSoftmaxDescriptor& logSoftmaxDescriptor,
616*89c4ff92SAndroid Build Coastguard Worker                                           const char* name = nullptr);
617*89c4ff92SAndroid Build Coastguard Worker 
618*89c4ff92SAndroid Build Coastguard Worker     /// Adds a layer with no inputs and a single output, which always corresponds to
619*89c4ff92SAndroid Build Coastguard Worker     /// the passed in constant tensor.
620*89c4ff92SAndroid Build Coastguard Worker     /// @param input - Tensor to be provided as the only output of the layer. The layer will maintain
621*89c4ff92SAndroid Build Coastguard Worker     ///                its own copy of the tensor data, meaning the memory referenced by @a input can
622*89c4ff92SAndroid Build Coastguard Worker     ///                be freed or reused after this function is called.
623*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
624*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
625*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddConstantLayer(const ConstTensor& input,
626*89c4ff92SAndroid Build Coastguard Worker                                         const char* name = nullptr);
627*89c4ff92SAndroid Build Coastguard Worker 
628*89c4ff92SAndroid Build Coastguard Worker     /// Adds a reshape layer to the network.
629*89c4ff92SAndroid Build Coastguard Worker     /// @param reshapeDescriptor - Parameters for the reshape operation.
630*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
631*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
632*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor,
633*89c4ff92SAndroid Build Coastguard Worker                                        const char* name = nullptr);
634*89c4ff92SAndroid Build Coastguard Worker 
635*89c4ff92SAndroid Build Coastguard Worker     /// Adds a shape layer to the network.
636*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
637*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
638*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddShapeLayer(const char* name = nullptr);
639*89c4ff92SAndroid Build Coastguard Worker 
640*89c4ff92SAndroid Build Coastguard Worker     /// Adds a space to batch layer to the network.
641*89c4ff92SAndroid Build Coastguard Worker     /// @param spaceToBatchNdDescriptor - Parameters for the space to batch operation.
642*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
643*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
644*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
645*89c4ff92SAndroid Build Coastguard Worker                                               const char* name = nullptr);
646*89c4ff92SAndroid Build Coastguard Worker 
647*89c4ff92SAndroid Build Coastguard Worker     /// Adds a space to depth layer to the network.
648*89c4ff92SAndroid Build Coastguard Worker     /// @param spaceToDepthDescriptor - Parameters for the space to depth operation.
649*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
650*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
651*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor,
652*89c4ff92SAndroid Build Coastguard Worker                                             const char* name = nullptr);
653*89c4ff92SAndroid Build Coastguard Worker 
654*89c4ff92SAndroid Build Coastguard Worker     /// Adds a floor layer to the network.
655*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
656*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
657*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddFloorLayer(const char* name = nullptr);
658*89c4ff92SAndroid Build Coastguard Worker 
659*89c4ff92SAndroid Build Coastguard Worker     /// Adds an output layer to the network.
660*89c4ff92SAndroid Build Coastguard Worker     /// @param id - User generated id to uniquely identify a particular output. The same id needs to be specified
661*89c4ff92SAndroid Build Coastguard Worker     /// when passing the outputs to the IRuntime::EnqueueWorkload() function.
662*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
663*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
664*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr);
665*89c4ff92SAndroid Build Coastguard Worker 
666*89c4ff92SAndroid Build Coastguard Worker     /// Add a Lstm layer to the network
667*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Parameters for the Lstm operation
668*89c4ff92SAndroid Build Coastguard Worker     /// @param params - Weights and biases for the LSTM cell
669*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer
670*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
671*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
672*89c4ff92SAndroid Build Coastguard Worker                                     const LstmInputParams& params,
673*89c4ff92SAndroid Build Coastguard Worker                                     const char* name = nullptr);
674*89c4ff92SAndroid Build Coastguard Worker 
675*89c4ff92SAndroid Build Coastguard Worker     /// Adds a division layer to the network.
676*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
677*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
678*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02")
679*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddDivisionLayer(const char* name = nullptr);
680*89c4ff92SAndroid Build Coastguard Worker 
681*89c4ff92SAndroid Build Coastguard Worker     /// Adds a subtraction layer to the network.
682*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
683*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
684*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02")
685*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddSubtractionLayer(const char* name = nullptr);
686*89c4ff92SAndroid Build Coastguard Worker 
687*89c4ff92SAndroid Build Coastguard Worker     /// Add a Maximum layer to the network.
688*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
689*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
690*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02")
691*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddMaximumLayer(const char* name = nullptr);
692*89c4ff92SAndroid Build Coastguard Worker 
693*89c4ff92SAndroid Build Coastguard Worker     /// Add a Mean layer to the network.
694*89c4ff92SAndroid Build Coastguard Worker     /// @param meanDescriptor - Parameters for the mean operation.
695*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
696*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
697*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr);
698*89c4ff92SAndroid Build Coastguard Worker 
699*89c4ff92SAndroid Build Coastguard Worker     /// Adds a fully pad layer to the network.
700*89c4ff92SAndroid Build Coastguard Worker     /// @param paddings - n by 2 tensor, where n is the rank of the input tensor,
701*89c4ff92SAndroid Build Coastguard Worker     ///                   such that paddings[i,0] indicates the amount of padding to add in front of dimonsion i, and
702*89c4ff92SAndroid Build Coastguard Worker     ///                   paddings[i,1] indicates the amount of padding to add after the end of dimension i
703*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
704*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
705*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor,
706*89c4ff92SAndroid Build Coastguard Worker                                            const char* name = nullptr);
707*89c4ff92SAndroid Build Coastguard Worker 
708*89c4ff92SAndroid Build Coastguard Worker     /// Add a quantize layer to the network
709*89c4ff92SAndroid Build Coastguard Worker     ///@param name - Optional name for the layer.
710*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
711*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddQuantizeLayer(const char* name = nullptr);
712*89c4ff92SAndroid Build Coastguard Worker 
713*89c4ff92SAndroid Build Coastguard Worker     /// Adds a strided slice layer to the network.
714*89c4ff92SAndroid Build Coastguard Worker     /// @param StridedSliceDescriptor - Parameters for the strided slice operation.
715*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
716*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
717*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor,
718*89c4ff92SAndroid Build Coastguard Worker                                                     const char* name = nullptr);
719*89c4ff92SAndroid Build Coastguard Worker 
720*89c4ff92SAndroid Build Coastguard Worker     /// Add a Minimum layer to the network.
721*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
722*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
723*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02")
724*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddMinimumLayer(const char* name = nullptr);
725*89c4ff92SAndroid Build Coastguard Worker 
726*89c4ff92SAndroid Build Coastguard Worker     /// Add Gather layer to the network.
727*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Description of the gather layer.
728*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
729*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
730*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddGatherLayer(const GatherDescriptor& descriptor,
731*89c4ff92SAndroid Build Coastguard Worker                                               const char* name = nullptr);
732*89c4ff92SAndroid Build Coastguard Worker 
733*89c4ff92SAndroid Build Coastguard Worker     /// Add GatherNd layer to the network.
734*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
735*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
736*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddGatherNdLayer(const char* name = nullptr);
737*89c4ff92SAndroid Build Coastguard Worker 
738*89c4ff92SAndroid Build Coastguard Worker     /// Adds a switch layer to the network.
739*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
740*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
741*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddSwitchLayer(const char* name = nullptr);
742*89c4ff92SAndroid Build Coastguard Worker 
743*89c4ff92SAndroid Build Coastguard Worker     /// Adds a PReLU layer to the network.
744*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
745*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
746*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddPreluLayer(const char* name = nullptr);
747*89c4ff92SAndroid Build Coastguard Worker 
748*89c4ff92SAndroid Build Coastguard Worker     /// Adds a 2D transpose convolution layer to the network.
749*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Description of the 2D transpose convolution layer.
750*89c4ff92SAndroid Build Coastguard Worker     /// @param weights - Tensor for the weights data.
751*89c4ff92SAndroid Build Coastguard Worker     /// @param biases - Optional tensor for the bias data.
752*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
753*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
754*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor,
755*89c4ff92SAndroid Build Coastguard Worker                                                               const ConstTensor& weights,
756*89c4ff92SAndroid Build Coastguard Worker                                                               const Optional<ConstTensor>& biases,
757*89c4ff92SAndroid Build Coastguard Worker                                                               const char* name = nullptr);
758*89c4ff92SAndroid Build Coastguard Worker 
759*89c4ff92SAndroid Build Coastguard Worker     /// Adds a transpose layer to the network.
760*89c4ff92SAndroid Build Coastguard Worker     /// @param transposeDescriptor - TransposeDescriptor to configure the transpose.
761*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
762*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
763*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddTransposeLayer(const TransposeDescriptor& transposeDescriptor,
764*89c4ff92SAndroid Build Coastguard Worker                                                  const char* name = nullptr);
765*89c4ff92SAndroid Build Coastguard Worker 
766*89c4ff92SAndroid Build Coastguard Worker     /// Adds a stack layer to the network.
767*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Description of the stack layer.
768*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
769*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
770*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddStackLayer(const StackDescriptor& descriptor,
771*89c4ff92SAndroid Build Coastguard Worker                                              const char* name = nullptr);
772*89c4ff92SAndroid Build Coastguard Worker 
773*89c4ff92SAndroid Build Coastguard Worker     /// Add a stand-in layer for a type unknown to the Arm NN framework.
774*89c4ff92SAndroid Build Coastguard Worker     /// Note: Due to the nature of this layer, no validation can be performed by the framework.
775*89c4ff92SAndroid Build Coastguard Worker     /// Furthermore, Any model containing this layer cannot make use of dynamic tensors since the
776*89c4ff92SAndroid Build Coastguard Worker     /// tensor sizes cannot be inferred.
777*89c4ff92SAndroid Build Coastguard Worker     /// @descriptor - Descriptor for the StandIn layer.
778*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
779*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddStandInLayer(const StandInDescriptor& descriptor,
780*89c4ff92SAndroid Build Coastguard Worker                                                const char* name = nullptr);
781*89c4ff92SAndroid Build Coastguard Worker 
782*89c4ff92SAndroid Build Coastguard Worker     /// Add a QuantizedLstm layer to the network
783*89c4ff92SAndroid Build Coastguard Worker     /// @param params - The weights and biases for the Quantized LSTM cell
784*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer
785*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
786*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
787*89c4ff92SAndroid Build Coastguard Worker                                                      const char* name = nullptr);
788*89c4ff92SAndroid Build Coastguard Worker 
789*89c4ff92SAndroid Build Coastguard Worker     /// Add a QLstm layer to the network
790*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Parameters for the QLstm operation
791*89c4ff92SAndroid Build Coastguard Worker     /// @param params - Weights and biases for the layer
792*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer
793*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
794*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddQLstmLayer(const QLstmDescriptor& descriptor,
795*89c4ff92SAndroid Build Coastguard Worker                                              const LstmInputParams& params,
796*89c4ff92SAndroid Build Coastguard Worker                                              const char* name = nullptr);
797*89c4ff92SAndroid Build Coastguard Worker 
798*89c4ff92SAndroid Build Coastguard Worker     /// Adds a Logical Binary layer to the network.
799*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Description of the Logical Binary layer.
800*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer.
801*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
802*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddLogicalBinaryLayer(const LogicalBinaryDescriptor& descriptor,
803*89c4ff92SAndroid Build Coastguard Worker                                                      const char* name = nullptr);
804*89c4ff92SAndroid Build Coastguard Worker 
805*89c4ff92SAndroid Build Coastguard Worker     /// Add a UnidirectionalSequenceLstm layer to the network
806*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Parameters for the UnidirectionalSequenceLstm operation
807*89c4ff92SAndroid Build Coastguard Worker     /// @param params - Weights and biases for the UnidirectionalSequenceLstm
808*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer
809*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer.
810*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddUnidirectionalSequenceLstmLayer(const UnidirectionalSequenceLstmDescriptor& descriptor,
811*89c4ff92SAndroid Build Coastguard Worker                                                           const LstmInputParams& params,
812*89c4ff92SAndroid Build Coastguard Worker                                                           const char* name = nullptr);
813*89c4ff92SAndroid Build Coastguard Worker 
814*89c4ff92SAndroid Build Coastguard Worker     /// Add a ChannelShuffle layer to the network
815*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Parameters for the ChannelShuffle operation
816*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer
817*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer
818*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddChannelShuffleLayer(const ChannelShuffleDescriptor& descriptor,
819*89c4ff92SAndroid Build Coastguard Worker                                               const char* name = nullptr);
820*89c4ff92SAndroid Build Coastguard Worker 
821*89c4ff92SAndroid Build Coastguard Worker     /// Add a BatchMatMul layer to the network
822*89c4ff92SAndroid Build Coastguard Worker     /// @param descriptor - Parameters for the BatchMatMul operation
823*89c4ff92SAndroid Build Coastguard Worker     /// @param name - Optional name for the layer
824*89c4ff92SAndroid Build Coastguard Worker     /// @return - Interface for configuring the layer
825*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* AddBatchMatMulLayer(const BatchMatMulDescriptor& descriptor,
826*89c4ff92SAndroid Build Coastguard Worker                                            const char* name = nullptr);
827*89c4ff92SAndroid Build Coastguard Worker 
828*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(IStrategy& strategy) const;
829*89c4ff92SAndroid Build Coastguard Worker 
830*89c4ff92SAndroid Build Coastguard Worker protected:
831*89c4ff92SAndroid Build Coastguard Worker     ~INetwork();
832*89c4ff92SAndroid Build Coastguard Worker 
833*89c4ff92SAndroid Build Coastguard Worker     friend void VisitLayersTopologically(const INetwork* inputNetwork, IStrategy& strategy);
834*89c4ff92SAndroid Build Coastguard Worker     friend class TestConnectionPreservation;
835*89c4ff92SAndroid Build Coastguard Worker     friend TensorInfo GetInputTensorInfo(const INetwork* network);
836*89c4ff92SAndroid Build Coastguard Worker     friend IOptimizedNetworkPtr Optimize(const INetwork& network,
837*89c4ff92SAndroid Build Coastguard Worker                                          const std::vector<BackendId>& backendPreferences,
838*89c4ff92SAndroid Build Coastguard Worker                                          const IDeviceSpec& deviceSpec,
839*89c4ff92SAndroid Build Coastguard Worker                                          const OptimizerOptions& options,
840*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::vector<std::string>&> messages);
841*89c4ff92SAndroid Build Coastguard Worker     friend IOptimizedNetworkPtr Optimize(const INetwork& network,
842*89c4ff92SAndroid Build Coastguard Worker                                          const std::vector<BackendId>& backendPreferences,
843*89c4ff92SAndroid Build Coastguard Worker                                          const IDeviceSpec& deviceSpec,
844*89c4ff92SAndroid Build Coastguard Worker                                          const OptimizerOptionsOpaque& options,
845*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::vector<std::string>&> messages);
846*89c4ff92SAndroid Build Coastguard Worker 
847*89c4ff92SAndroid Build Coastguard Worker     INetwork(NetworkOptions networkOptions = {});
848*89c4ff92SAndroid Build Coastguard Worker 
849*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<NetworkImpl> pNetworkImpl;
850*89c4ff92SAndroid Build Coastguard Worker };
851*89c4ff92SAndroid Build Coastguard Worker 
852*89c4ff92SAndroid Build Coastguard Worker namespace experimental
853*89c4ff92SAndroid Build Coastguard Worker {
854*89c4ff92SAndroid Build Coastguard Worker class AsyncNetworkImpl;
855*89c4ff92SAndroid Build Coastguard Worker class WorkingMemHandle;
856*89c4ff92SAndroid Build Coastguard Worker }
857*89c4ff92SAndroid Build Coastguard Worker 
858*89c4ff92SAndroid Build Coastguard Worker struct BackendSettings;
859*89c4ff92SAndroid Build Coastguard Worker struct OptimizationResult;
860*89c4ff92SAndroid Build Coastguard Worker class OptimizedNetworkImpl;
861*89c4ff92SAndroid Build Coastguard Worker class IProfiler;
862*89c4ff92SAndroid Build Coastguard Worker class IOptimizedNetwork
863*89c4ff92SAndroid Build Coastguard Worker {
864*89c4ff92SAndroid Build Coastguard Worker public:
865*89c4ff92SAndroid Build Coastguard Worker     static void Destroy(IOptimizedNetwork* network);
866*89c4ff92SAndroid Build Coastguard Worker 
867*89c4ff92SAndroid Build Coastguard Worker     Status PrintGraph();
868*89c4ff92SAndroid Build Coastguard Worker     Status SerializeToDot(std::ostream& stream) const;
869*89c4ff92SAndroid Build Coastguard Worker 
870*89c4ff92SAndroid Build Coastguard Worker     arm::pipe::ProfilingGuid GetGuid() const;
871*89c4ff92SAndroid Build Coastguard Worker 
872*89c4ff92SAndroid Build Coastguard Worker     size_t GetNumInputs() const;
873*89c4ff92SAndroid Build Coastguard Worker     size_t GetNumOutputs() const;
874*89c4ff92SAndroid Build Coastguard Worker 
875*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(IStrategy& strategy) const;
876*89c4ff92SAndroid Build Coastguard Worker 
877*89c4ff92SAndroid Build Coastguard Worker     /// Creates a copy of the IOptimizedNetwork. The IOptimizedNetwork will not be reoptimized,
878*89c4ff92SAndroid Build Coastguard Worker     /// the provided ModelOptions will only be used when creating a LoadedNetwork.
879*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetwork(const IOptimizedNetwork& other, const ModelOptions& modelOptions);
880*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetwork(std::unique_ptr<Graph> graph);
881*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetwork(std::unique_ptr<OptimizedNetworkImpl> impl);
882*89c4ff92SAndroid Build Coastguard Worker     ~IOptimizedNetwork();
883*89c4ff92SAndroid Build Coastguard Worker 
884*89c4ff92SAndroid Build Coastguard Worker     const std::shared_ptr<IProfiler>& GetProfiler() const;
885*89c4ff92SAndroid Build Coastguard Worker 
886*89c4ff92SAndroid Build Coastguard Worker protected:
887*89c4ff92SAndroid Build Coastguard Worker     friend class LoadedNetwork;
888*89c4ff92SAndroid Build Coastguard Worker 
889*89c4ff92SAndroid Build Coastguard Worker     friend class experimental::AsyncNetworkImpl;
890*89c4ff92SAndroid Build Coastguard Worker     friend class experimental::WorkingMemHandle;
891*89c4ff92SAndroid Build Coastguard Worker 
892*89c4ff92SAndroid Build Coastguard Worker     friend Graph& GetGraphForTesting(IOptimizedNetwork* optNetPtr);
893*89c4ff92SAndroid Build Coastguard Worker     friend ModelOptions& GetModelOptionsForTesting(IOptimizedNetwork* optNetPtr);
894*89c4ff92SAndroid Build Coastguard Worker     friend IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
895*89c4ff92SAndroid Build Coastguard Worker                                          const std::vector<BackendId>& backendPreferences,
896*89c4ff92SAndroid Build Coastguard Worker                                          const IDeviceSpec& deviceSpec,
897*89c4ff92SAndroid Build Coastguard Worker                                          const OptimizerOptionsOpaque& options,
898*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::vector<std::string>&> messages);
899*89c4ff92SAndroid Build Coastguard Worker     friend IOptimizedNetworkPtr Optimize(const Graph& inGraph,
900*89c4ff92SAndroid Build Coastguard Worker                                          const std::vector<BackendId>& backendPreferences,
901*89c4ff92SAndroid Build Coastguard Worker                                          const IDeviceSpec& deviceSpec,
902*89c4ff92SAndroid Build Coastguard Worker                                          const OptimizerOptionsOpaque& options,
903*89c4ff92SAndroid Build Coastguard Worker                                          Optional<std::vector<std::string>&> messages);
904*89c4ff92SAndroid Build Coastguard Worker 
905*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetwork(std::unique_ptr<Graph> graph, const ModelOptions& modelOptions);
906*89c4ff92SAndroid Build Coastguard Worker 
907*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<OptimizedNetworkImpl> pOptimizedNetworkImpl;
908*89c4ff92SAndroid Build Coastguard Worker };
909*89c4ff92SAndroid Build Coastguard Worker 
910*89c4ff92SAndroid Build Coastguard Worker /// Create an optimized version of the network
911*89c4ff92SAndroid Build Coastguard Worker /// @param network INetwork description of the network to be optimized.
912*89c4ff92SAndroid Build Coastguard Worker /// @param backendPreferences The choice of the backend ordered by user preferences.
913*89c4ff92SAndroid Build Coastguard Worker /// @param deviceSpec DeviceSpec object as queried from the runtime. See IRuntime::GetDeviceSpec()
914*89c4ff92SAndroid Build Coastguard Worker /// @param messages If there are failures or warnings a string describing same will be added to the vector
915*89c4ff92SAndroid Build Coastguard Worker /// @param options OptimizerOptions object with optimizer configuration options
916*89c4ff92SAndroid Build Coastguard Worker /// @return An IOptimizedNetworkPtr interface to the optimized network, throws an exception derived from
917*89c4ff92SAndroid Build Coastguard Worker /// armnn::Exception if process fails.
918*89c4ff92SAndroid Build Coastguard Worker 
919*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr Optimize(const INetwork& network,
920*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<BackendId>& backendPreferences,
921*89c4ff92SAndroid Build Coastguard Worker                               const IDeviceSpec& deviceSpec,
922*89c4ff92SAndroid Build Coastguard Worker                               const OptimizerOptionsOpaque& options = OptimizerOptionsOpaque(),
923*89c4ff92SAndroid Build Coastguard Worker                               Optional<std::vector<std::string>&> messages = EmptyOptional());
924*89c4ff92SAndroid Build Coastguard Worker 
925*89c4ff92SAndroid Build Coastguard Worker /// Create an optimized version of the network
926*89c4ff92SAndroid Build Coastguard Worker /// @param inGraph Graph to be optimized.
927*89c4ff92SAndroid Build Coastguard Worker /// @param backendPreferences The choice of the backend ordered by user preferences.
928*89c4ff92SAndroid Build Coastguard Worker /// @param deviceSpec DeviceSpec object as queried from the runtime. See IRuntime::GetDeviceSpec()
929*89c4ff92SAndroid Build Coastguard Worker /// @param messages If there are failures or warnings a string describing same will be added to the vector
930*89c4ff92SAndroid Build Coastguard Worker /// @param options OptimizerOptions object with optimizer configuration options
931*89c4ff92SAndroid Build Coastguard Worker /// @return An IOptimizedNetworkPtr interface to the optimized network, throws an exception derived from
932*89c4ff92SAndroid Build Coastguard Worker /// armnn::Exception if process fails.
933*89c4ff92SAndroid Build Coastguard Worker 
934*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr Optimize(const Graph& inGraph,
935*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<BackendId>& backendPreferences,
936*89c4ff92SAndroid Build Coastguard Worker                               const IDeviceSpec& deviceSpec,
937*89c4ff92SAndroid Build Coastguard Worker                               const OptimizerOptionsOpaque& options,
938*89c4ff92SAndroid Build Coastguard Worker                               Optional<std::vector<std::string>&> messages = EmptyOptional());
939*89c4ff92SAndroid Build Coastguard Worker 
940*89c4ff92SAndroid Build Coastguard Worker /// Accept legacy OptimizerOptions
941*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr Optimize(const Graph& inGraph,
942*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<BackendId>& backendPreferences,
943*89c4ff92SAndroid Build Coastguard Worker                               const IDeviceSpec& deviceSpec,
944*89c4ff92SAndroid Build Coastguard Worker                               const OptimizerOptions& options,
945*89c4ff92SAndroid Build Coastguard Worker                               Optional<std::vector<std::string>&> messages = EmptyOptional());
946*89c4ff92SAndroid Build Coastguard Worker 
947*89c4ff92SAndroid Build Coastguard Worker /// Accept legacy OptimizerOptions
948*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr Optimize(const INetwork& network,
949*89c4ff92SAndroid Build Coastguard Worker                               const std::vector<BackendId>& backendPreferences,
950*89c4ff92SAndroid Build Coastguard Worker                               const IDeviceSpec& deviceSpec,
951*89c4ff92SAndroid Build Coastguard Worker                               const OptimizerOptions& options,
952*89c4ff92SAndroid Build Coastguard Worker                               Optional<std::vector<std::string>&> messages = EmptyOptional());
953*89c4ff92SAndroid Build Coastguard Worker 
954*89c4ff92SAndroid Build Coastguard Worker } //namespace armnn
955