xref: /aosp_15_r20/external/armnn/include/armnnTestUtils/WorkloadTestUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/Workload.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadInfo.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker namespace armnn
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker class ITensorHandle;
17*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker namespace
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker template <typename QueueDescriptor>
AddInputToWorkload(QueueDescriptor & descriptor,armnn::WorkloadInfo & info,const armnn::TensorInfo & tensorInfo,armnn::ITensorHandle * tensorHandle)23*89c4ff92SAndroid Build Coastguard Worker void AddInputToWorkload(QueueDescriptor& descriptor,
24*89c4ff92SAndroid Build Coastguard Worker     armnn::WorkloadInfo& info,
25*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& tensorInfo,
26*89c4ff92SAndroid Build Coastguard Worker     armnn::ITensorHandle* tensorHandle)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Inputs.push_back(tensorHandle);
29*89c4ff92SAndroid Build Coastguard Worker     info.m_InputTensorInfos.push_back(tensorInfo);
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker template <typename QueueDescriptor>
AddOutputToWorkload(QueueDescriptor & descriptor,armnn::WorkloadInfo & info,const armnn::TensorInfo & tensorInfo,armnn::ITensorHandle * tensorHandle)33*89c4ff92SAndroid Build Coastguard Worker void AddOutputToWorkload(QueueDescriptor& descriptor,
34*89c4ff92SAndroid Build Coastguard Worker     armnn::WorkloadInfo& info,
35*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& tensorInfo,
36*89c4ff92SAndroid Build Coastguard Worker     armnn::ITensorHandle* tensorHandle)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Outputs.push_back(tensorHandle);
39*89c4ff92SAndroid Build Coastguard Worker     info.m_OutputTensorInfos.push_back(tensorInfo);
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker template <typename QueueDescriptor>
SetWorkloadInput(QueueDescriptor & descriptor,armnn::WorkloadInfo & info,unsigned int index,const armnn::TensorInfo & tensorInfo,armnn::ITensorHandle * tensorHandle)43*89c4ff92SAndroid Build Coastguard Worker void SetWorkloadInput(QueueDescriptor& descriptor,
44*89c4ff92SAndroid Build Coastguard Worker     armnn::WorkloadInfo& info,
45*89c4ff92SAndroid Build Coastguard Worker     unsigned int index,
46*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& tensorInfo,
47*89c4ff92SAndroid Build Coastguard Worker     armnn::ITensorHandle* tensorHandle)
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Inputs[index] = tensorHandle;
50*89c4ff92SAndroid Build Coastguard Worker     info.m_InputTensorInfos[index] = tensorInfo;
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker template <typename QueueDescriptor>
SetWorkloadOutput(QueueDescriptor & descriptor,armnn::WorkloadInfo & info,unsigned int index,const armnn::TensorInfo & tensorInfo,armnn::ITensorHandle * tensorHandle)54*89c4ff92SAndroid Build Coastguard Worker void SetWorkloadOutput(QueueDescriptor& descriptor,
55*89c4ff92SAndroid Build Coastguard Worker     armnn::WorkloadInfo& info,
56*89c4ff92SAndroid Build Coastguard Worker     unsigned int index,
57*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& tensorInfo,
58*89c4ff92SAndroid Build Coastguard Worker     armnn::ITensorHandle* tensorHandle)
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Outputs[index] = tensorHandle;
61*89c4ff92SAndroid Build Coastguard Worker     info.m_OutputTensorInfos[index] = tensorInfo;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker 
ExecuteWorkload(armnn::IWorkload & workload,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,bool memoryManagementRequested=true)64*89c4ff92SAndroid Build Coastguard Worker inline void ExecuteWorkload(armnn::IWorkload& workload,
65*89c4ff92SAndroid Build Coastguard Worker                             const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
66*89c4ff92SAndroid Build Coastguard Worker                             bool memoryManagementRequested = true)
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker     const bool manageMemory = memoryManager && memoryManagementRequested;
69*89c4ff92SAndroid Build Coastguard Worker 
70*89c4ff92SAndroid Build Coastguard Worker     // Acquire working memory (if needed)
71*89c4ff92SAndroid Build Coastguard Worker     if (manageMemory)
72*89c4ff92SAndroid Build Coastguard Worker     {
73*89c4ff92SAndroid Build Coastguard Worker         memoryManager->Acquire();
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     // Perform PostAllocationConfiguration
77*89c4ff92SAndroid Build Coastguard Worker     workload.PostAllocationConfigure();
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     // Execute the workload
80*89c4ff92SAndroid Build Coastguard Worker     workload.Execute();
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker     // Release working memory (if needed)
83*89c4ff92SAndroid Build Coastguard Worker     if (manageMemory)
84*89c4ff92SAndroid Build Coastguard Worker     {
85*89c4ff92SAndroid Build Coastguard Worker         memoryManager->Release();
86*89c4ff92SAndroid Build Coastguard Worker     }
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker 
GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)89*89c4ff92SAndroid Build Coastguard Worker inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker     if (!weightsType)
92*89c4ff92SAndroid Build Coastguard Worker     {
93*89c4ff92SAndroid Build Coastguard Worker         return weightsType;
94*89c4ff92SAndroid Build Coastguard Worker     }
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     switch(weightsType.value())
97*89c4ff92SAndroid Build Coastguard Worker     {
98*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::BFloat16:
99*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Float16:
100*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Float32:
101*89c4ff92SAndroid Build Coastguard Worker             return weightsType;
102*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmS8:
103*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmU8:
104*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QSymmS8:
105*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QSymmS16:
106*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Signed32;
107*89c4ff92SAndroid Build Coastguard Worker         default:
108*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
109*89c4ff92SAndroid Build Coastguard Worker     }
110*89c4ff92SAndroid Build Coastguard Worker     return armnn::EmptyOptional();
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
114