1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/Workload.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadInfo.hpp>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker namespace armnn
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker class ITensorHandle;
17*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker namespace
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker template <typename QueueDescriptor>
AddInputToWorkload(QueueDescriptor & descriptor,armnn::WorkloadInfo & info,const armnn::TensorInfo & tensorInfo,armnn::ITensorHandle * tensorHandle)23*89c4ff92SAndroid Build Coastguard Worker void AddInputToWorkload(QueueDescriptor& descriptor,
24*89c4ff92SAndroid Build Coastguard Worker armnn::WorkloadInfo& info,
25*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo,
26*89c4ff92SAndroid Build Coastguard Worker armnn::ITensorHandle* tensorHandle)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Inputs.push_back(tensorHandle);
29*89c4ff92SAndroid Build Coastguard Worker info.m_InputTensorInfos.push_back(tensorInfo);
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker template <typename QueueDescriptor>
AddOutputToWorkload(QueueDescriptor & descriptor,armnn::WorkloadInfo & info,const armnn::TensorInfo & tensorInfo,armnn::ITensorHandle * tensorHandle)33*89c4ff92SAndroid Build Coastguard Worker void AddOutputToWorkload(QueueDescriptor& descriptor,
34*89c4ff92SAndroid Build Coastguard Worker armnn::WorkloadInfo& info,
35*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo,
36*89c4ff92SAndroid Build Coastguard Worker armnn::ITensorHandle* tensorHandle)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Outputs.push_back(tensorHandle);
39*89c4ff92SAndroid Build Coastguard Worker info.m_OutputTensorInfos.push_back(tensorInfo);
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker template <typename QueueDescriptor>
SetWorkloadInput(QueueDescriptor & descriptor,armnn::WorkloadInfo & info,unsigned int index,const armnn::TensorInfo & tensorInfo,armnn::ITensorHandle * tensorHandle)43*89c4ff92SAndroid Build Coastguard Worker void SetWorkloadInput(QueueDescriptor& descriptor,
44*89c4ff92SAndroid Build Coastguard Worker armnn::WorkloadInfo& info,
45*89c4ff92SAndroid Build Coastguard Worker unsigned int index,
46*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo,
47*89c4ff92SAndroid Build Coastguard Worker armnn::ITensorHandle* tensorHandle)
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Inputs[index] = tensorHandle;
50*89c4ff92SAndroid Build Coastguard Worker info.m_InputTensorInfos[index] = tensorInfo;
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker template <typename QueueDescriptor>
SetWorkloadOutput(QueueDescriptor & descriptor,armnn::WorkloadInfo & info,unsigned int index,const armnn::TensorInfo & tensorInfo,armnn::ITensorHandle * tensorHandle)54*89c4ff92SAndroid Build Coastguard Worker void SetWorkloadOutput(QueueDescriptor& descriptor,
55*89c4ff92SAndroid Build Coastguard Worker armnn::WorkloadInfo& info,
56*89c4ff92SAndroid Build Coastguard Worker unsigned int index,
57*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo,
58*89c4ff92SAndroid Build Coastguard Worker armnn::ITensorHandle* tensorHandle)
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Outputs[index] = tensorHandle;
61*89c4ff92SAndroid Build Coastguard Worker info.m_OutputTensorInfos[index] = tensorInfo;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
ExecuteWorkload(armnn::IWorkload & workload,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,bool memoryManagementRequested=true)64*89c4ff92SAndroid Build Coastguard Worker inline void ExecuteWorkload(armnn::IWorkload& workload,
65*89c4ff92SAndroid Build Coastguard Worker const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
66*89c4ff92SAndroid Build Coastguard Worker bool memoryManagementRequested = true)
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker const bool manageMemory = memoryManager && memoryManagementRequested;
69*89c4ff92SAndroid Build Coastguard Worker
70*89c4ff92SAndroid Build Coastguard Worker // Acquire working memory (if needed)
71*89c4ff92SAndroid Build Coastguard Worker if (manageMemory)
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker memoryManager->Acquire();
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker // Perform PostAllocationConfiguration
77*89c4ff92SAndroid Build Coastguard Worker workload.PostAllocationConfigure();
78*89c4ff92SAndroid Build Coastguard Worker
79*89c4ff92SAndroid Build Coastguard Worker // Execute the workload
80*89c4ff92SAndroid Build Coastguard Worker workload.Execute();
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker // Release working memory (if needed)
83*89c4ff92SAndroid Build Coastguard Worker if (manageMemory)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker memoryManager->Release();
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker
GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)89*89c4ff92SAndroid Build Coastguard Worker inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker if (!weightsType)
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker return weightsType;
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker switch(weightsType.value())
97*89c4ff92SAndroid Build Coastguard Worker {
98*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::BFloat16:
99*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Float16:
100*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Float32:
101*89c4ff92SAndroid Build Coastguard Worker return weightsType;
102*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QAsymmS8:
103*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QAsymmU8:
104*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QSymmS8:
105*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QSymmS16:
106*89c4ff92SAndroid Build Coastguard Worker return armnn::DataType::Signed32;
107*89c4ff92SAndroid Build Coastguard Worker default:
108*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
109*89c4ff92SAndroid Build Coastguard Worker }
110*89c4ff92SAndroid Build Coastguard Worker return armnn::EmptyOptional();
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker
113*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
114