xref: /aosp_15_r20/external/armnn/src/backends/cl/ClTensorHandle.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <aclCommon/ArmComputeTensorHandle.hpp>
8 #include <aclCommon/ArmComputeTensorUtils.hpp>
9 
10 #include <Half.hpp>
11 
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 
14 #include <arm_compute/runtime/CL/CLTensor.h>
15 #include <arm_compute/runtime/CL/CLSubTensor.h>
16 #include <arm_compute/runtime/IMemoryGroup.h>
17 #include <arm_compute/runtime/MemoryGroup.h>
18 #include <arm_compute/core/TensorShape.h>
19 #include <arm_compute/core/Coordinates.h>
20 
21 #include <aclCommon/IClTensorHandle.hpp>
22 
23 namespace armnn
24 {
25 
26 class ClTensorHandle : public IClTensorHandle
27 {
28 public:
ClTensorHandle(const TensorInfo & tensorInfo)29     ClTensorHandle(const TensorInfo& tensorInfo)
30                      : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
31                        m_Imported(false),
32                        m_IsImportEnabled(false)
33     {
34         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35     }
36 
ClTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Undefined))37     ClTensorHandle(const TensorInfo& tensorInfo,
38                    DataLayout dataLayout,
39                    MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
40                    : m_ImportFlags(importFlags),
41                      m_Imported(false),
42                      m_IsImportEnabled(false)
43     {
44         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45     }
46 
GetTensor()47     arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
GetTensor() const48     arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
Allocate()49     virtual void Allocate() override
50     {
51         // If we have enabled Importing, don't allocate the tensor
52         if (m_IsImportEnabled)
53         {
54             throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
55         }
56         else
57         {
58             armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
59         }
60 
61     }
62 
Manage()63     virtual void Manage() override
64     {
65         // If we have enabled Importing, don't manage the tensor
66         if (m_IsImportEnabled)
67         {
68             throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
69         }
70         else
71         {
72             assert(m_MemoryGroup != nullptr);
73             m_MemoryGroup->manage(&m_Tensor);
74         }
75     }
76 
Map(bool blocking=true) const77     virtual const void* Map(bool blocking = true) const override
78     {
79         const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
80         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
81     }
82 
Unmap() const83     virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
84 
GetParent() const85     virtual ITensorHandle* GetParent() const override { return nullptr; }
86 
GetDataType() const87     virtual arm_compute::DataType GetDataType() const override
88     {
89         return m_Tensor.info()->data_type();
90     }
91 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)92     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
93     {
94         m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
95     }
96 
GetStrides() const97     TensorShape GetStrides() const override
98     {
99         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
100     }
101 
GetShape() const102     TensorShape GetShape() const override
103     {
104         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
105     }
106 
SetImportFlags(MemorySourceFlags importFlags)107     void SetImportFlags(MemorySourceFlags importFlags)
108     {
109         m_ImportFlags = importFlags;
110     }
111 
GetImportFlags() const112     MemorySourceFlags GetImportFlags() const override
113     {
114         return m_ImportFlags;
115     }
116 
SetImportEnabledFlag(bool importEnabledFlag)117     void SetImportEnabledFlag(bool importEnabledFlag)
118     {
119         m_IsImportEnabled = importEnabledFlag;
120     }
121 
Import(void * memory,MemorySource source)122     virtual bool Import(void* memory, MemorySource source) override
123     {
124         armnn::IgnoreUnused(memory);
125         if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
126         {
127             throw MemoryImportException("ClTensorHandle::Incorrect import flag");
128         }
129         m_Imported = false;
130         return false;
131     }
132 
CanBeImported(void * memory,MemorySource source)133     virtual bool CanBeImported(void* memory, MemorySource source) override
134     {
135         // This TensorHandle can never import.
136         armnn::IgnoreUnused(memory, source);
137         return false;
138     }
139 
140 private:
141     // Only used for testing
CopyOutTo(void * memory) const142     void CopyOutTo(void* memory) const override
143     {
144         const_cast<armnn::ClTensorHandle*>(this)->Map(true);
145         switch(this->GetDataType())
146         {
147             case arm_compute::DataType::F32:
148                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
149                                                                  static_cast<float*>(memory));
150                 break;
151             case arm_compute::DataType::U8:
152             case arm_compute::DataType::QASYMM8:
153                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
154                                                                  static_cast<uint8_t*>(memory));
155                 break;
156             case arm_compute::DataType::QSYMM8:
157             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
158             case arm_compute::DataType::QASYMM8_SIGNED:
159                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
160                                                                  static_cast<int8_t*>(memory));
161                 break;
162             case arm_compute::DataType::F16:
163                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
164                                                                  static_cast<armnn::Half*>(memory));
165                 break;
166             case arm_compute::DataType::S16:
167             case arm_compute::DataType::QSYMM16:
168                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
169                                                                  static_cast<int16_t*>(memory));
170                 break;
171             case arm_compute::DataType::S32:
172                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173                                                                  static_cast<int32_t*>(memory));
174                 break;
175             default:
176             {
177                 throw armnn::UnimplementedException();
178             }
179         }
180         const_cast<armnn::ClTensorHandle*>(this)->Unmap();
181     }
182 
183     // Only used for testing
CopyInFrom(const void * memory)184     void CopyInFrom(const void* memory) override
185     {
186         this->Map(true);
187         switch(this->GetDataType())
188         {
189             case arm_compute::DataType::F32:
190                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
191                                                                  this->GetTensor());
192                 break;
193             case arm_compute::DataType::U8:
194             case arm_compute::DataType::QASYMM8:
195                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
196                                                                  this->GetTensor());
197                 break;
198             case arm_compute::DataType::F16:
199                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
200                                                                  this->GetTensor());
201                 break;
202             case arm_compute::DataType::S16:
203             case arm_compute::DataType::QSYMM8:
204             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
205             case arm_compute::DataType::QASYMM8_SIGNED:
206                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
207                                                                  this->GetTensor());
208                 break;
209             case arm_compute::DataType::QSYMM16:
210                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
211                                                                  this->GetTensor());
212                 break;
213             case arm_compute::DataType::S32:
214                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
215                                                                  this->GetTensor());
216                 break;
217             default:
218             {
219                 throw armnn::UnimplementedException();
220             }
221         }
222         this->Unmap();
223     }
224 
225     arm_compute::CLTensor m_Tensor;
226     std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
227     MemorySourceFlags m_ImportFlags;
228     bool m_Imported;
229     bool m_IsImportEnabled;
230 };
231 
232 class ClSubTensorHandle : public IClTensorHandle
233 {
234 public:
ClSubTensorHandle(IClTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)235     ClSubTensorHandle(IClTensorHandle* parent,
236                       const arm_compute::TensorShape& shape,
237                       const arm_compute::Coordinates& coords)
238     : m_Tensor(&parent->GetTensor(), shape, coords)
239     {
240         parentHandle = parent;
241     }
242 
GetTensor()243     arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
GetTensor() const244     arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
245 
Allocate()246     virtual void Allocate() override {}
Manage()247     virtual void Manage() override {}
248 
Map(bool blocking=true) const249     virtual const void* Map(bool blocking = true) const override
250     {
251         const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
252         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
253     }
Unmap() const254     virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
255 
GetParent() const256     virtual ITensorHandle* GetParent() const override { return parentHandle; }
257 
GetDataType() const258     virtual arm_compute::DataType GetDataType() const override
259     {
260         return m_Tensor.info()->data_type();
261     }
262 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)263     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
264 
GetStrides() const265     TensorShape GetStrides() const override
266     {
267         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
268     }
269 
GetShape() const270     TensorShape GetShape() const override
271     {
272         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
273     }
274 
275 private:
276     // Only used for testing
CopyOutTo(void * memory) const277     void CopyOutTo(void* memory) const override
278     {
279         const_cast<ClSubTensorHandle*>(this)->Map(true);
280         switch(this->GetDataType())
281         {
282             case arm_compute::DataType::F32:
283                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
284                                                                  static_cast<float*>(memory));
285                 break;
286             case arm_compute::DataType::U8:
287             case arm_compute::DataType::QASYMM8:
288                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
289                                                                  static_cast<uint8_t*>(memory));
290                 break;
291             case arm_compute::DataType::F16:
292                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
293                                                                  static_cast<armnn::Half*>(memory));
294                 break;
295             case arm_compute::DataType::QSYMM8:
296             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
297             case arm_compute::DataType::QASYMM8_SIGNED:
298             armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
299                                                              static_cast<int8_t*>(memory));
300                 break;
301             case arm_compute::DataType::S16:
302             case arm_compute::DataType::QSYMM16:
303                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304                                                                  static_cast<int16_t*>(memory));
305                 break;
306             case arm_compute::DataType::S32:
307                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
308                                                                  static_cast<int32_t*>(memory));
309                 break;
310             default:
311             {
312                 throw armnn::UnimplementedException();
313             }
314         }
315         const_cast<ClSubTensorHandle*>(this)->Unmap();
316     }
317 
318     // Only used for testing
CopyInFrom(const void * memory)319     void CopyInFrom(const void* memory) override
320     {
321         this->Map(true);
322         switch(this->GetDataType())
323         {
324             case arm_compute::DataType::F32:
325                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
326                                                                  this->GetTensor());
327                 break;
328             case arm_compute::DataType::U8:
329             case arm_compute::DataType::QASYMM8:
330                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
331                                                                  this->GetTensor());
332                 break;
333             case arm_compute::DataType::F16:
334                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
335                                                                  this->GetTensor());
336                 break;
337             case arm_compute::DataType::QSYMM8:
338             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
339             case arm_compute::DataType::QASYMM8_SIGNED:
340                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
341                                                                  this->GetTensor());
342                 break;
343             case arm_compute::DataType::S16:
344             case arm_compute::DataType::QSYMM16:
345                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
346                                                                  this->GetTensor());
347                 break;
348             case arm_compute::DataType::S32:
349                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
350                                                                  this->GetTensor());
351                 break;
352             default:
353             {
354                 throw armnn::UnimplementedException();
355             }
356         }
357         this->Unmap();
358     }
359 
360     mutable arm_compute::CLSubTensor m_Tensor;
361     ITensorHandle* parentHandle = nullptr;
362 };
363 
364 } // namespace armnn
365