xref: /aosp_15_r20/external/armnn/include/armnnTestUtils/MockBackend.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Deprecated.hpp>
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/Exceptions.hpp>
10 #include <armnn/IRuntime.hpp>
11 #include <armnn/Optional.hpp>
12 #include <armnn/Tensor.hpp>
13 #include <armnn/Types.hpp>
14 #include <armnn/backends/IBackendInternal.hpp>
15 #include <armnn/backends/MemCopyWorkload.hpp>
16 #include <armnn/backends/ITensorHandle.hpp>
17 #include <armnn/backends/IWorkload.hpp>
18 #include <armnn/backends/OptimizationViews.hpp>
19 #include <armnn/backends/SubgraphView.hpp>
20 #include <armnn/backends/WorkloadData.hpp>
21 #include <armnn/backends/WorkloadFactory.hpp>
22 #include <armnn/backends/WorkloadInfo.hpp>
23 #include <armnn/utility/IgnoreUnused.hpp>
24 #include <armnn/utility/PolymorphicDowncast.hpp>
25 #include <armnnTestUtils/MockTensorHandle.hpp>
26 #include <backendsCommon/LayerSupportBase.hpp>
27 
28 #include <client/include/CounterValue.hpp>
29 #include <client/include/ISendTimelinePacket.hpp>
30 #include <client/include/Timestamp.hpp>
31 #include <client/include/backends/IBackendProfiling.hpp>
32 #include <client/include/backends/IBackendProfilingContext.hpp>
33 #include <common/include/Optional.hpp>
34 
35 #include <atomic>
36 #include <cstdint>
37 #include <memory>
38 #include <string>
39 #include <utility>
40 #include <vector>
41 
42 namespace armnn
43 {
44 class BackendId;
45 class ICustomAllocator;
46 class MockMemoryManager;
47 struct LstmInputParamsInfo;
48 struct QuantizedLstmInputParamsInfo;
49 
50 // A bare bones Mock backend to enable unit testing of simple tensor manipulation features.
51 class MockBackend : public IBackendInternal
52 {
53 public:
54     MockBackend() = default;
55 
56     ~MockBackend() = default;
57 
58     static const BackendId& GetIdStatic();
59 
GetId() const60     const BackendId& GetId() const override
61     {
62         return GetIdStatic();
63     }
64     IBackendInternal::IWorkloadFactoryPtr
65         CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override;
66 
67     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override;
68 
69     IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override;
70 
71     IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
72     IBackendInternal::IBackendProfilingContextPtr
73     CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions,
74                                   IBackendProfilingPtr& backendProfiling) override;
75 
76     OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const override;
77 
78     std::unique_ptr<ICustomAllocator> GetDefaultAllocator() const override;
79 };
80 
81 class MockWorkloadFactory : public IWorkloadFactory
82 {
83 
84 public:
85     explicit MockWorkloadFactory(const std::shared_ptr<MockMemoryManager>& memoryManager);
86     MockWorkloadFactory();
87 
~MockWorkloadFactory()88     ~MockWorkloadFactory()
89     {}
90 
91     const BackendId& GetBackendId() const override;
92 
SupportsSubTensors() const93     bool SupportsSubTensors() const override
94     {
95         return false;
96     }
97 
98     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
CreateSubTensorHandle(ITensorHandle &,TensorShape const &,unsigned int const *) const99     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle&,
100                                                          TensorShape const&,
101                                                          unsigned int const*) const override
102     {
103         return nullptr;
104     }
105 
106     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged=true) const107     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
108                                                       const bool IsMemoryManaged = true) const override
109     {
110         IgnoreUnused(IsMemoryManaged);
111         return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
112     };
113 
114     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged=true) const115     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
116                                                       DataLayout dataLayout,
117                                                       const bool IsMemoryManaged = true) const override
118     {
119         IgnoreUnused(dataLayout, IsMemoryManaged);
120         return std::make_unique<MockTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
121     };
122 
123     ARMNN_DEPRECATED_MSG_REMOVAL_DATE(
124         "Use ABI stable "
125         "CreateWorkload(LayerType, const QueueDescriptor&, const WorkloadInfo& info) instead.",
126         "23.08")
CreateInput(const InputQueueDescriptor & descriptor,const WorkloadInfo & info) const127     std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
128                                            const WorkloadInfo& info) const override
129     {
130         if (info.m_InputTensorInfos.empty())
131         {
132             throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Input cannot be zero length");
133         }
134         if (info.m_OutputTensorInfos.empty())
135         {
136             throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Output cannot be zero length");
137         }
138 
139         if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
140         {
141             throw InvalidArgumentException(
142                 "MockWorkloadFactory::CreateInput: data input and output differ in byte count.");
143         }
144 
145         return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
146     };
147 
148     std::unique_ptr<IWorkload>
149         CreateWorkload(LayerType type, const QueueDescriptor& descriptor, const WorkloadInfo& info) const override;
150 
151 private:
152     mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
153 };
154 
155 class MockBackendInitialiser
156 {
157 public:
158     MockBackendInitialiser();
159     ~MockBackendInitialiser();
160 };
161 
162 class MockBackendProfilingContext : public arm::pipe::IBackendProfilingContext
163 {
164 public:
MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr & backendProfiling)165     MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling)
166         : m_BackendProfiling(std::move(backendProfiling))
167         , m_CapturePeriod(0)
168         , m_IsTimelineEnabled(true)
169     {}
170 
171     ~MockBackendProfilingContext() = default;
172 
GetBackendProfiling()173     IBackendInternal::IBackendProfilingPtr& GetBackendProfiling()
174     {
175         return m_BackendProfiling;
176     }
177 
RegisterCounters(uint16_t currentMaxGlobalCounterId)178     uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId)
179     {
180         std::unique_ptr<arm::pipe::IRegisterBackendCounters> counterRegistrar =
181             m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId));
182 
183         std::string categoryName("MockCounters");
184         counterRegistrar->RegisterCategory(categoryName);
185 
186         counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter");
187 
188         counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two",
189                                           "Another notional counter");
190 
191         std::string units("microseconds");
192         uint16_t nextMaxGlobalCounterId =
193                         counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter",
194                                                           "A dummy four core counter", units, 4);
195         return nextMaxGlobalCounterId;
196     }
197 
ActivateCounters(uint32_t capturePeriod,const std::vector<uint16_t> & counterIds)198     arm::pipe::Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
199     {
200         if (capturePeriod == 0 || counterIds.size() == 0)
201         {
202             m_ActiveCounters.clear();
203         }
204         else if (capturePeriod == 15939u)
205         {
206             return arm::pipe::Optional<std::string>("ActivateCounters example test error");
207         }
208         m_CapturePeriod  = capturePeriod;
209         m_ActiveCounters = counterIds;
210         return arm::pipe::Optional<std::string>();
211     }
212 
ReportCounterValues()213     std::vector<arm::pipe::Timestamp> ReportCounterValues()
214     {
215         std::vector<arm::pipe::CounterValue> counterValues;
216 
217         for (auto counterId : m_ActiveCounters)
218         {
219             counterValues.emplace_back(arm::pipe::CounterValue{ counterId, counterId + 1u });
220         }
221 
222         uint64_t timestamp = m_CapturePeriod;
223         return { arm::pipe::Timestamp{ timestamp, counterValues } };
224     }
225 
EnableProfiling(bool)226     bool EnableProfiling(bool)
227     {
228         auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket();
229         sendTimelinePacket->SendTimelineEntityBinaryPacket(4256);
230         sendTimelinePacket->Commit();
231         return true;
232     }
233 
EnableTimelineReporting(bool isEnabled)234     bool EnableTimelineReporting(bool isEnabled)
235     {
236         m_IsTimelineEnabled = isEnabled;
237         return isEnabled;
238     }
239 
TimelineReportingEnabled()240     bool TimelineReportingEnabled()
241     {
242         return m_IsTimelineEnabled;
243     }
244 
245 private:
246     IBackendInternal::IBackendProfilingPtr m_BackendProfiling;
247     uint32_t m_CapturePeriod;
248     std::vector<uint16_t> m_ActiveCounters;
249     std::atomic<bool> m_IsTimelineEnabled;
250 };
251 
252 class MockBackendProfilingService
253 {
254 public:
255     // Getter for the singleton instance
Instance()256     static MockBackendProfilingService& Instance()
257     {
258         static MockBackendProfilingService instance;
259         return instance;
260     }
261 
GetContext()262     MockBackendProfilingContext* GetContext()
263     {
264         return m_sharedContext.get();
265     }
266 
SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)267     void SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)
268     {
269         m_sharedContext = shared;
270     }
271 
272 private:
273     std::shared_ptr<MockBackendProfilingContext> m_sharedContext;
274 };
275 
276 class MockLayerSupport : public LayerSupportBase
277 {
278 public:
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> &,const Optional<QuantizedLstmInputParamsInfo> &,Optional<std::string &> reasonIfUnsupported) const279     bool IsLayerSupported(const LayerType& type,
280                           const std::vector<TensorInfo>& infos,
281                           const BaseDescriptor& descriptor,
282                           const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/,
283                           const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/,
284                           Optional<std::string&> reasonIfUnsupported) const override
285     {
286         switch(type)
287         {
288             case LayerType::Input:
289                 return IsInputSupported(infos[0], reasonIfUnsupported);
290             case LayerType::Output:
291                 return IsOutputSupported(infos[0], reasonIfUnsupported);
292             case LayerType::Addition:
293                 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
294             case LayerType::Convolution2d:
295             {
296                 if (infos.size() != 4)
297                 {
298                     throw InvalidArgumentException("Invalid number of TransposeConvolution2d "
299                                                    "TensorInfos. TensorInfos should be of format: "
300                                                    "{input, output, weights, biases}.");
301                 }
302 
303                 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
304                 if (infos[3] == TensorInfo())
305                 {
306                     return IsConvolution2dSupported(infos[0],
307                                                     infos[1],
308                                                     desc,
309                                                     infos[2],
310                                                     EmptyOptional(),
311                                                     reasonIfUnsupported);
312                 }
313                 else
314                 {
315                     return IsConvolution2dSupported(infos[0],
316                                                     infos[1],
317                                                     desc,
318                                                     infos[2],
319                                                     infos[3],
320                                                     reasonIfUnsupported);
321                 }
322             }
323             case LayerType::ElementwiseBinary:
324             {
325                 auto elementwiseDesc = *(PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&descriptor));
326                 return (elementwiseDesc.m_Operation == BinaryOperation::Add);
327             }
328             default:
329                 return false;
330         }
331     }
332 
IsInputSupported(const TensorInfo &,Optional<std::string &>) const333     bool IsInputSupported(const TensorInfo& /*input*/,
334                           Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
335     {
336         return true;
337     }
338 
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const339     bool IsOutputSupported(const TensorInfo& /*input*/,
340                            Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
341     {
342         return true;
343     }
344 
IsAdditionSupported(const TensorInfo &,const TensorInfo &,const TensorInfo &,Optional<std::string &>) const345     bool IsAdditionSupported(const TensorInfo& /*input0*/,
346                              const TensorInfo& /*input1*/,
347                              const TensorInfo& /*output*/,
348                              Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
349     {
350         return true;
351     }
352 
IsConvolution2dSupported(const TensorInfo &,const TensorInfo &,const Convolution2dDescriptor &,const TensorInfo &,const Optional<TensorInfo> &,Optional<std::string &>) const353     bool IsConvolution2dSupported(const TensorInfo& /*input*/,
354                                   const TensorInfo& /*output*/,
355                                   const Convolution2dDescriptor& /*descriptor*/,
356                                   const TensorInfo& /*weights*/,
357                                   const Optional<TensorInfo>& /*biases*/,
358                                   Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
359     {
360         return true;
361     }
362 };
363 
364 }    // namespace armnn
365