xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/MakeWorkloadHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker namespace armnn
8*89c4ff92SAndroid Build Coastguard Worker {
9*89c4ff92SAndroid Build Coastguard Worker namespace
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker // Make a workload of the specified WorkloadType.
13*89c4ff92SAndroid Build Coastguard Worker template<typename WorkloadType>
14*89c4ff92SAndroid Build Coastguard Worker struct MakeWorkloadForType
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker     template<typename QueueDescriptorType, typename... Args>
Funcarmnn::__anon1c45edbd0111::MakeWorkloadForType17*89c4ff92SAndroid Build Coastguard Worker     static std::unique_ptr<WorkloadType> Func(const QueueDescriptorType& descriptor,
18*89c4ff92SAndroid Build Coastguard Worker                                               const WorkloadInfo& info,
19*89c4ff92SAndroid Build Coastguard Worker                                               Args&&... args)
20*89c4ff92SAndroid Build Coastguard Worker     {
21*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<WorkloadType>(descriptor, info, std::forward<Args>(args)...);
22*89c4ff92SAndroid Build Coastguard Worker     }
23*89c4ff92SAndroid Build Coastguard Worker };
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker // Specialization for void workload type used for unsupported workloads.
26*89c4ff92SAndroid Build Coastguard Worker template<>
27*89c4ff92SAndroid Build Coastguard Worker struct MakeWorkloadForType<NullWorkload>
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     template<typename QueueDescriptorType, typename... Args>
Funcarmnn::__anon1c45edbd0111::MakeWorkloadForType30*89c4ff92SAndroid Build Coastguard Worker     static std::unique_ptr<NullWorkload> Func(const QueueDescriptorType& descriptor,
31*89c4ff92SAndroid Build Coastguard Worker                                               const WorkloadInfo& info,
32*89c4ff92SAndroid Build Coastguard Worker                                               Args&&... args)
33*89c4ff92SAndroid Build Coastguard Worker     {
34*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(descriptor);
35*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(info);
36*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(args...);
37*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
38*89c4ff92SAndroid Build Coastguard Worker     }
39*89c4ff92SAndroid Build Coastguard Worker };
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker // Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
42*89c4ff92SAndroid Build Coastguard Worker // Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
43*89c4ff92SAndroid Build Coastguard Worker template <typename Float16Workload, typename Float32Workload, typename Uint8Workload, typename Int32Workload,
44*89c4ff92SAndroid Build Coastguard Worker           typename BooleanWorkload, typename Int8Workload, typename QueueDescriptorType, typename... Args>
MakeWorkloadHelper(const QueueDescriptorType & descriptor,const WorkloadInfo & info,Args &&...args)45*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descriptor,
46*89c4ff92SAndroid Build Coastguard Worker                                               const WorkloadInfo& info,
47*89c4ff92SAndroid Build Coastguard Worker                                               Args&&... args)
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker     const DataType dataType = !info.m_InputTensorInfos.empty() ?
50*89c4ff92SAndroid Build Coastguard Worker         info.m_InputTensorInfos[0].GetDataType()
51*89c4ff92SAndroid Build Coastguard Worker         : info.m_OutputTensorInfos[0].GetDataType();
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     switch (dataType)
54*89c4ff92SAndroid Build Coastguard Worker     {
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker         case DataType::Float16:
57*89c4ff92SAndroid Build Coastguard Worker             return MakeWorkloadForType<Float16Workload>::Func(descriptor, info, std::forward<Args>(args)...);
58*89c4ff92SAndroid Build Coastguard Worker         case DataType::Float32:
59*89c4ff92SAndroid Build Coastguard Worker             return MakeWorkloadForType<Float32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
60*89c4ff92SAndroid Build Coastguard Worker         case DataType::QAsymmU8:
61*89c4ff92SAndroid Build Coastguard Worker             return MakeWorkloadForType<Uint8Workload>::Func(descriptor, info, std::forward<Args>(args)...);
62*89c4ff92SAndroid Build Coastguard Worker         case DataType::QSymmS8:
63*89c4ff92SAndroid Build Coastguard Worker         case DataType::QAsymmS8:
64*89c4ff92SAndroid Build Coastguard Worker             return MakeWorkloadForType<Int8Workload>::Func(descriptor, info, std::forward<Args>(args)...);
65*89c4ff92SAndroid Build Coastguard Worker         case DataType::Signed32:
66*89c4ff92SAndroid Build Coastguard Worker             return MakeWorkloadForType<Int32Workload>::Func(descriptor, info, std::forward<Args>(args)...);
67*89c4ff92SAndroid Build Coastguard Worker         case DataType::Boolean:
68*89c4ff92SAndroid Build Coastguard Worker             return MakeWorkloadForType<BooleanWorkload>::Func(descriptor, info, std::forward<Args>(args)...);
69*89c4ff92SAndroid Build Coastguard Worker         case DataType::BFloat16:
70*89c4ff92SAndroid Build Coastguard Worker         case DataType::QSymmS16:
71*89c4ff92SAndroid Build Coastguard Worker             return nullptr;
72*89c4ff92SAndroid Build Coastguard Worker         default:
73*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT_MSG(false, "Unknown DataType.");
74*89c4ff92SAndroid Build Coastguard Worker             return nullptr;
75*89c4ff92SAndroid Build Coastguard Worker     }
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker // Makes a workload for one the specified types based on the data type requirements of the tensorinfo.
79*89c4ff92SAndroid Build Coastguard Worker // Calling this method is the equivalent of calling the five typed MakeWorkload method with <FloatWorkload,
80*89c4ff92SAndroid Build Coastguard Worker // FloatWorkload, Uint8Workload, NullWorkload, NullWorkload, NullWorkload>.
81*89c4ff92SAndroid Build Coastguard Worker // Specify type void as the WorkloadType for unsupported DataType/WorkloadType combos.
82*89c4ff92SAndroid Build Coastguard Worker template <typename FloatWorkload, typename Uint8Workload, typename QueueDescriptorType, typename... Args>
MakeWorkloadHelper(const QueueDescriptorType & descriptor,const WorkloadInfo & info,Args &&...args)83*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> MakeWorkloadHelper(const QueueDescriptorType& descriptor,
84*89c4ff92SAndroid Build Coastguard Worker                                               const WorkloadInfo& info,
85*89c4ff92SAndroid Build Coastguard Worker                                               Args&&... args)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker     return MakeWorkloadHelper<FloatWorkload, FloatWorkload, Uint8Workload, NullWorkload, NullWorkload, NullWorkload>(
88*89c4ff92SAndroid Build Coastguard Worker         descriptor,
89*89c4ff92SAndroid Build Coastguard Worker         info,
90*89c4ff92SAndroid Build Coastguard Worker         std::forward<Args>(args)...);
91*89c4ff92SAndroid Build Coastguard Worker }
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker } //namespace
94*89c4ff92SAndroid Build Coastguard Worker } //namespace armnn
95