1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendOptions.hpp> 8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Deprecated.hpp> 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/DescriptorsFwd.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IStrategy.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/NetworkFwd.hpp> 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Optional.hpp> 13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TensorFwd.hpp> 14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp> 15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp> 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker #include <memory> 18*89c4ff92SAndroid Build Coastguard Worker #include <vector> 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker namespace armnn 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker /// @brief An input connection slot for a layer. 23*89c4ff92SAndroid Build Coastguard Worker /// The input slot can be connected to an output slot of the preceding layer in the graph. 24*89c4ff92SAndroid Build Coastguard Worker /// Only one connection to the input slot is allowed. 25*89c4ff92SAndroid Build Coastguard Worker class IInputSlot 26*89c4ff92SAndroid Build Coastguard Worker { 27*89c4ff92SAndroid Build Coastguard Worker public: 28*89c4ff92SAndroid Build Coastguard Worker virtual const IOutputSlot* GetConnection() const = 0; 29*89c4ff92SAndroid Build Coastguard Worker virtual IOutputSlot* GetConnection() = 0; 30*89c4ff92SAndroid Build Coastguard Worker virtual const IConnectableLayer& GetOwningIConnectableLayer() const = 0; 31*89c4ff92SAndroid Build Coastguard Worker virtual IConnectableLayer& GetOwningIConnectableLayer() = 0; 32*89c4ff92SAndroid Build Coastguard Worker virtual unsigned int GetSlotIndex() const = 0; 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker protected: 35*89c4ff92SAndroid Build Coastguard Worker /// Not user deletable. ~IInputSlot()36*89c4ff92SAndroid Build Coastguard Worker ~IInputSlot() {} 37*89c4ff92SAndroid Build Coastguard Worker }; 38*89c4ff92SAndroid Build Coastguard Worker 39*89c4ff92SAndroid Build Coastguard Worker /// @brief An output connection slot for a layer. 40*89c4ff92SAndroid Build Coastguard Worker /// The output slot may be connected to 1 or more input slots of subsequent layers in the graph. 41*89c4ff92SAndroid Build Coastguard Worker class IOutputSlot 42*89c4ff92SAndroid Build Coastguard Worker { 43*89c4ff92SAndroid Build Coastguard Worker public: 44*89c4ff92SAndroid Build Coastguard Worker virtual unsigned int GetNumConnections() const = 0; 45*89c4ff92SAndroid Build Coastguard Worker virtual const IInputSlot* GetConnection(unsigned int index) const = 0; 46*89c4ff92SAndroid Build Coastguard Worker virtual IInputSlot* GetConnection(unsigned int outputindex) = 0; 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker virtual void SetTensorInfo(const TensorInfo& tensorInfo) = 0; 49*89c4ff92SAndroid Build Coastguard Worker virtual const TensorInfo& GetTensorInfo() const = 0; 50*89c4ff92SAndroid Build Coastguard Worker virtual bool IsTensorInfoSet() const = 0; 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker virtual int Connect(IInputSlot& destination) = 0; 53*89c4ff92SAndroid Build Coastguard Worker virtual void Disconnect(IInputSlot& slot) = 0; 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker virtual unsigned int CalculateIndexOnOwner() const = 0; 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker virtual LayerGuid GetOwningLayerGuid() const = 0; 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker virtual const IConnectableLayer& GetOwningIConnectableLayer() const = 0; 60*89c4ff92SAndroid Build Coastguard Worker virtual IConnectableLayer& GetOwningIConnectableLayer() = 0; 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker protected: 63*89c4ff92SAndroid Build Coastguard Worker /// Not user deletable. ~IOutputSlot()64*89c4ff92SAndroid Build Coastguard Worker ~IOutputSlot() {} 65*89c4ff92SAndroid Build Coastguard Worker }; 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Worker /// @brief Interface for a layer that is connectable to other layers via InputSlots and OutputSlots. 68*89c4ff92SAndroid Build Coastguard Worker class IConnectableLayer 69*89c4ff92SAndroid Build Coastguard Worker { 70*89c4ff92SAndroid Build Coastguard Worker public: 71*89c4ff92SAndroid Build Coastguard Worker /// Returns the name of the layer 72*89c4ff92SAndroid Build Coastguard Worker virtual const char* GetName() const = 0; 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker /// Returns the number of connectable input slots 75*89c4ff92SAndroid Build Coastguard Worker virtual unsigned int GetNumInputSlots() const = 0; 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker /// Returns the number of connectable output slots 78*89c4ff92SAndroid Build Coastguard Worker virtual unsigned int GetNumOutputSlots() const = 0; 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker /// Get a const input slot handle by slot index 81*89c4ff92SAndroid Build Coastguard Worker virtual const IInputSlot& GetInputSlot(unsigned int index) const = 0; 82*89c4ff92SAndroid Build Coastguard Worker 83*89c4ff92SAndroid Build Coastguard Worker /// Get the input slot handle by slot index 84*89c4ff92SAndroid Build Coastguard Worker virtual IInputSlot& GetInputSlot(unsigned int index) = 0; 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker /// Get the const output slot handle by slot index 87*89c4ff92SAndroid Build Coastguard Worker virtual const IOutputSlot& GetOutputSlot(unsigned int index) const = 0; 88*89c4ff92SAndroid Build Coastguard Worker 89*89c4ff92SAndroid Build Coastguard Worker /// Get the output slot handle by slot index 90*89c4ff92SAndroid Build Coastguard Worker virtual IOutputSlot& GetOutputSlot(unsigned int index) = 0; 91*89c4ff92SAndroid Build Coastguard Worker 92*89c4ff92SAndroid Build Coastguard Worker /// Infer the shape of the output(s) based on the provided input shape(s) 93*89c4ff92SAndroid Build Coastguard Worker virtual std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const = 0; 94*89c4ff92SAndroid Build Coastguard Worker 95*89c4ff92SAndroid Build Coastguard Worker /// Returns the unique id of the layer 96*89c4ff92SAndroid Build Coastguard Worker virtual LayerGuid GetGuid() const = 0; 97*89c4ff92SAndroid Build Coastguard Worker 98*89c4ff92SAndroid Build Coastguard Worker /// Apply a visitor to this layer 99*89c4ff92SAndroid Build Coastguard Worker virtual void ExecuteStrategy(IStrategy& strategy) const = 0; 100*89c4ff92SAndroid Build Coastguard Worker 101*89c4ff92SAndroid Build Coastguard Worker /// Provide a hint for the optimizer as to which backend to prefer for this layer. 102*89c4ff92SAndroid Build Coastguard Worker /// By providing a BackendSelectionHint there is no guarantee the input backend supports that layer. 103*89c4ff92SAndroid Build Coastguard Worker /// If IsLayerSupported() returns false with the backend hint, we default to calling IsLayerSupported() 104*89c4ff92SAndroid Build Coastguard Worker /// on the BackendPreferences vector. Use SetBackendId() if we can guarantee a backend supports that 105*89c4ff92SAndroid Build Coastguard Worker /// layer (IsLayerSupported returns true for a specific backend). 106*89c4ff92SAndroid Build Coastguard Worker virtual void BackendSelectionHint(Optional<BackendId> backend) = 0; 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker /// Returns the armnn::LayerType of this layer 109*89c4ff92SAndroid Build Coastguard Worker virtual LayerType GetType() const = 0; 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker /// If the layer has a descriptor return it. 112*89c4ff92SAndroid Build Coastguard Worker /// The base descriptor can then be cast to the correct descriptor class. 113*89c4ff92SAndroid Build Coastguard Worker /// If the layer has no associated descriptor a struct of type NullDescriptor will be returned. 114*89c4ff92SAndroid Build Coastguard Worker /// Note: NullDescriptors can be detected because they return true when 115*89c4ff92SAndroid Build Coastguard Worker /// the BaseDescriptor IsNull function is invoked. 116*89c4ff92SAndroid Build Coastguard Worker virtual const BaseDescriptor& GetParameters() const = 0; 117*89c4ff92SAndroid Build Coastguard Worker 118*89c4ff92SAndroid Build Coastguard Worker /// Set the backend of the IConnectableLayer. 119*89c4ff92SAndroid Build Coastguard Worker /// By using SetBackendId() we guarantee that the input backend supports that 120*89c4ff92SAndroid Build Coastguard Worker /// layer (IsLayerSupported returns true for a specific backend). If there is 121*89c4ff92SAndroid Build Coastguard Worker /// no guarantee the input backend supports that layer use BackendSelectionHint(). 122*89c4ff92SAndroid Build Coastguard Worker virtual void SetBackendId(const BackendId& id) = 0; 123*89c4ff92SAndroid Build Coastguard Worker 124*89c4ff92SAndroid Build Coastguard Worker using ConstantTensors = std::vector<std::reference_wrapper<std::shared_ptr<ConstTensorHandle>>>; 125*89c4ff92SAndroid Build Coastguard Worker 126*89c4ff92SAndroid Build Coastguard Worker // Returns ConstantTensors of this Layer if it has any, otherwise returns empty vector. 127*89c4ff92SAndroid Build Coastguard Worker virtual ConstantTensors GetConstantTensorsByRef() = 0; 128*89c4ff92SAndroid Build Coastguard Worker 129*89c4ff92SAndroid Build Coastguard Worker using ImmutableConstantTensors = std::vector<std::reference_wrapper<const std::shared_ptr<ConstTensorHandle>>>; 130*89c4ff92SAndroid Build Coastguard Worker 131*89c4ff92SAndroid Build Coastguard Worker // Returns ConstantTensors of this Layer if it has any, otherwise returns empty vector. 132*89c4ff92SAndroid Build Coastguard Worker virtual ImmutableConstantTensors GetConstantTensorsByRef() const = 0; 133*89c4ff92SAndroid Build Coastguard Worker 134*89c4ff92SAndroid Build Coastguard Worker protected: 135*89c4ff92SAndroid Build Coastguard Worker /// Objects are not deletable via the handle ~IConnectableLayer()136*89c4ff92SAndroid Build Coastguard Worker ~IConnectableLayer() {} 137*89c4ff92SAndroid Build Coastguard Worker }; 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker struct OptimizerOptions 140*89c4ff92SAndroid Build Coastguard Worker { 141*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use ABI stable OptimizerOptionsOpaque instead.", "24.02") OptimizerOptionsarmnn::OptimizerOptions142*89c4ff92SAndroid Build Coastguard Worker OptimizerOptions() 143*89c4ff92SAndroid Build Coastguard Worker : m_ReduceFp32ToFp16(false) 144*89c4ff92SAndroid Build Coastguard Worker , m_Debug(false) 145*89c4ff92SAndroid Build Coastguard Worker , m_DebugToFile(false) 146*89c4ff92SAndroid Build Coastguard Worker , m_ReduceFp32ToBf16(false) 147*89c4ff92SAndroid Build Coastguard Worker , m_shapeInferenceMethod(armnn::ShapeInferenceMethod::ValidateOnly) 148*89c4ff92SAndroid Build Coastguard Worker , m_ImportEnabled(false) 149*89c4ff92SAndroid Build Coastguard Worker , m_ModelOptions() 150*89c4ff92SAndroid Build Coastguard Worker , m_ProfilingEnabled(false) 151*89c4ff92SAndroid Build Coastguard Worker , m_ExportEnabled(false) 152*89c4ff92SAndroid Build Coastguard Worker , m_AllowExpandedDims(false) 153*89c4ff92SAndroid Build Coastguard Worker {} 154*89c4ff92SAndroid Build Coastguard Worker 155*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use ABI stable OptimizerOptionsOpaque instead.", "24.02") OptimizerOptionsarmnn::OptimizerOptions156*89c4ff92SAndroid Build Coastguard Worker OptimizerOptions(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16, bool importEnabled, 157*89c4ff92SAndroid Build Coastguard Worker ModelOptions modelOptions = {}, bool exportEnabled = false, bool debugToFile = false) 158*89c4ff92SAndroid Build Coastguard Worker : m_ReduceFp32ToFp16(reduceFp32ToFp16) 159*89c4ff92SAndroid Build Coastguard Worker , m_Debug(debug) 160*89c4ff92SAndroid Build Coastguard Worker , m_DebugToFile(debugToFile) 161*89c4ff92SAndroid Build Coastguard Worker , m_ReduceFp32ToBf16(reduceFp32ToBf16) 162*89c4ff92SAndroid Build Coastguard Worker , m_shapeInferenceMethod(armnn::ShapeInferenceMethod::ValidateOnly) 163*89c4ff92SAndroid Build Coastguard Worker , m_ImportEnabled(importEnabled) 164*89c4ff92SAndroid Build Coastguard Worker , m_ModelOptions(modelOptions) 165*89c4ff92SAndroid Build Coastguard Worker , m_ProfilingEnabled(false) 166*89c4ff92SAndroid Build Coastguard Worker , m_ExportEnabled(exportEnabled) 167*89c4ff92SAndroid Build Coastguard Worker , m_AllowExpandedDims(false) 168*89c4ff92SAndroid Build Coastguard Worker { 169*89c4ff92SAndroid Build Coastguard Worker } 170*89c4ff92SAndroid Build Coastguard Worker 171*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use ABI stable OptimizerOptionsOpaque instead.", "24.02") OptimizerOptionsarmnn::OptimizerOptions172*89c4ff92SAndroid Build Coastguard Worker OptimizerOptions(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16 = false, 173*89c4ff92SAndroid Build Coastguard Worker ShapeInferenceMethod shapeInferenceMethod = armnn::ShapeInferenceMethod::ValidateOnly, 174*89c4ff92SAndroid Build Coastguard Worker bool importEnabled = false, ModelOptions modelOptions = {}, bool exportEnabled = false, 175*89c4ff92SAndroid Build Coastguard Worker bool debugToFile = false, bool allowExpandedDims = false) 176*89c4ff92SAndroid Build Coastguard Worker : m_ReduceFp32ToFp16(reduceFp32ToFp16) 177*89c4ff92SAndroid Build Coastguard Worker , m_Debug(debug) 178*89c4ff92SAndroid Build Coastguard Worker , m_DebugToFile(debugToFile) 179*89c4ff92SAndroid Build Coastguard Worker , m_ReduceFp32ToBf16(reduceFp32ToBf16) 180*89c4ff92SAndroid Build Coastguard Worker , m_shapeInferenceMethod(shapeInferenceMethod) 181*89c4ff92SAndroid Build Coastguard Worker , m_ImportEnabled(importEnabled) 182*89c4ff92SAndroid Build Coastguard Worker , m_ModelOptions(modelOptions) 183*89c4ff92SAndroid Build Coastguard Worker , m_ProfilingEnabled(false) 184*89c4ff92SAndroid Build Coastguard Worker , m_ExportEnabled(exportEnabled) 185*89c4ff92SAndroid Build Coastguard Worker , m_AllowExpandedDims(allowExpandedDims) 186*89c4ff92SAndroid Build Coastguard Worker { 187*89c4ff92SAndroid Build Coastguard Worker } 188*89c4ff92SAndroid Build Coastguard Worker ToStringarmnn::OptimizerOptions189*89c4ff92SAndroid Build Coastguard Worker const std::string ToString() const 190*89c4ff92SAndroid Build Coastguard Worker { 191*89c4ff92SAndroid Build Coastguard Worker std::stringstream stream; 192*89c4ff92SAndroid Build Coastguard Worker stream << "OptimizerOptions: \n"; 193*89c4ff92SAndroid Build Coastguard Worker stream << "\tReduceFp32ToFp16: " << m_ReduceFp32ToFp16 << "\n"; 194*89c4ff92SAndroid Build Coastguard Worker stream << "\tReduceFp32ToBf16: " << m_ReduceFp32ToBf16 << "\n"; 195*89c4ff92SAndroid Build Coastguard Worker stream << "\tDebug: " << m_Debug << "\n"; 196*89c4ff92SAndroid Build Coastguard Worker stream << "\tDebug to file: " << m_DebugToFile << "\n"; 197*89c4ff92SAndroid Build Coastguard Worker stream << "\tShapeInferenceMethod: " << 198*89c4ff92SAndroid Build Coastguard Worker (m_shapeInferenceMethod == ShapeInferenceMethod::ValidateOnly 199*89c4ff92SAndroid Build Coastguard Worker ? "ValidateOnly" : "InferAndValidate") << "\n"; 200*89c4ff92SAndroid Build Coastguard Worker stream << "\tImportEnabled: " << m_ImportEnabled << "\n"; 201*89c4ff92SAndroid Build Coastguard Worker stream << "\tExportEnabled: " << m_ExportEnabled << "\n"; 202*89c4ff92SAndroid Build Coastguard Worker stream << "\tProfilingEnabled: " << m_ProfilingEnabled << "\n"; 203*89c4ff92SAndroid Build Coastguard Worker stream << "\tAllowExpandedDims: " << m_AllowExpandedDims << "\n"; 204*89c4ff92SAndroid Build Coastguard Worker 205*89c4ff92SAndroid Build Coastguard Worker stream << "\tModelOptions: \n"; 206*89c4ff92SAndroid Build Coastguard Worker for (auto optionsGroup : m_ModelOptions) 207*89c4ff92SAndroid Build Coastguard Worker { 208*89c4ff92SAndroid Build Coastguard Worker for (size_t i=0; i < optionsGroup.GetOptionCount(); i++) 209*89c4ff92SAndroid Build Coastguard Worker { 210*89c4ff92SAndroid Build Coastguard Worker const armnn::BackendOptions::BackendOption option = optionsGroup.GetOption(i); 211*89c4ff92SAndroid Build Coastguard Worker stream << "\t\tBackend: " << optionsGroup.GetBackendId() << "\n" 212*89c4ff92SAndroid Build Coastguard Worker << "\t\t\tOption: " << option.GetName() << "\n" 213*89c4ff92SAndroid Build Coastguard Worker << "\t\t\tValue: " << std::string(option.GetValue().ToString()) << "\n"; 214*89c4ff92SAndroid Build Coastguard Worker } 215*89c4ff92SAndroid Build Coastguard Worker } 216*89c4ff92SAndroid Build Coastguard Worker 217*89c4ff92SAndroid Build Coastguard Worker return stream.str(); 218*89c4ff92SAndroid Build Coastguard Worker } 219*89c4ff92SAndroid Build Coastguard Worker 220*89c4ff92SAndroid Build Coastguard Worker /// Reduces all Fp32 operators in the model to Fp16 for faster processing. 221*89c4ff92SAndroid Build Coastguard Worker /// @Note This feature works best if all operators of the model are in Fp32. ArmNN will add conversion layers 222*89c4ff92SAndroid Build Coastguard Worker /// between layers that weren't in Fp32 in the first place or if the operator is not supported in Fp16. 223*89c4ff92SAndroid Build Coastguard Worker /// The overhead of these conversions can lead to a slower overall performance if too many conversions are 224*89c4ff92SAndroid Build Coastguard Worker /// required. 225*89c4ff92SAndroid Build Coastguard Worker bool m_ReduceFp32ToFp16; 226*89c4ff92SAndroid Build Coastguard Worker 227*89c4ff92SAndroid Build Coastguard Worker /// Add debug data for easier troubleshooting 228*89c4ff92SAndroid Build Coastguard Worker bool m_Debug; 229*89c4ff92SAndroid Build Coastguard Worker 230*89c4ff92SAndroid Build Coastguard Worker /// Pass debug data to separate output files for easier troubleshooting 231*89c4ff92SAndroid Build Coastguard Worker bool m_DebugToFile; 232*89c4ff92SAndroid Build Coastguard Worker 233*89c4ff92SAndroid Build Coastguard Worker /// @Note This feature has been replaced by enabling Fast Math in compute library backend options. 234*89c4ff92SAndroid Build Coastguard Worker /// This is currently a placeholder option 235*89c4ff92SAndroid Build Coastguard Worker bool m_ReduceFp32ToBf16; 236*89c4ff92SAndroid Build Coastguard Worker 237*89c4ff92SAndroid Build Coastguard Worker /// Infer output size when not available 238*89c4ff92SAndroid Build Coastguard Worker ShapeInferenceMethod m_shapeInferenceMethod; 239*89c4ff92SAndroid Build Coastguard Worker 240*89c4ff92SAndroid Build Coastguard Worker /// Enable Import 241*89c4ff92SAndroid Build Coastguard Worker bool m_ImportEnabled; 242*89c4ff92SAndroid Build Coastguard Worker 243*89c4ff92SAndroid Build Coastguard Worker /// Enable Model Options 244*89c4ff92SAndroid Build Coastguard Worker ModelOptions m_ModelOptions; 245*89c4ff92SAndroid Build Coastguard Worker 246*89c4ff92SAndroid Build Coastguard Worker /// Enable profiling dump of the optimizer phase 247*89c4ff92SAndroid Build Coastguard Worker bool m_ProfilingEnabled; 248*89c4ff92SAndroid Build Coastguard Worker 249*89c4ff92SAndroid Build Coastguard Worker /// Enable Export 250*89c4ff92SAndroid Build Coastguard Worker bool m_ExportEnabled; 251*89c4ff92SAndroid Build Coastguard Worker 252*89c4ff92SAndroid Build Coastguard Worker /// When calculating tensor sizes, dimensions of size == 1 will be ignored 253*89c4ff92SAndroid Build Coastguard Worker bool m_AllowExpandedDims; 254*89c4ff92SAndroid Build Coastguard Worker }; 255*89c4ff92SAndroid Build Coastguard Worker 256*89c4ff92SAndroid Build Coastguard Worker /// ArmNN performs an optimization on each model/network before it gets loaded for execution. OptimizerOptions provides 257*89c4ff92SAndroid Build Coastguard Worker /// a set of features that allows the user to customize this optimization on a per model basis. 258*89c4ff92SAndroid Build Coastguard Worker struct OptimizerOptionsOpaqueImpl; 259*89c4ff92SAndroid Build Coastguard Worker 260*89c4ff92SAndroid Build Coastguard Worker class OptimizerOptionsOpaque 261*89c4ff92SAndroid Build Coastguard Worker { 262*89c4ff92SAndroid Build Coastguard Worker public: 263*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque(); 264*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque(const OptimizerOptionsOpaque& other); 265*89c4ff92SAndroid Build Coastguard Worker ~OptimizerOptionsOpaque(); 266*89c4ff92SAndroid Build Coastguard Worker 267*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque(const OptimizerOptions& OptimizerStruct); 268*89c4ff92SAndroid Build Coastguard Worker 269*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque& operator=(OptimizerOptionsOpaque other); 270*89c4ff92SAndroid Build Coastguard Worker 271*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16, bool importEnabled, 272*89c4ff92SAndroid Build Coastguard Worker ModelOptions modelOptions = {}, bool exportEnabled = false, bool debugToFile = false); 273*89c4ff92SAndroid Build Coastguard Worker 274*89c4ff92SAndroid Build Coastguard Worker OptimizerOptionsOpaque(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16 = false, 275*89c4ff92SAndroid Build Coastguard Worker ShapeInferenceMethod shapeInferenceMethod = armnn::ShapeInferenceMethod::ValidateOnly, 276*89c4ff92SAndroid Build Coastguard Worker bool importEnabled = false, ModelOptions modelOptions = {}, bool exportEnabled = false, 277*89c4ff92SAndroid Build Coastguard Worker bool debugToFile = false, bool allowExpandedDims = false); 278*89c4ff92SAndroid Build Coastguard Worker 279*89c4ff92SAndroid Build Coastguard Worker const std::string ToString() const; 280*89c4ff92SAndroid Build Coastguard Worker 281*89c4ff92SAndroid Build Coastguard Worker bool GetProfilingEnabled() const; 282*89c4ff92SAndroid Build Coastguard Worker 283*89c4ff92SAndroid Build Coastguard Worker bool GetImportEnabled() const; 284*89c4ff92SAndroid Build Coastguard Worker 285*89c4ff92SAndroid Build Coastguard Worker bool GetExportEnabled() const; 286*89c4ff92SAndroid Build Coastguard Worker 287*89c4ff92SAndroid Build Coastguard Worker bool GetReduceFp32ToFp16() const; 288*89c4ff92SAndroid Build Coastguard Worker 289*89c4ff92SAndroid Build Coastguard Worker bool GetReduceFp32ToBf16() const; 290*89c4ff92SAndroid Build Coastguard Worker 291*89c4ff92SAndroid Build Coastguard Worker bool GetDebugEnabled() const; 292*89c4ff92SAndroid Build Coastguard Worker 293*89c4ff92SAndroid Build Coastguard Worker bool GetDebugToFileEnabled() const; 294*89c4ff92SAndroid Build Coastguard Worker 295*89c4ff92SAndroid Build Coastguard Worker bool GetAllowExpandedDims() const; 296*89c4ff92SAndroid Build Coastguard Worker 297*89c4ff92SAndroid Build Coastguard Worker armnn::ModelOptions GetModelOptions() const; 298*89c4ff92SAndroid Build Coastguard Worker 299*89c4ff92SAndroid Build Coastguard Worker armnn::ShapeInferenceMethod GetShapeInferenceMethod() const; 300*89c4ff92SAndroid Build Coastguard Worker 301*89c4ff92SAndroid Build Coastguard Worker void SetImportEnabled(bool ImportState); 302*89c4ff92SAndroid Build Coastguard Worker 303*89c4ff92SAndroid Build Coastguard Worker void SetExportEnabled(bool ExportState); 304*89c4ff92SAndroid Build Coastguard Worker 305*89c4ff92SAndroid Build Coastguard Worker void SetProfilingEnabled(bool ProfilingState); 306*89c4ff92SAndroid Build Coastguard Worker 307*89c4ff92SAndroid Build Coastguard Worker void SetDebugEnabled(bool DebugState); 308*89c4ff92SAndroid Build Coastguard Worker 309*89c4ff92SAndroid Build Coastguard Worker void SetDebugToFileEnabled(bool DebugFileState); 310*89c4ff92SAndroid Build Coastguard Worker 311*89c4ff92SAndroid Build Coastguard Worker void SetReduceFp32ToFp16(bool ReduceFp32ToFp16State); 312*89c4ff92SAndroid Build Coastguard Worker 313*89c4ff92SAndroid Build Coastguard Worker void SetShapeInferenceMethod(armnn::ShapeInferenceMethod ShapeInferenceMethodType); 314*89c4ff92SAndroid Build Coastguard Worker 315*89c4ff92SAndroid Build Coastguard Worker void AddModelOption(armnn::BackendOptions); 316*89c4ff92SAndroid Build Coastguard Worker 317*89c4ff92SAndroid Build Coastguard Worker void SetAllowExpandedDims(bool ExpandedDimsAllowed); 318*89c4ff92SAndroid Build Coastguard Worker 319*89c4ff92SAndroid Build Coastguard Worker private: 320*89c4ff92SAndroid Build Coastguard Worker 321*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<armnn::OptimizerOptionsOpaqueImpl> p_OptimizerOptionsImpl; 322*89c4ff92SAndroid Build Coastguard Worker 323*89c4ff92SAndroid Build Coastguard Worker }; 324*89c4ff92SAndroid Build Coastguard Worker 325*89c4ff92SAndroid Build Coastguard Worker class IWorkloadFactory; 326*89c4ff92SAndroid Build Coastguard Worker class NetworkImpl; 327*89c4ff92SAndroid Build Coastguard Worker using INetworkPtr = std::unique_ptr<INetwork, void(*)(INetwork* network)>; 328*89c4ff92SAndroid Build Coastguard Worker using IOptimizedNetworkPtr = std::unique_ptr<IOptimizedNetwork, void(*)(IOptimizedNetwork* network)>; 329*89c4ff92SAndroid Build Coastguard Worker 330*89c4ff92SAndroid Build Coastguard Worker using CompiledBlobDeleter = std::function<void(const void*)>; 331*89c4ff92SAndroid Build Coastguard Worker using CompiledBlobPtr = std::unique_ptr<void, CompiledBlobDeleter>; 332*89c4ff92SAndroid Build Coastguard Worker 333*89c4ff92SAndroid Build Coastguard Worker /// Main network class which provides the interface for building up a neural network. 334*89c4ff92SAndroid Build Coastguard Worker /// This object is subsequently required by the IRuntime::Load() method. 335*89c4ff92SAndroid Build Coastguard Worker class INetwork 336*89c4ff92SAndroid Build Coastguard Worker { 337*89c4ff92SAndroid Build Coastguard Worker public: 338*89c4ff92SAndroid Build Coastguard Worker static INetwork* CreateRaw(const NetworkOptions& networkOptions = {}); 339*89c4ff92SAndroid Build Coastguard Worker static INetworkPtr Create(const NetworkOptions& networkOptions = {}); 340*89c4ff92SAndroid Build Coastguard Worker static void Destroy(INetwork* network); 341*89c4ff92SAndroid Build Coastguard Worker 342*89c4ff92SAndroid Build Coastguard Worker Status PrintGraph(); 343*89c4ff92SAndroid Build Coastguard Worker 344*89c4ff92SAndroid Build Coastguard Worker /// Adds an input layer to the network. 345*89c4ff92SAndroid Build Coastguard Worker /// @param id - User generated id to uniquely identify a particular input. The same id needs to be specified. 346*89c4ff92SAndroid Build Coastguard Worker /// when passing the inputs to the IRuntime::EnqueueWorkload() function. 347*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 348*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 349*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name = nullptr); 350*89c4ff92SAndroid Build Coastguard Worker 351*89c4ff92SAndroid Build Coastguard Worker /// Adds an ArgMinMax layer to the network. 352*89c4ff92SAndroid Build Coastguard Worker /// @param desc - Parameters for the L2 normalization operation. 353*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 354*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 355*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc, 356*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 357*89c4ff92SAndroid Build Coastguard Worker 358*89c4ff92SAndroid Build Coastguard Worker /// Adds a cast layer to the network. 359*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 360*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 361*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddCastLayer(const char* name = nullptr); 362*89c4ff92SAndroid Build Coastguard Worker 363*89c4ff92SAndroid Build Coastguard Worker /// Add a Comparison layer to the network. 364*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 365*89c4ff92SAndroid Build Coastguard Worker /// @param desc - Descriptor for the comparison operation. 366*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 367*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor, 368*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 369*89c4ff92SAndroid Build Coastguard Worker 370*89c4ff92SAndroid Build Coastguard Worker /// Adds a concatenation layer to the network. 371*89c4ff92SAndroid Build Coastguard Worker /// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation 372*89c4ff92SAndroid Build Coastguard Worker /// process. Number of Views must be equal to the number of inputs, and their order 373*89c4ff92SAndroid Build Coastguard Worker /// must match - e.g. first view corresponds to the first input, second view to the 374*89c4ff92SAndroid Build Coastguard Worker /// second input, etc.... 375*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 376*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 377*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor, 378*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 379*89c4ff92SAndroid Build Coastguard Worker 380*89c4ff92SAndroid Build Coastguard Worker /// Adds a 2D convolution layer to the network. 381*89c4ff92SAndroid Build Coastguard Worker /// @param convolution2dDescriptor - Description of the 2D convolution layer. 382*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 383*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 384*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor, 385*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 386*89c4ff92SAndroid Build Coastguard Worker 387*89c4ff92SAndroid Build Coastguard Worker /// Adds a 3D convolution layer to the network. 388*89c4ff92SAndroid Build Coastguard Worker /// @param convolution3dDescriptor - Description of the 3D convolution layer. 389*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 390*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 391*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddConvolution3dLayer(const Convolution3dDescriptor& convolution3dDescriptor, 392*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 393*89c4ff92SAndroid Build Coastguard Worker 394*89c4ff92SAndroid Build Coastguard Worker /// Adds a depth to space layer to the network. 395*89c4ff92SAndroid Build Coastguard Worker /// @param depthToSpaceDescriptor - Parameters for the depth to space operation. 396*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 397*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 398*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddDepthToSpaceLayer(const DepthToSpaceDescriptor& depthToSpaceDescriptor, 399*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 400*89c4ff92SAndroid Build Coastguard Worker 401*89c4ff92SAndroid Build Coastguard Worker /// Adds a 2D depthwise convolution layer to the network. 402*89c4ff92SAndroid Build Coastguard Worker /// @param convolution2dDescriptor - Description of the 2D depthwise convolution layer. 403*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 404*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 405*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddDepthwiseConvolution2dLayer(const DepthwiseConvolution2dDescriptor& convolution2dDescriptor, 406*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 407*89c4ff92SAndroid Build Coastguard Worker 408*89c4ff92SAndroid Build Coastguard Worker /// Adds a Dequantize layer to the network. 409*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 410*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddDequantizeLayer(const char* name = nullptr); 411*89c4ff92SAndroid Build Coastguard Worker 412*89c4ff92SAndroid Build Coastguard Worker /// Adds a Detection PostProcess layer to the network. 413*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Description of the Detection PostProcess layer. 414*89c4ff92SAndroid Build Coastguard Worker /// @param anchors - Tensor for anchors. 415*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 416*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 417*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddDetectionPostProcessLayer( 418*89c4ff92SAndroid Build Coastguard Worker const DetectionPostProcessDescriptor& descriptor, 419*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& anchors, 420*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 421*89c4ff92SAndroid Build Coastguard Worker 422*89c4ff92SAndroid Build Coastguard Worker /// Add an ElementwiseBinary layer to the network. 423*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 424*89c4ff92SAndroid Build Coastguard Worker /// @param desc - Descriptor for the elementwiseBinary operations. 425*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 426*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddElementwiseBinaryLayer(const ElementwiseBinaryDescriptor& elementwiseUnaryDescriptor, 427*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 428*89c4ff92SAndroid Build Coastguard Worker 429*89c4ff92SAndroid Build Coastguard Worker /// Add an ElementwiseUnary layer to the network. 430*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 431*89c4ff92SAndroid Build Coastguard Worker /// @param desc - Descriptor for the elementwiseUnary operations. 432*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 433*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor, 434*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 435*89c4ff92SAndroid Build Coastguard Worker 436*89c4ff92SAndroid Build Coastguard Worker /// Add an Fill layer to the network. 437*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 438*89c4ff92SAndroid Build Coastguard Worker /// @param fillDescriptor - Descriptor for the fill operation. 439*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 440*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddFillLayer(const FillDescriptor& fillDescriptor, 441*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 442*89c4ff92SAndroid Build Coastguard Worker 443*89c4ff92SAndroid Build Coastguard Worker 444*89c4ff92SAndroid Build Coastguard Worker /// Adds a fully connected layer to the network. 445*89c4ff92SAndroid Build Coastguard Worker /// @param fullyConnectedDescriptor - Description of the fully connected layer. 446*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 447*89c4ff92SAndroid Build Coastguard Worker /// 448*89c4ff92SAndroid Build Coastguard Worker /// @note Weights and biases are passed in as inputs. If they are constant tensors you can simply store 449*89c4ff92SAndroid Build Coastguard Worker /// them in a ConstantLayer as seen below. A full example can be found in samples/SimpleSample.cpp. 450*89c4ff92SAndroid Build Coastguard Worker /// 451*89c4ff92SAndroid Build Coastguard Worker /// @code 452*89c4ff92SAndroid Build Coastguard Worker /// // Make sure the IsConstant flag is set on the weightsInfo before passing it to the ConstTensor. 453*89c4ff92SAndroid Build Coastguard Worker /// ConstTensor weights(weightsInfo, weightsData); 454*89c4ff92SAndroid Build Coastguard Worker /// 455*89c4ff92SAndroid Build Coastguard Worker /// // Constant layer that now holds weights data for FullyConnected 456*89c4ff92SAndroid Build Coastguard Worker /// IConnectableLayer* const constantWeightsLayer = myNetwork->AddConstantLayer(weights, "weights"); 457*89c4ff92SAndroid Build Coastguard Worker /// 458*89c4ff92SAndroid Build Coastguard Worker /// FullyConnectedDescriptor fullyConnectedDesc; 459*89c4ff92SAndroid Build Coastguard Worker /// IConnectableLayer* const fullyConnectedLayer = myNetwork->AddFullyConnectedLayer(fullyConnectedDesc, 460*89c4ff92SAndroid Build Coastguard Worker /// "fully connected"); 461*89c4ff92SAndroid Build Coastguard Worker /// IConnectableLayer* InputLayer = myNetwork->AddInputLayer(0); 462*89c4ff92SAndroid Build Coastguard Worker /// InputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0)); 463*89c4ff92SAndroid Build Coastguard Worker /// constantWeightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1)); 464*89c4ff92SAndroid Build Coastguard Worker /// @endcode 465*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddFullyConnectedLayer(const FullyConnectedDescriptor& fullyConnectedDescriptor, 466*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 467*89c4ff92SAndroid Build Coastguard Worker 468*89c4ff92SAndroid Build Coastguard Worker /// Adds a permute layer to the network. 469*89c4ff92SAndroid Build Coastguard Worker /// @param permuteDescriptor - PermuteDescriptor to configure the permute. 470*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 471*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 472*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddPermuteLayer(const PermuteDescriptor& permuteDescriptor, 473*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 474*89c4ff92SAndroid Build Coastguard Worker 475*89c4ff92SAndroid Build Coastguard Worker /// Adds a batch to space ND layer to the network. 476*89c4ff92SAndroid Build Coastguard Worker /// @param batchToSpaceNdDescriptor - Description of the layer. 477*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 478*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 479*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor, 480*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 481*89c4ff92SAndroid Build Coastguard Worker 482*89c4ff92SAndroid Build Coastguard Worker /// Adds a 2D pooling layer to the network. 483*89c4ff92SAndroid Build Coastguard Worker /// @param pooling2dDescriptor - Pooling2dDescriptor to configure the pooling. 484*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 485*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 486*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddPooling2dLayer(const Pooling2dDescriptor& pooling2dDescriptor, 487*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 488*89c4ff92SAndroid Build Coastguard Worker 489*89c4ff92SAndroid Build Coastguard Worker /// Adds a 3D pooling layer to the network. 490*89c4ff92SAndroid Build Coastguard Worker /// @param pooling3dDescriptor - Pooling3dDescriptor to configure the pooling. 491*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 492*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 493*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddPooling3dLayer(const Pooling3dDescriptor& pooling3dDescriptor, 494*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 495*89c4ff92SAndroid Build Coastguard Worker 496*89c4ff92SAndroid Build Coastguard Worker /// Adds a Precompiled layer to the network. 497*89c4ff92SAndroid Build Coastguard Worker /// Method use is for backend users. 498*89c4ff92SAndroid Build Coastguard Worker /// @param preCompiledDescriptor - PreCompiledDescriptor contains parameters for the Precompiled layer. 499*89c4ff92SAndroid Build Coastguard Worker /// @param compiledBlobPtr - CompiledBlobPtr pre-compiled object set for the Precompiled layer. 500*89c4ff92SAndroid Build Coastguard Worker /// @param backend - optional BackendId set for the Precompiled layer. 501*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 502*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddPrecompiledLayer(const PreCompiledDescriptor& preCompiledDescriptor, 503*89c4ff92SAndroid Build Coastguard Worker CompiledBlobPtr compiledBlobPtr, 504*89c4ff92SAndroid Build Coastguard Worker const Optional<BackendId>& backend, 505*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 506*89c4ff92SAndroid Build Coastguard Worker 507*89c4ff92SAndroid Build Coastguard Worker /// Adds an activation layer to the network. 508*89c4ff92SAndroid Build Coastguard Worker /// @param activationDescriptor - ActivationDescriptor to configure the activation. 509*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 510*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 511*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddActivationLayer(const ActivationDescriptor& activationDescriptor, 512*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 513*89c4ff92SAndroid Build Coastguard Worker 514*89c4ff92SAndroid Build Coastguard Worker /// Adds a normalization layer to the network. 515*89c4ff92SAndroid Build Coastguard Worker /// @param normalizationDescriptor - NormalizationDescriptor to configure the normalization. 516*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 517*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 518*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddNormalizationLayer(const NormalizationDescriptor& normalizationDescriptor, 519*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 520*89c4ff92SAndroid Build Coastguard Worker 521*89c4ff92SAndroid Build Coastguard Worker /// Adds a slice layer to the network. 522*89c4ff92SAndroid Build Coastguard Worker /// @param sliceDescriptor - SliceDescriptor to configure the slice operation. 523*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 524*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 525*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddSliceLayer(const SliceDescriptor& sliceDescriptor, const char* name = nullptr); 526*89c4ff92SAndroid Build Coastguard Worker 527*89c4ff92SAndroid Build Coastguard Worker /// Adds a softmax layer to the network. 528*89c4ff92SAndroid Build Coastguard Worker /// If the data type is QAsymm8, then the output quantization parameters 529*89c4ff92SAndroid Build Coastguard Worker /// must have a scale of 1/256 and an offset of 0 530*89c4ff92SAndroid Build Coastguard Worker /// @param softmaxDescriptor - SoftmaxDescriptor to configure the softmax. 531*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 532*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 533*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddSoftmaxLayer(const SoftmaxDescriptor& softmaxDescriptor, 534*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 535*89c4ff92SAndroid Build Coastguard Worker 536*89c4ff92SAndroid Build Coastguard Worker /// Adds a splitter layer to the network. 537*89c4ff92SAndroid Build Coastguard Worker /// @param splitterDescriptor - ViewsDescriptor to configure the splitting process. 538*89c4ff92SAndroid Build Coastguard Worker /// Number of Views must be equal to the number of outputs, 539*89c4ff92SAndroid Build Coastguard Worker /// and their order must match - e.g. first view corresponds to 540*89c4ff92SAndroid Build Coastguard Worker /// the first output, second view to the second output, etc.... 541*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 542*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 543*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor, 544*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 545*89c4ff92SAndroid Build Coastguard Worker 546*89c4ff92SAndroid Build Coastguard Worker /// Adds a merge layer to the network. 547*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 548*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 549*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddMergeLayer(const char* name = nullptr); 550*89c4ff92SAndroid Build Coastguard Worker 551*89c4ff92SAndroid Build Coastguard Worker /// Adds an addition layer to the network. 552*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 553*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 554*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02") 555*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddAdditionLayer(const char* name = nullptr); 556*89c4ff92SAndroid Build Coastguard Worker 557*89c4ff92SAndroid Build Coastguard Worker /// Adds a multiplication layer to the network. 558*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 559*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 560*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02") 561*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddMultiplicationLayer(const char* name = nullptr); 562*89c4ff92SAndroid Build Coastguard Worker 563*89c4ff92SAndroid Build Coastguard Worker /// Adds a batch normalization layer to the network. 564*89c4ff92SAndroid Build Coastguard Worker /// @param mean - Pre-calculated mean for each channel. 565*89c4ff92SAndroid Build Coastguard Worker /// @param variance - Pre-calculated variance for each channel. 566*89c4ff92SAndroid Build Coastguard Worker /// @param beta - Per-channel additive factor. 567*89c4ff92SAndroid Build Coastguard Worker /// @param gamma - Per-channel multiplicative factor. 568*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 569*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 570*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc, 571*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& mean, 572*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& variance, 573*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& beta, 574*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& gamma, 575*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 576*89c4ff92SAndroid Build Coastguard Worker 577*89c4ff92SAndroid Build Coastguard Worker /// Adds a rank layer to the network. 578*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 579*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 580*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddRankLayer(const char* name = nullptr); 581*89c4ff92SAndroid Build Coastguard Worker 582*89c4ff92SAndroid Build Coastguard Worker /// Adds a resize layer to the network. 583*89c4ff92SAndroid Build Coastguard Worker /// @param resizeDescriptor - Parameters for the resize operation. 584*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 585*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 586*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddResizeLayer(const ResizeDescriptor& resizeDescriptor, 587*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 588*89c4ff92SAndroid Build Coastguard Worker 589*89c4ff92SAndroid Build Coastguard Worker /// Adds a reduce layer to the network. 590*89c4ff92SAndroid Build Coastguard Worker /// @param ReduceDescriptor - Parameters for the reduce operation. 591*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 592*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 593*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddReduceLayer(const ReduceDescriptor& reduceDescriptor, 594*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 595*89c4ff92SAndroid Build Coastguard Worker 596*89c4ff92SAndroid Build Coastguard Worker /// Adds an instance normalization layer to the network. 597*89c4ff92SAndroid Build Coastguard Worker /// @param desc - Parameters for the instance normalization operation. 598*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 599*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 600*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddInstanceNormalizationLayer(const InstanceNormalizationDescriptor& desc, 601*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 602*89c4ff92SAndroid Build Coastguard Worker 603*89c4ff92SAndroid Build Coastguard Worker /// Adds an L2 normalization layer to the network. 604*89c4ff92SAndroid Build Coastguard Worker /// Normalization is performed along dimension 1, but requires a 4d input. 605*89c4ff92SAndroid Build Coastguard Worker /// @param desc - Parameters for the L2 normalization operation. 606*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 607*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 608*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddL2NormalizationLayer(const L2NormalizationDescriptor& desc, 609*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 610*89c4ff92SAndroid Build Coastguard Worker 611*89c4ff92SAndroid Build Coastguard Worker /// Adds a log softmax layer to the network. 612*89c4ff92SAndroid Build Coastguard Worker /// @param logSoftmaxDescriptor - LogSoftmaxDescriptor to configure the log softmax. 613*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 614*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 615*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddLogSoftmaxLayer(const LogSoftmaxDescriptor& logSoftmaxDescriptor, 616*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 617*89c4ff92SAndroid Build Coastguard Worker 618*89c4ff92SAndroid Build Coastguard Worker /// Adds a layer with no inputs and a single output, which always corresponds to 619*89c4ff92SAndroid Build Coastguard Worker /// the passed in constant tensor. 620*89c4ff92SAndroid Build Coastguard Worker /// @param input - Tensor to be provided as the only output of the layer. The layer will maintain 621*89c4ff92SAndroid Build Coastguard Worker /// its own copy of the tensor data, meaning the memory referenced by @a input can 622*89c4ff92SAndroid Build Coastguard Worker /// be freed or reused after this function is called. 623*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 624*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 625*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddConstantLayer(const ConstTensor& input, 626*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 627*89c4ff92SAndroid Build Coastguard Worker 628*89c4ff92SAndroid Build Coastguard Worker /// Adds a reshape layer to the network. 629*89c4ff92SAndroid Build Coastguard Worker /// @param reshapeDescriptor - Parameters for the reshape operation. 630*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 631*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 632*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddReshapeLayer(const ReshapeDescriptor& reshapeDescriptor, 633*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 634*89c4ff92SAndroid Build Coastguard Worker 635*89c4ff92SAndroid Build Coastguard Worker /// Adds a shape layer to the network. 636*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 637*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 638*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddShapeLayer(const char* name = nullptr); 639*89c4ff92SAndroid Build Coastguard Worker 640*89c4ff92SAndroid Build Coastguard Worker /// Adds a space to batch layer to the network. 641*89c4ff92SAndroid Build Coastguard Worker /// @param spaceToBatchNdDescriptor - Parameters for the space to batch operation. 642*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 643*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 644*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddSpaceToBatchNdLayer(const SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor, 645*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 646*89c4ff92SAndroid Build Coastguard Worker 647*89c4ff92SAndroid Build Coastguard Worker /// Adds a space to depth layer to the network. 648*89c4ff92SAndroid Build Coastguard Worker /// @param spaceToDepthDescriptor - Parameters for the space to depth operation. 649*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 650*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 651*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddSpaceToDepthLayer(const SpaceToDepthDescriptor& spaceToDepthDescriptor, 652*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 653*89c4ff92SAndroid Build Coastguard Worker 654*89c4ff92SAndroid Build Coastguard Worker /// Adds a floor layer to the network. 655*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 656*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 657*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddFloorLayer(const char* name = nullptr); 658*89c4ff92SAndroid Build Coastguard Worker 659*89c4ff92SAndroid Build Coastguard Worker /// Adds an output layer to the network. 660*89c4ff92SAndroid Build Coastguard Worker /// @param id - User generated id to uniquely identify a particular output. The same id needs to be specified 661*89c4ff92SAndroid Build Coastguard Worker /// when passing the outputs to the IRuntime::EnqueueWorkload() function. 662*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 663*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 664*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr); 665*89c4ff92SAndroid Build Coastguard Worker 666*89c4ff92SAndroid Build Coastguard Worker /// Add a Lstm layer to the network 667*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Parameters for the Lstm operation 668*89c4ff92SAndroid Build Coastguard Worker /// @param params - Weights and biases for the LSTM cell 669*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer 670*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 671*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor, 672*89c4ff92SAndroid Build Coastguard Worker const LstmInputParams& params, 673*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 674*89c4ff92SAndroid Build Coastguard Worker 675*89c4ff92SAndroid Build Coastguard Worker /// Adds a division layer to the network. 676*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 677*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 678*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02") 679*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddDivisionLayer(const char* name = nullptr); 680*89c4ff92SAndroid Build Coastguard Worker 681*89c4ff92SAndroid Build Coastguard Worker /// Adds a subtraction layer to the network. 682*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 683*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 684*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02") 685*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddSubtractionLayer(const char* name = nullptr); 686*89c4ff92SAndroid Build Coastguard Worker 687*89c4ff92SAndroid Build Coastguard Worker /// Add a Maximum layer to the network. 688*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 689*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 690*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02") 691*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddMaximumLayer(const char* name = nullptr); 692*89c4ff92SAndroid Build Coastguard Worker 693*89c4ff92SAndroid Build Coastguard Worker /// Add a Mean layer to the network. 694*89c4ff92SAndroid Build Coastguard Worker /// @param meanDescriptor - Parameters for the mean operation. 695*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 696*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 697*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddMeanLayer(const MeanDescriptor& meanDescriptor, const char* name = nullptr); 698*89c4ff92SAndroid Build Coastguard Worker 699*89c4ff92SAndroid Build Coastguard Worker /// Adds a fully pad layer to the network. 700*89c4ff92SAndroid Build Coastguard Worker /// @param paddings - n by 2 tensor, where n is the rank of the input tensor, 701*89c4ff92SAndroid Build Coastguard Worker /// such that paddings[i,0] indicates the amount of padding to add in front of dimonsion i, and 702*89c4ff92SAndroid Build Coastguard Worker /// paddings[i,1] indicates the amount of padding to add after the end of dimension i 703*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 704*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 705*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddPadLayer(const PadDescriptor& padDescriptor, 706*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 707*89c4ff92SAndroid Build Coastguard Worker 708*89c4ff92SAndroid Build Coastguard Worker /// Add a quantize layer to the network 709*89c4ff92SAndroid Build Coastguard Worker ///@param name - Optional name for the layer. 710*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 711*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddQuantizeLayer(const char* name = nullptr); 712*89c4ff92SAndroid Build Coastguard Worker 713*89c4ff92SAndroid Build Coastguard Worker /// Adds a strided slice layer to the network. 714*89c4ff92SAndroid Build Coastguard Worker /// @param StridedSliceDescriptor - Parameters for the strided slice operation. 715*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 716*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 717*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddStridedSliceLayer(const StridedSliceDescriptor& stridedSliceDescriptor, 718*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 719*89c4ff92SAndroid Build Coastguard Worker 720*89c4ff92SAndroid Build Coastguard Worker /// Add a Minimum layer to the network. 721*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 722*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 723*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE("Use AddElementwiseBinaryLayer instead", "24.02") 724*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddMinimumLayer(const char* name = nullptr); 725*89c4ff92SAndroid Build Coastguard Worker 726*89c4ff92SAndroid Build Coastguard Worker /// Add Gather layer to the network. 727*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Description of the gather layer. 728*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 729*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 730*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddGatherLayer(const GatherDescriptor& descriptor, 731*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 732*89c4ff92SAndroid Build Coastguard Worker 733*89c4ff92SAndroid Build Coastguard Worker /// Add GatherNd layer to the network. 734*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 735*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 736*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddGatherNdLayer(const char* name = nullptr); 737*89c4ff92SAndroid Build Coastguard Worker 738*89c4ff92SAndroid Build Coastguard Worker /// Adds a switch layer to the network. 739*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 740*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 741*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddSwitchLayer(const char* name = nullptr); 742*89c4ff92SAndroid Build Coastguard Worker 743*89c4ff92SAndroid Build Coastguard Worker /// Adds a PReLU layer to the network. 744*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 745*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 746*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddPreluLayer(const char* name = nullptr); 747*89c4ff92SAndroid Build Coastguard Worker 748*89c4ff92SAndroid Build Coastguard Worker /// Adds a 2D transpose convolution layer to the network. 749*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Description of the 2D transpose convolution layer. 750*89c4ff92SAndroid Build Coastguard Worker /// @param weights - Tensor for the weights data. 751*89c4ff92SAndroid Build Coastguard Worker /// @param biases - Optional tensor for the bias data. 752*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 753*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 754*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddTransposeConvolution2dLayer(const TransposeConvolution2dDescriptor& descriptor, 755*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& weights, 756*89c4ff92SAndroid Build Coastguard Worker const Optional<ConstTensor>& biases, 757*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 758*89c4ff92SAndroid Build Coastguard Worker 759*89c4ff92SAndroid Build Coastguard Worker /// Adds a transpose layer to the network. 760*89c4ff92SAndroid Build Coastguard Worker /// @param transposeDescriptor - TransposeDescriptor to configure the transpose. 761*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 762*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 763*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddTransposeLayer(const TransposeDescriptor& transposeDescriptor, 764*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 765*89c4ff92SAndroid Build Coastguard Worker 766*89c4ff92SAndroid Build Coastguard Worker /// Adds a stack layer to the network. 767*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Description of the stack layer. 768*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 769*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 770*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddStackLayer(const StackDescriptor& descriptor, 771*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 772*89c4ff92SAndroid Build Coastguard Worker 773*89c4ff92SAndroid Build Coastguard Worker /// Add a stand-in layer for a type unknown to the Arm NN framework. 774*89c4ff92SAndroid Build Coastguard Worker /// Note: Due to the nature of this layer, no validation can be performed by the framework. 775*89c4ff92SAndroid Build Coastguard Worker /// Furthermore, Any model containing this layer cannot make use of dynamic tensors since the 776*89c4ff92SAndroid Build Coastguard Worker /// tensor sizes cannot be inferred. 777*89c4ff92SAndroid Build Coastguard Worker /// @descriptor - Descriptor for the StandIn layer. 778*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 779*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddStandInLayer(const StandInDescriptor& descriptor, 780*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 781*89c4ff92SAndroid Build Coastguard Worker 782*89c4ff92SAndroid Build Coastguard Worker /// Add a QuantizedLstm layer to the network 783*89c4ff92SAndroid Build Coastguard Worker /// @param params - The weights and biases for the Quantized LSTM cell 784*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer 785*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 786*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params, 787*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 788*89c4ff92SAndroid Build Coastguard Worker 789*89c4ff92SAndroid Build Coastguard Worker /// Add a QLstm layer to the network 790*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Parameters for the QLstm operation 791*89c4ff92SAndroid Build Coastguard Worker /// @param params - Weights and biases for the layer 792*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer 793*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 794*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddQLstmLayer(const QLstmDescriptor& descriptor, 795*89c4ff92SAndroid Build Coastguard Worker const LstmInputParams& params, 796*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 797*89c4ff92SAndroid Build Coastguard Worker 798*89c4ff92SAndroid Build Coastguard Worker /// Adds a Logical Binary layer to the network. 799*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Description of the Logical Binary layer. 800*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer. 801*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 802*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddLogicalBinaryLayer(const LogicalBinaryDescriptor& descriptor, 803*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 804*89c4ff92SAndroid Build Coastguard Worker 805*89c4ff92SAndroid Build Coastguard Worker /// Add a UnidirectionalSequenceLstm layer to the network 806*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Parameters for the UnidirectionalSequenceLstm operation 807*89c4ff92SAndroid Build Coastguard Worker /// @param params - Weights and biases for the UnidirectionalSequenceLstm 808*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer 809*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer. 810*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddUnidirectionalSequenceLstmLayer(const UnidirectionalSequenceLstmDescriptor& descriptor, 811*89c4ff92SAndroid Build Coastguard Worker const LstmInputParams& params, 812*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 813*89c4ff92SAndroid Build Coastguard Worker 814*89c4ff92SAndroid Build Coastguard Worker /// Add a ChannelShuffle layer to the network 815*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Parameters for the ChannelShuffle operation 816*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer 817*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer 818*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddChannelShuffleLayer(const ChannelShuffleDescriptor& descriptor, 819*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 820*89c4ff92SAndroid Build Coastguard Worker 821*89c4ff92SAndroid Build Coastguard Worker /// Add a BatchMatMul layer to the network 822*89c4ff92SAndroid Build Coastguard Worker /// @param descriptor - Parameters for the BatchMatMul operation 823*89c4ff92SAndroid Build Coastguard Worker /// @param name - Optional name for the layer 824*89c4ff92SAndroid Build Coastguard Worker /// @return - Interface for configuring the layer 825*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* AddBatchMatMulLayer(const BatchMatMulDescriptor& descriptor, 826*89c4ff92SAndroid Build Coastguard Worker const char* name = nullptr); 827*89c4ff92SAndroid Build Coastguard Worker 828*89c4ff92SAndroid Build Coastguard Worker void ExecuteStrategy(IStrategy& strategy) const; 829*89c4ff92SAndroid Build Coastguard Worker 830*89c4ff92SAndroid Build Coastguard Worker protected: 831*89c4ff92SAndroid Build Coastguard Worker ~INetwork(); 832*89c4ff92SAndroid Build Coastguard Worker 833*89c4ff92SAndroid Build Coastguard Worker friend void VisitLayersTopologically(const INetwork* inputNetwork, IStrategy& strategy); 834*89c4ff92SAndroid Build Coastguard Worker friend class TestConnectionPreservation; 835*89c4ff92SAndroid Build Coastguard Worker friend TensorInfo GetInputTensorInfo(const INetwork* network); 836*89c4ff92SAndroid Build Coastguard Worker friend IOptimizedNetworkPtr Optimize(const INetwork& network, 837*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences, 838*89c4ff92SAndroid Build Coastguard Worker const IDeviceSpec& deviceSpec, 839*89c4ff92SAndroid Build Coastguard Worker const OptimizerOptions& options, 840*89c4ff92SAndroid Build Coastguard Worker Optional<std::vector<std::string>&> messages); 841*89c4ff92SAndroid Build Coastguard Worker friend IOptimizedNetworkPtr Optimize(const INetwork& network, 842*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences, 843*89c4ff92SAndroid Build Coastguard Worker const IDeviceSpec& deviceSpec, 844*89c4ff92SAndroid Build Coastguard Worker const OptimizerOptionsOpaque& options, 845*89c4ff92SAndroid Build Coastguard Worker Optional<std::vector<std::string>&> messages); 846*89c4ff92SAndroid Build Coastguard Worker 847*89c4ff92SAndroid Build Coastguard Worker INetwork(NetworkOptions networkOptions = {}); 848*89c4ff92SAndroid Build Coastguard Worker 849*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<NetworkImpl> pNetworkImpl; 850*89c4ff92SAndroid Build Coastguard Worker }; 851*89c4ff92SAndroid Build Coastguard Worker 852*89c4ff92SAndroid Build Coastguard Worker namespace experimental 853*89c4ff92SAndroid Build Coastguard Worker { 854*89c4ff92SAndroid Build Coastguard Worker class AsyncNetworkImpl; 855*89c4ff92SAndroid Build Coastguard Worker class WorkingMemHandle; 856*89c4ff92SAndroid Build Coastguard Worker } 857*89c4ff92SAndroid Build Coastguard Worker 858*89c4ff92SAndroid Build Coastguard Worker struct BackendSettings; 859*89c4ff92SAndroid Build Coastguard Worker struct OptimizationResult; 860*89c4ff92SAndroid Build Coastguard Worker class OptimizedNetworkImpl; 861*89c4ff92SAndroid Build Coastguard Worker class IProfiler; 862*89c4ff92SAndroid Build Coastguard Worker class IOptimizedNetwork 863*89c4ff92SAndroid Build Coastguard Worker { 864*89c4ff92SAndroid Build Coastguard Worker public: 865*89c4ff92SAndroid Build Coastguard Worker static void Destroy(IOptimizedNetwork* network); 866*89c4ff92SAndroid Build Coastguard Worker 867*89c4ff92SAndroid Build Coastguard Worker Status PrintGraph(); 868*89c4ff92SAndroid Build Coastguard Worker Status SerializeToDot(std::ostream& stream) const; 869*89c4ff92SAndroid Build Coastguard Worker 870*89c4ff92SAndroid Build Coastguard Worker arm::pipe::ProfilingGuid GetGuid() const; 871*89c4ff92SAndroid Build Coastguard Worker 872*89c4ff92SAndroid Build Coastguard Worker size_t GetNumInputs() const; 873*89c4ff92SAndroid Build Coastguard Worker size_t GetNumOutputs() const; 874*89c4ff92SAndroid Build Coastguard Worker 875*89c4ff92SAndroid Build Coastguard Worker void ExecuteStrategy(IStrategy& strategy) const; 876*89c4ff92SAndroid Build Coastguard Worker 877*89c4ff92SAndroid Build Coastguard Worker /// Creates a copy of the IOptimizedNetwork. The IOptimizedNetwork will not be reoptimized, 878*89c4ff92SAndroid Build Coastguard Worker /// the provided ModelOptions will only be used when creating a LoadedNetwork. 879*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetwork(const IOptimizedNetwork& other, const ModelOptions& modelOptions); 880*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetwork(std::unique_ptr<Graph> graph); 881*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetwork(std::unique_ptr<OptimizedNetworkImpl> impl); 882*89c4ff92SAndroid Build Coastguard Worker ~IOptimizedNetwork(); 883*89c4ff92SAndroid Build Coastguard Worker 884*89c4ff92SAndroid Build Coastguard Worker const std::shared_ptr<IProfiler>& GetProfiler() const; 885*89c4ff92SAndroid Build Coastguard Worker 886*89c4ff92SAndroid Build Coastguard Worker protected: 887*89c4ff92SAndroid Build Coastguard Worker friend class LoadedNetwork; 888*89c4ff92SAndroid Build Coastguard Worker 889*89c4ff92SAndroid Build Coastguard Worker friend class experimental::AsyncNetworkImpl; 890*89c4ff92SAndroid Build Coastguard Worker friend class experimental::WorkingMemHandle; 891*89c4ff92SAndroid Build Coastguard Worker 892*89c4ff92SAndroid Build Coastguard Worker friend Graph& GetGraphForTesting(IOptimizedNetwork* optNetPtr); 893*89c4ff92SAndroid Build Coastguard Worker friend ModelOptions& GetModelOptionsForTesting(IOptimizedNetwork* optNetPtr); 894*89c4ff92SAndroid Build Coastguard Worker friend IOptimizedNetworkPtr Optimize(const INetwork& inNetwork, 895*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences, 896*89c4ff92SAndroid Build Coastguard Worker const IDeviceSpec& deviceSpec, 897*89c4ff92SAndroid Build Coastguard Worker const OptimizerOptionsOpaque& options, 898*89c4ff92SAndroid Build Coastguard Worker Optional<std::vector<std::string>&> messages); 899*89c4ff92SAndroid Build Coastguard Worker friend IOptimizedNetworkPtr Optimize(const Graph& inGraph, 900*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences, 901*89c4ff92SAndroid Build Coastguard Worker const IDeviceSpec& deviceSpec, 902*89c4ff92SAndroid Build Coastguard Worker const OptimizerOptionsOpaque& options, 903*89c4ff92SAndroid Build Coastguard Worker Optional<std::vector<std::string>&> messages); 904*89c4ff92SAndroid Build Coastguard Worker 905*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetwork(std::unique_ptr<Graph> graph, const ModelOptions& modelOptions); 906*89c4ff92SAndroid Build Coastguard Worker 907*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<OptimizedNetworkImpl> pOptimizedNetworkImpl; 908*89c4ff92SAndroid Build Coastguard Worker }; 909*89c4ff92SAndroid Build Coastguard Worker 910*89c4ff92SAndroid Build Coastguard Worker /// Create an optimized version of the network 911*89c4ff92SAndroid Build Coastguard Worker /// @param network INetwork description of the network to be optimized. 912*89c4ff92SAndroid Build Coastguard Worker /// @param backendPreferences The choice of the backend ordered by user preferences. 913*89c4ff92SAndroid Build Coastguard Worker /// @param deviceSpec DeviceSpec object as queried from the runtime. See IRuntime::GetDeviceSpec() 914*89c4ff92SAndroid Build Coastguard Worker /// @param messages If there are failures or warnings a string describing same will be added to the vector 915*89c4ff92SAndroid Build Coastguard Worker /// @param options OptimizerOptions object with optimizer configuration options 916*89c4ff92SAndroid Build Coastguard Worker /// @return An IOptimizedNetworkPtr interface to the optimized network, throws an exception derived from 917*89c4ff92SAndroid Build Coastguard Worker /// armnn::Exception if process fails. 918*89c4ff92SAndroid Build Coastguard Worker 919*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr Optimize(const INetwork& network, 920*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences, 921*89c4ff92SAndroid Build Coastguard Worker const IDeviceSpec& deviceSpec, 922*89c4ff92SAndroid Build Coastguard Worker const OptimizerOptionsOpaque& options = OptimizerOptionsOpaque(), 923*89c4ff92SAndroid Build Coastguard Worker Optional<std::vector<std::string>&> messages = EmptyOptional()); 924*89c4ff92SAndroid Build Coastguard Worker 925*89c4ff92SAndroid Build Coastguard Worker /// Create an optimized version of the network 926*89c4ff92SAndroid Build Coastguard Worker /// @param inGraph Graph to be optimized. 927*89c4ff92SAndroid Build Coastguard Worker /// @param backendPreferences The choice of the backend ordered by user preferences. 928*89c4ff92SAndroid Build Coastguard Worker /// @param deviceSpec DeviceSpec object as queried from the runtime. See IRuntime::GetDeviceSpec() 929*89c4ff92SAndroid Build Coastguard Worker /// @param messages If there are failures or warnings a string describing same will be added to the vector 930*89c4ff92SAndroid Build Coastguard Worker /// @param options OptimizerOptions object with optimizer configuration options 931*89c4ff92SAndroid Build Coastguard Worker /// @return An IOptimizedNetworkPtr interface to the optimized network, throws an exception derived from 932*89c4ff92SAndroid Build Coastguard Worker /// armnn::Exception if process fails. 933*89c4ff92SAndroid Build Coastguard Worker 934*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr Optimize(const Graph& inGraph, 935*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences, 936*89c4ff92SAndroid Build Coastguard Worker const IDeviceSpec& deviceSpec, 937*89c4ff92SAndroid Build Coastguard Worker const OptimizerOptionsOpaque& options, 938*89c4ff92SAndroid Build Coastguard Worker Optional<std::vector<std::string>&> messages = EmptyOptional()); 939*89c4ff92SAndroid Build Coastguard Worker 940*89c4ff92SAndroid Build Coastguard Worker /// Accept legacy OptimizerOptions 941*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr Optimize(const Graph& inGraph, 942*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences, 943*89c4ff92SAndroid Build Coastguard Worker const IDeviceSpec& deviceSpec, 944*89c4ff92SAndroid Build Coastguard Worker const OptimizerOptions& options, 945*89c4ff92SAndroid Build Coastguard Worker Optional<std::vector<std::string>&> messages = EmptyOptional()); 946*89c4ff92SAndroid Build Coastguard Worker 947*89c4ff92SAndroid Build Coastguard Worker /// Accept legacy OptimizerOptions 948*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr Optimize(const INetwork& network, 949*89c4ff92SAndroid Build Coastguard Worker const std::vector<BackendId>& backendPreferences, 950*89c4ff92SAndroid Build Coastguard Worker const IDeviceSpec& deviceSpec, 951*89c4ff92SAndroid Build Coastguard Worker const OptimizerOptions& options, 952*89c4ff92SAndroid Build Coastguard Worker Optional<std::vector<std::string>&> messages = EmptyOptional()); 953*89c4ff92SAndroid Build Coastguard Worker 954*89c4ff92SAndroid Build Coastguard Worker } //namespace armnn 955