xref: /aosp_15_r20/external/armnn/src/backends/tosaReference/TosaRefBackend.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "TosaRefBackend.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "TosaRefBackendId.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "TosaRefWorkloadFactory.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "TosaRefLayerSupport.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "TosaRefTensorHandleFactory.hpp"
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <tosaCommon/TosaMappings.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendContext.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/DefaultAllocator.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/SubgraphUtils.hpp>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker namespace armnn
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker // Utility function to construct a valid Deleter for TosaSerializationHandler ptrs passed back to ArmNN
26*89c4ff92SAndroid Build Coastguard Worker template <typename T>
DeleteAsType(const void * const blob)27*89c4ff92SAndroid Build Coastguard Worker void DeleteAsType(const void* const blob)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     delete static_cast<const T*>(blob);
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker 
GetIdStatic()32*89c4ff92SAndroid Build Coastguard Worker const BackendId& TosaRefBackend::GetIdStatic()
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker     static const BackendId s_Id{TosaRefBackendId()};
35*89c4ff92SAndroid Build Coastguard Worker     return s_Id;
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr & memoryManager) const38*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr TosaRefBackend::CreateWorkloadFactory(
39*89c4ff92SAndroid Build Coastguard Worker     const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<TosaRefWorkloadFactory>(PolymorphicPointerDowncast<TosaRefMemoryManager>(memoryManager));
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(class TensorHandleFactoryRegistry & tensorHandleFactoryRegistry) const44*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr TosaRefBackend::CreateWorkloadFactory(
45*89c4ff92SAndroid Build Coastguard Worker     class TensorHandleFactoryRegistry& tensorHandleFactoryRegistry) const
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker     auto memoryManager = std::make_shared<TosaRefMemoryManager>();
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterMemoryManager(memoryManager);
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     auto factory = std::make_unique<TosaRefTensorHandleFactory>(memoryManager);
52*89c4ff92SAndroid Build Coastguard Worker     // Register copy and import factory pair
53*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
54*89c4ff92SAndroid Build Coastguard Worker     // Register the factory
55*89c4ff92SAndroid Build Coastguard Worker     tensorHandleFactoryRegistry.RegisterFactory(std::move(factory));
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<TosaRefWorkloadFactory>(PolymorphicPointerDowncast<TosaRefMemoryManager>(memoryManager));
58*89c4ff92SAndroid Build Coastguard Worker }
59*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendContext(const IRuntime::CreationOptions &) const60*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendContextPtr TosaRefBackend::CreateBackendContext(const IRuntime::CreationOptions&) const
61*89c4ff92SAndroid Build Coastguard Worker {
62*89c4ff92SAndroid Build Coastguard Worker     return IBackendContextPtr{};
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendProfilingContext(const IRuntime::CreationOptions &,IBackendProfilingPtr &)65*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendProfilingContextPtr TosaRefBackend::CreateBackendProfilingContext(
66*89c4ff92SAndroid Build Coastguard Worker     const IRuntime::CreationOptions&, IBackendProfilingPtr&)
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker     return IBackendProfilingContextPtr{};
69*89c4ff92SAndroid Build Coastguard Worker }
70*89c4ff92SAndroid Build Coastguard Worker 
CreateMemoryManager() const71*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IMemoryManagerUniquePtr TosaRefBackend::CreateMemoryManager() const
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<TosaRefMemoryManager>();
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker 
GetLayerSupport() const76*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr TosaRefBackend::GetLayerSupport() const
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker     static ILayerSupportSharedPtr layerSupport{new TosaRefLayerSupport};
79*89c4ff92SAndroid Build Coastguard Worker     return layerSupport;
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker 
OptimizeSubgraphView(const SubgraphView & subgraph,const ModelOptions & modelOptions) const82*89c4ff92SAndroid Build Coastguard Worker OptimizationViews TosaRefBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
83*89c4ff92SAndroid Build Coastguard Worker                                                        const ModelOptions& modelOptions) const
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews(modelOptions);
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     auto handler = std::make_unique<TosaSerializationHandler>();
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> graphInputs;
90*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> graphOutputs;
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker     std::vector<TosaSerializationOperator*> operators;
93*89c4ff92SAndroid Build Coastguard Worker     std::vector<TosaSerializationTensor*> tensors;
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker     auto it = subgraph.endIConnectable();
96*89c4ff92SAndroid Build Coastguard Worker     while (it != subgraph.beginIConnectable())
97*89c4ff92SAndroid Build Coastguard Worker     {
98*89c4ff92SAndroid Build Coastguard Worker         --it;
99*89c4ff92SAndroid Build Coastguard Worker         Layer& base = *(PolymorphicDowncast<Layer*>(*it));
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker         if(base.GetType() == armnn::LayerType::Input ||
102*89c4ff92SAndroid Build Coastguard Worker            base.GetType() == armnn::LayerType::Output)
103*89c4ff92SAndroid Build Coastguard Worker         {
104*89c4ff92SAndroid Build Coastguard Worker             continue;
105*89c4ff92SAndroid Build Coastguard Worker         }
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker         tosa::TosaSerializationBasicBlock* mappings = GetTosaMappingFromLayer(&base);
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker         // Loop through inputs to see if there are any graph inputs, if so save them.
110*89c4ff92SAndroid Build Coastguard Worker         // If it's an input to the graph "input" can be found in the string.
111*89c4ff92SAndroid Build Coastguard Worker         for (uint32_t i = 0; i < mappings->GetInputs().size(); i++)
112*89c4ff92SAndroid Build Coastguard Worker         {
113*89c4ff92SAndroid Build Coastguard Worker             std::basic_string<char> blockInputName = mappings->GetInputs()[i];
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker             if (blockInputName.find("input") != std::string::npos)
116*89c4ff92SAndroid Build Coastguard Worker             {
117*89c4ff92SAndroid Build Coastguard Worker                 graphInputs.push_back(blockInputName);
118*89c4ff92SAndroid Build Coastguard Worker             }
119*89c4ff92SAndroid Build Coastguard Worker         }
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker         // Loop through outputs to see if there are any graph outputs, if so save them.
122*89c4ff92SAndroid Build Coastguard Worker         // If it's an output to the graph "output" can be found in the string.
123*89c4ff92SAndroid Build Coastguard Worker         for (uint32_t i = 0; i < mappings->GetOutputs().size(); i++)
124*89c4ff92SAndroid Build Coastguard Worker         {
125*89c4ff92SAndroid Build Coastguard Worker             std::basic_string<char> blockOutputName = mappings->GetOutputs()[i];
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker             if (blockOutputName.find("output") != std::string::npos)
128*89c4ff92SAndroid Build Coastguard Worker             {
129*89c4ff92SAndroid Build Coastguard Worker                 graphOutputs.push_back(blockOutputName);
130*89c4ff92SAndroid Build Coastguard Worker             }
131*89c4ff92SAndroid Build Coastguard Worker         }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker         auto blockOperators = mappings->GetOperators();
134*89c4ff92SAndroid Build Coastguard Worker         operators.insert(operators.end(), blockOperators.begin(), blockOperators.end());
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker         auto blockTensors = mappings->GetTensors();
137*89c4ff92SAndroid Build Coastguard Worker         tensors.insert(tensors.end(), blockTensors.begin(), blockTensors.end());
138*89c4ff92SAndroid Build Coastguard Worker     }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker     // Add all mappings to main block, the TOSA Reference Model requires the full graph to be in one block called main.
141*89c4ff92SAndroid Build Coastguard Worker     auto* block = new TosaSerializationBasicBlock("main", operators, tensors, graphInputs, graphOutputs);
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker     handler.get()->GetBlocks().push_back(block);
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     auto compiledBlob =
146*89c4ff92SAndroid Build Coastguard Worker             std::make_unique<PreCompiledObjectPtr>(handler.release(), DeleteAsType<TosaSerializationHandler>);
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* preCompiledLayer = optimizationViews.GetINetwork()->AddPrecompiledLayer(
149*89c4ff92SAndroid Build Coastguard Worker             PreCompiledDescriptor(subgraph.GetNumInputSlots(), subgraph.GetNumOutputSlots()),
150*89c4ff92SAndroid Build Coastguard Worker             std::move(*compiledBlob),
151*89c4ff92SAndroid Build Coastguard Worker             armnn::Optional<BackendId>(GetId()),
152*89c4ff92SAndroid Build Coastguard Worker             "TOSA_Pre_Compiled_Layer");
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker     // Copy the output tensor infos from sub-graph
155*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < subgraph.GetNumOutputSlots(); i++)
156*89c4ff92SAndroid Build Coastguard Worker     {
157*89c4ff92SAndroid Build Coastguard Worker         preCompiledLayer->GetOutputSlot(i).SetTensorInfo(subgraph.GetIOutputSlot(i)->GetTensorInfo());
158*89c4ff92SAndroid Build Coastguard Worker     }
159*89c4ff92SAndroid Build Coastguard Worker 
160*89c4ff92SAndroid Build Coastguard Worker     optimizationViews.AddSubstitution({ std::move(subgraph), SubgraphView(preCompiledLayer) });
161*89c4ff92SAndroid Build Coastguard Worker     return optimizationViews;
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker 
GetHandleFactoryPreferences() const165*89c4ff92SAndroid Build Coastguard Worker std::vector<ITensorHandleFactory::FactoryId> TosaRefBackend::GetHandleFactoryPreferences() const
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker     return std::vector<ITensorHandleFactory::FactoryId> { TosaRefTensorHandleFactory::GetIdStatic() };
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker 
RegisterTensorHandleFactories(class TensorHandleFactoryRegistry & registry)170*89c4ff92SAndroid Build Coastguard Worker void TosaRefBackend::RegisterTensorHandleFactories(class TensorHandleFactoryRegistry& registry)
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker     auto memoryManager = std::make_shared<TosaRefMemoryManager>();
173*89c4ff92SAndroid Build Coastguard Worker 
174*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterMemoryManager(memoryManager);
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker     auto factory = std::make_unique<TosaRefTensorHandleFactory>(memoryManager);
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker     // Register copy and import factory pair
179*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), factory->GetId());
180*89c4ff92SAndroid Build Coastguard Worker     // Register the factory
181*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(factory));
182*89c4ff92SAndroid Build Coastguard Worker }
183*89c4ff92SAndroid Build Coastguard Worker 
GetDefaultAllocator() const184*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ICustomAllocator> TosaRefBackend::GetDefaultAllocator() const
185*89c4ff92SAndroid Build Coastguard Worker {
186*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<DefaultAllocator>();
187*89c4ff92SAndroid Build Coastguard Worker }
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
190