1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/TypesUtils.hpp> 9 10 #include "RefBaseWorkload.hpp" 11 #include <armnn/backends/WorkloadData.hpp> 12 13 #include "Encoders.hpp" 14 #include "Decoders.hpp" 15 16 namespace armnn 17 { 18 19 class RefUnidirectionalSequenceLstmWorkload : public RefBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor> 20 { 21 public: 22 explicit RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, 23 const WorkloadInfo& info); 24 25 void Execute() const override; 26 void ExecuteAsync(ExecutionData& executionData) override; 27 28 29 private: 30 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; 31 std::unique_ptr<ScopedTensorHandle> m_InputToInputWeightsTensor; 32 std::unique_ptr<ScopedTensorHandle> m_InputToForgetWeightsTensor; 33 std::unique_ptr<ScopedTensorHandle> m_InputToCellWeightsTensor; 34 std::unique_ptr<ScopedTensorHandle> m_InputToOutputWeightsTensor; 35 std::unique_ptr<ScopedTensorHandle> m_RecurrentToInputWeightsTensor; 36 std::unique_ptr<ScopedTensorHandle> m_RecurrentToForgetWeightsTensor; 37 std::unique_ptr<ScopedTensorHandle> m_RecurrentToCellWeightsTensor; 38 std::unique_ptr<ScopedTensorHandle> m_RecurrentToOutputWeightsTensor; 39 std::unique_ptr<ScopedTensorHandle> m_CellToInputWeightsTensor; 40 std::unique_ptr<ScopedTensorHandle> m_CellToForgetWeightsTensor; 41 std::unique_ptr<ScopedTensorHandle> m_CellToOutputWeightsTensor; 42 std::unique_ptr<ScopedTensorHandle> m_InputGateBiasTensor; 43 std::unique_ptr<ScopedTensorHandle> m_ForgetGateBiasTensor; 44 std::unique_ptr<ScopedTensorHandle> m_CellBiasTensor; 45 std::unique_ptr<ScopedTensorHandle> m_OutputGateBiasTensor; 46 std::unique_ptr<ScopedTensorHandle> m_ProjectionWeightsTensor; 47 std::unique_ptr<ScopedTensorHandle> m_ProjectionBiasTensor; 48 std::unique_ptr<ScopedTensorHandle> m_InputLayerNormWeights; 49 std::unique_ptr<ScopedTensorHandle> m_ForgetLayerNormWeights; 50 std::unique_ptr<ScopedTensorHandle> m_CellLayerNormWeights; 51 std::unique_ptr<ScopedTensorHandle> m_OutputLayerNormWeights; 52 53 float m_LayerNormEpsilon = static_cast<float>(1e-8); 54 }; 55 56 } //namespace armnn 57