1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "LayerWithParameters.hpp" 8 9 namespace armnn 10 { 11 12 class ScopedTensorHandle; 13 14 /// This layer represents a batch normalization operation. 15 class BatchNormalizationLayer : public LayerWithParameters<BatchNormalizationDescriptor> 16 { 17 public: 18 /// A unique pointer to store Mean values 19 std::shared_ptr<ConstTensorHandle> m_Mean; 20 /// A unique pointer to store Variance values 21 std::shared_ptr<ConstTensorHandle> m_Variance; 22 /// A unique pointer to store Beta values 23 std::shared_ptr<ConstTensorHandle> m_Beta; 24 /// A unique pointer to store Gamma values 25 std::shared_ptr<ConstTensorHandle> m_Gamma; 26 27 /// Makes a workload for the BatchNormalization type. 28 /// @param [in] graph The graph where this layer can be found. 29 /// @param [in] factory The workload factory which will create the workload. 30 /// @return A pointer to the created workload, or nullptr if not created. 31 virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const override; 32 33 /// Creates a dynamically-allocated copy of this layer. 34 /// @param [in] graph The graph into which this layer is being cloned. 35 BatchNormalizationLayer* Clone(Graph& graph) const override; 36 37 /// Check if the input tensor shape(s) 38 /// will lead to a valid configuration of @ref BatchNormalizationLayer. 39 /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated. 40 void ValidateTensorShapesFromInputs() override; 41 42 void ExecuteStrategy(IStrategy& strategy) const override; 43 44 protected: 45 /// Constructor to create a BatchNormalizationLayer. 46 /// @param [in] param BatchNormalizationDescriptor to configure the batch normalization operation. 47 /// @param [in] name Optional name for the layer. 48 BatchNormalizationLayer(const BatchNormalizationDescriptor& param, const char* name); 49 50 /// Default destructor 51 ~BatchNormalizationLayer() = default; 52 53 /// Retrieve the handles to the constant values stored by the layer. 54 /// @return A vector of the constant tensors stored by this layer. 55 ImmutableConstantTensors GetConstantTensorsByRef() const override; 56 }; 57 58 } // namespace 59