xref: /aosp_15_r20/external/armnn/src/backends/neon/NeonTensorHandleFactory.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "NeonTensorHandleFactory.hpp"
7 #include "NeonTensorHandle.hpp"
8 
9 #include "Layer.hpp"
10 
11 #include <armnn/utility/IgnoreUnused.hpp>
12 #include <armnn/utility/NumericCast.hpp>
13 #include <armnn/utility/PolymorphicDowncast.hpp>
14 
15 namespace armnn
16 {
17 
18 using FactoryId = ITensorHandleFactory::FactoryId;
19 
CreateSubTensorHandle(ITensorHandle & parent,const TensorShape & subTensorShape,const unsigned int * subTensorOrigin) const20 std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
21                                                                               const TensorShape& subTensorShape,
22                                                                               const unsigned int* subTensorOrigin)
23                                                                               const
24 {
25     const arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
26 
27     arm_compute::Coordinates coords;
28     coords.set_num_dimensions(subTensorShape.GetNumDimensions());
29     for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
30     {
31         // Arm compute indexes tensor coords in reverse order.
32         unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
33         coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
34     }
35 
36     const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
37 
38     if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
39     {
40         return nullptr;
41     }
42 
43     return std::make_unique<NeonSubTensorHandle>(
44             PolymorphicDowncast<IAclTensorHandle*>(&parent), shape, coords);
45 }
46 
CreateTensorHandle(const TensorInfo & tensorInfo) const47 std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
48 {
49     return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, true);
50 }
51 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const52 std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
53                                                                            DataLayout dataLayout) const
54 {
55     return NeonTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true);
56 }
57 
CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged) const58 std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
59                                                                            const bool IsMemoryManaged) const
60 {
61     auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo);
62     if (IsMemoryManaged)
63     {
64         tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
65     }
66     // If we are not Managing the Memory then we must be importing
67     tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
68     tensorHandle->SetImportFlags(GetImportFlags());
69 
70     return tensorHandle;
71 }
72 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged) const73 std::unique_ptr<ITensorHandle> NeonTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
74                                                                            DataLayout dataLayout,
75                                                                            const bool IsMemoryManaged) const
76 {
77     auto tensorHandle = std::make_unique<NeonTensorHandle>(tensorInfo, dataLayout);
78     if (IsMemoryManaged)
79     {
80         tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
81     }
82     // If we are not Managing the Memory then we must be importing
83     tensorHandle->SetImportEnabledFlag(!IsMemoryManaged);
84     tensorHandle->SetImportFlags(GetImportFlags());
85 
86     return tensorHandle;
87 }
88 
GetIdStatic()89 const FactoryId& NeonTensorHandleFactory::GetIdStatic()
90 {
91     static const FactoryId s_Id(NeonTensorHandleFactoryId());
92     return s_Id;
93 }
94 
GetId() const95 const FactoryId& NeonTensorHandleFactory::GetId() const
96 {
97     return GetIdStatic();
98 }
99 
SupportsInPlaceComputation() const100 bool NeonTensorHandleFactory::SupportsInPlaceComputation() const
101 {
102     return true;
103 }
104 
SupportsSubTensors() const105 bool NeonTensorHandleFactory::SupportsSubTensors() const
106 {
107     return true;
108 }
109 
GetExportFlags() const110 MemorySourceFlags NeonTensorHandleFactory::GetExportFlags() const
111 {
112     return m_ExportFlags;
113 }
114 
GetImportFlags() const115 MemorySourceFlags NeonTensorHandleFactory::GetImportFlags() const
116 {
117     return m_ImportFlags;
118 }
119 
GetCapabilities(const IConnectableLayer * layer,const IConnectableLayer * connectedLayer,CapabilityClass capabilityClass)120 std::vector<Capability> NeonTensorHandleFactory::GetCapabilities(const IConnectableLayer* layer,
121                                                                  const IConnectableLayer* connectedLayer,
122                                                                  CapabilityClass capabilityClass)
123 
124 {
125     IgnoreUnused(connectedLayer);
126     std::vector<Capability> capabilities;
127     if (capabilityClass == CapabilityClass::PaddingRequired)
128     {
129         auto search = paddingRequiredLayers.find((PolymorphicDowncast<const Layer*>(layer))->GetType());
130         if ( search != paddingRequiredLayers.end())
131         {
132             Capability paddingCapability(CapabilityClass::PaddingRequired, true);
133             capabilities.push_back(paddingCapability);
134         }
135     }
136     return capabilities;
137 }
138 
139 } // namespace armnn
140