1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Tensor.hpp> 9 #include <armnn/Descriptors.hpp> 10 11 #include "ClBaseWorkload.hpp" 12 13 #include <arm_compute/runtime/CL/functions/CLConvolutionLayer.h> 14 #include <arm_compute/runtime/MemoryManagerOnDemand.h> 15 16 #include <cl/ICLTensorProxy.hpp> 17 18 #include <memory> 19 20 namespace armnn 21 { 22 23 arm_compute::Status ClConvolution2dWorkloadValidate(const TensorInfo& input, 24 const TensorInfo& output, 25 const Convolution2dDescriptor& descriptor, 26 const TensorInfo& weights, 27 const Optional<TensorInfo>& biases, 28 bool isFastMathEnabled = false, 29 const ActivationDescriptor* activationDescriptor = nullptr); 30 31 class ClConvolution2dWorkload : public ClBaseWorkload<Convolution2dQueueDescriptor> 32 { 33 public: 34 ClConvolution2dWorkload(const Convolution2dQueueDescriptor& descriptor, 35 const WorkloadInfo& info, 36 std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager, 37 const arm_compute::CLCompileContext& clCompileContext, 38 const bool isFastMathEnabled = false); 39 void Execute() const override; 40 41 arm_compute::ConvolutionMethod GetConvolutionMethod() const; 42 SupportsTensorHandleReplacement() const43 bool SupportsTensorHandleReplacement() const override 44 { 45 // NCHW DataLayout on ACL still uses paddding for alignment on the Conv2d workload so importing is unreliable. 46 if (m_Data.m_Parameters.m_DataLayout == DataLayout::NCHW) 47 { 48 return false; 49 } 50 else 51 { 52 return true; 53 } 54 } 55 56 57 protected: 58 void Reconfigure() override; 59 60 private: 61 mutable arm_compute::CLConvolutionLayer m_ConvolutionLayer; 62 63 arm_compute::ConvolutionMethod m_ConvolutionMethod; 64 65 std::unique_ptr<ICLTensorProxy> m_InputProxy; 66 std::unique_ptr<ICLTensorProxy> m_WeightsProxy; 67 std::unique_ptr<ICLTensorProxy> m_BiasProxy; 68 std::unique_ptr<ICLTensorProxy> m_OutputProxy; 69 }; 70 71 } //namespace armnn 72 73