1 // 2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <aclCommon/ArmComputeTensorHandle.hpp> 8 #include <aclCommon/ArmComputeTensorUtils.hpp> 9 10 #include <Half.hpp> 11 12 #include <armnn/utility/PolymorphicDowncast.hpp> 13 14 #include <arm_compute/runtime/CL/CLTensor.h> 15 #include <arm_compute/runtime/CL/CLSubTensor.h> 16 #include <arm_compute/runtime/IMemoryGroup.h> 17 #include <arm_compute/runtime/MemoryGroup.h> 18 #include <arm_compute/core/TensorShape.h> 19 #include <arm_compute/core/Coordinates.h> 20 21 #include <aclCommon/IClTensorHandle.hpp> 22 23 namespace armnn 24 { 25 26 class ClTensorHandle : public IClTensorHandle 27 { 28 public: ClTensorHandle(const TensorInfo & tensorInfo)29 ClTensorHandle(const TensorInfo& tensorInfo) 30 : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)), 31 m_Imported(false), 32 m_IsImportEnabled(false) 33 { 34 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); 35 } 36 ClTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Undefined))37 ClTensorHandle(const TensorInfo& tensorInfo, 38 DataLayout dataLayout, 39 MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined)) 40 : m_ImportFlags(importFlags), 41 m_Imported(false), 42 m_IsImportEnabled(false) 43 { 44 armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); 45 } 46 GetTensor()47 arm_compute::CLTensor& GetTensor() override { return m_Tensor; } GetTensor() const48 arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; } Allocate()49 virtual void Allocate() override 50 { 51 // If we have enabled Importing, don't allocate the tensor 52 if (m_IsImportEnabled) 53 { 54 throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing"); 55 } 56 else 57 { 58 armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); 59 } 60 61 } 62 Manage()63 virtual void Manage() override 64 { 65 // If we have enabled Importing, don't manage the tensor 66 if (m_IsImportEnabled) 67 { 68 throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing"); 69 } 70 else 71 { 72 assert(m_MemoryGroup != nullptr); 73 m_MemoryGroup->manage(&m_Tensor); 74 } 75 } 76 Map(bool blocking=true) const77 virtual const void* Map(bool blocking = true) const override 78 { 79 const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking); 80 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 81 } 82 Unmap() const83 virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); } 84 GetParent() const85 virtual ITensorHandle* GetParent() const override { return nullptr; } 86 GetDataType() const87 virtual arm_compute::DataType GetDataType() const override 88 { 89 return m_Tensor.info()->data_type(); 90 } 91 SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)92 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override 93 { 94 m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup); 95 } 96 GetStrides() const97 TensorShape GetStrides() const override 98 { 99 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 100 } 101 GetShape() const102 TensorShape GetShape() const override 103 { 104 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 105 } 106 SetImportFlags(MemorySourceFlags importFlags)107 void SetImportFlags(MemorySourceFlags importFlags) 108 { 109 m_ImportFlags = importFlags; 110 } 111 GetImportFlags() const112 MemorySourceFlags GetImportFlags() const override 113 { 114 return m_ImportFlags; 115 } 116 SetImportEnabledFlag(bool importEnabledFlag)117 void SetImportEnabledFlag(bool importEnabledFlag) 118 { 119 m_IsImportEnabled = importEnabledFlag; 120 } 121 Import(void * memory,MemorySource source)122 virtual bool Import(void* memory, MemorySource source) override 123 { 124 armnn::IgnoreUnused(memory); 125 if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) 126 { 127 throw MemoryImportException("ClTensorHandle::Incorrect import flag"); 128 } 129 m_Imported = false; 130 return false; 131 } 132 CanBeImported(void * memory,MemorySource source)133 virtual bool CanBeImported(void* memory, MemorySource source) override 134 { 135 // This TensorHandle can never import. 136 armnn::IgnoreUnused(memory, source); 137 return false; 138 } 139 140 private: 141 // Only used for testing CopyOutTo(void * memory) const142 void CopyOutTo(void* memory) const override 143 { 144 const_cast<armnn::ClTensorHandle*>(this)->Map(true); 145 switch(this->GetDataType()) 146 { 147 case arm_compute::DataType::F32: 148 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 149 static_cast<float*>(memory)); 150 break; 151 case arm_compute::DataType::U8: 152 case arm_compute::DataType::QASYMM8: 153 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 154 static_cast<uint8_t*>(memory)); 155 break; 156 case arm_compute::DataType::QSYMM8: 157 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 158 case arm_compute::DataType::QASYMM8_SIGNED: 159 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 160 static_cast<int8_t*>(memory)); 161 break; 162 case arm_compute::DataType::F16: 163 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 164 static_cast<armnn::Half*>(memory)); 165 break; 166 case arm_compute::DataType::S16: 167 case arm_compute::DataType::QSYMM16: 168 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 169 static_cast<int16_t*>(memory)); 170 break; 171 case arm_compute::DataType::S32: 172 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 173 static_cast<int32_t*>(memory)); 174 break; 175 default: 176 { 177 throw armnn::UnimplementedException(); 178 } 179 } 180 const_cast<armnn::ClTensorHandle*>(this)->Unmap(); 181 } 182 183 // Only used for testing CopyInFrom(const void * memory)184 void CopyInFrom(const void* memory) override 185 { 186 this->Map(true); 187 switch(this->GetDataType()) 188 { 189 case arm_compute::DataType::F32: 190 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 191 this->GetTensor()); 192 break; 193 case arm_compute::DataType::U8: 194 case arm_compute::DataType::QASYMM8: 195 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 196 this->GetTensor()); 197 break; 198 case arm_compute::DataType::F16: 199 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 200 this->GetTensor()); 201 break; 202 case arm_compute::DataType::S16: 203 case arm_compute::DataType::QSYMM8: 204 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 205 case arm_compute::DataType::QASYMM8_SIGNED: 206 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 207 this->GetTensor()); 208 break; 209 case arm_compute::DataType::QSYMM16: 210 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 211 this->GetTensor()); 212 break; 213 case arm_compute::DataType::S32: 214 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 215 this->GetTensor()); 216 break; 217 default: 218 { 219 throw armnn::UnimplementedException(); 220 } 221 } 222 this->Unmap(); 223 } 224 225 arm_compute::CLTensor m_Tensor; 226 std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; 227 MemorySourceFlags m_ImportFlags; 228 bool m_Imported; 229 bool m_IsImportEnabled; 230 }; 231 232 class ClSubTensorHandle : public IClTensorHandle 233 { 234 public: ClSubTensorHandle(IClTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)235 ClSubTensorHandle(IClTensorHandle* parent, 236 const arm_compute::TensorShape& shape, 237 const arm_compute::Coordinates& coords) 238 : m_Tensor(&parent->GetTensor(), shape, coords) 239 { 240 parentHandle = parent; 241 } 242 GetTensor()243 arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; } GetTensor() const244 arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; } 245 Allocate()246 virtual void Allocate() override {} Manage()247 virtual void Manage() override {} 248 Map(bool blocking=true) const249 virtual const void* Map(bool blocking = true) const override 250 { 251 const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking); 252 return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 253 } Unmap() const254 virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); } 255 GetParent() const256 virtual ITensorHandle* GetParent() const override { return parentHandle; } 257 GetDataType() const258 virtual arm_compute::DataType GetDataType() const override 259 { 260 return m_Tensor.info()->data_type(); 261 } 262 SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)263 virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {} 264 GetStrides() const265 TensorShape GetStrides() const override 266 { 267 return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 268 } 269 GetShape() const270 TensorShape GetShape() const override 271 { 272 return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 273 } 274 275 private: 276 // Only used for testing CopyOutTo(void * memory) const277 void CopyOutTo(void* memory) const override 278 { 279 const_cast<ClSubTensorHandle*>(this)->Map(true); 280 switch(this->GetDataType()) 281 { 282 case arm_compute::DataType::F32: 283 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 284 static_cast<float*>(memory)); 285 break; 286 case arm_compute::DataType::U8: 287 case arm_compute::DataType::QASYMM8: 288 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 289 static_cast<uint8_t*>(memory)); 290 break; 291 case arm_compute::DataType::F16: 292 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 293 static_cast<armnn::Half*>(memory)); 294 break; 295 case arm_compute::DataType::QSYMM8: 296 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 297 case arm_compute::DataType::QASYMM8_SIGNED: 298 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 299 static_cast<int8_t*>(memory)); 300 break; 301 case arm_compute::DataType::S16: 302 case arm_compute::DataType::QSYMM16: 303 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 304 static_cast<int16_t*>(memory)); 305 break; 306 case arm_compute::DataType::S32: 307 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 308 static_cast<int32_t*>(memory)); 309 break; 310 default: 311 { 312 throw armnn::UnimplementedException(); 313 } 314 } 315 const_cast<ClSubTensorHandle*>(this)->Unmap(); 316 } 317 318 // Only used for testing CopyInFrom(const void * memory)319 void CopyInFrom(const void* memory) override 320 { 321 this->Map(true); 322 switch(this->GetDataType()) 323 { 324 case arm_compute::DataType::F32: 325 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 326 this->GetTensor()); 327 break; 328 case arm_compute::DataType::U8: 329 case arm_compute::DataType::QASYMM8: 330 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 331 this->GetTensor()); 332 break; 333 case arm_compute::DataType::F16: 334 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 335 this->GetTensor()); 336 break; 337 case arm_compute::DataType::QSYMM8: 338 case arm_compute::DataType::QSYMM8_PER_CHANNEL: 339 case arm_compute::DataType::QASYMM8_SIGNED: 340 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 341 this->GetTensor()); 342 break; 343 case arm_compute::DataType::S16: 344 case arm_compute::DataType::QSYMM16: 345 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 346 this->GetTensor()); 347 break; 348 case arm_compute::DataType::S32: 349 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 350 this->GetTensor()); 351 break; 352 default: 353 { 354 throw armnn::UnimplementedException(); 355 } 356 } 357 this->Unmap(); 358 } 359 360 mutable arm_compute::CLSubTensor m_Tensor; 361 ITensorHandle* parentHandle = nullptr; 362 }; 363 364 } // namespace armnn 365