xref: /aosp_15_r20/external/armnn/src/backends/tosaReference/workloads/TosaRefPreCompiledWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "TosaRefPreCompiledWorkload.hpp"
7 
8 namespace armnn
9 {
10 
TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor & descriptor,const WorkloadInfo & info)11 TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor,
12                                                        const WorkloadInfo& info)
13     : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
14     , m_workloadInfo(info)
15 {
16     // Check that the workload is holding a pointer to a valid pre-compiled object
17     if (m_Data.m_PreCompiledObject == nullptr)
18     {
19         throw InvalidArgumentException(
20                 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21     }
22 }
23 
Execute() const24 void TosaRefPreCompiledWorkload::Execute() const
25 {
26     tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
27 
28     std::vector<std::string> inputNames = handler->GetInputs();
29     std::vector<std::string> outputNames = handler->GetOutputs();
30 
31     TosaReference::IModelRunner runner;
32     GraphStatus status;
33 
34     // Initialise the model runner with the TosaSerializationHandler
35     status = runner.initialize(*handler);
36     if(status != GraphStatus::TOSA_VALID)
37     {
38         throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
39     }
40 
41     // Set the inputs
42     for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
43     {
44         DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
45         switch (dataType)
46         {
47             case DataType::Float16:
48                 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
49                 break;
50             case DataType::Float32:
51                 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
52                 break;
53             case DataType::QAsymmU8:
54             case DataType::QAsymmS8:
55             case DataType::QSymmS8:
56             case DataType::QSymmS16:
57             case DataType::Signed32:
58                 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
59                 break;
60             case DataType::Signed64:
61                 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
62                 break;
63             case DataType::Boolean:
64                 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
65                 break;
66             default:
67                 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
68         }
69     }
70 
71     // Run the TOSA Reference Model
72     status = runner.run();
73     if(status != GraphStatus::TOSA_VALID)
74     {
75         throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
76     }
77 
78     // Gets the outputs
79     for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
80     {
81         DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
82         switch (dataType)
83         {
84             case DataType::Float16:
85                 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
86                 break;
87             case DataType::Float32:
88                 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
89                 break;
90             case DataType::QAsymmU8:
91             case DataType::QAsymmS8:
92             case DataType::QSymmS8:
93             case DataType::QSymmS16:
94             case DataType::Signed32:
95                 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
96                 break;
97             case DataType::Signed64:
98                 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
99                 break;
100             case DataType::Boolean:
101                 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
102                 break;
103             default:
104                 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
105         }
106     }
107 }
108 
109 template <typename T>
SetInput(TosaReference::IModelRunner & runner,std::string inputName,uint32_t inputIndex) const110 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
111                                           std::string inputName,
112                                           uint32_t inputIndex) const
113 {
114     std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
115     m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
116 
117     runner.setInput<T>(inputName, inputData);
118 }
119 
120 template <typename T>
GetOutput(TosaReference::IModelRunner & runner,std::string outputName,uint32_t outputIndex) const121 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
122                                            std::string outputName,
123                                            uint32_t outputIndex) const
124 {
125     std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
126 
127     m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
128 }
129 
TosaRefPreCompiledWorkloadValidate(std::string *)130 bool TosaRefPreCompiledWorkloadValidate(std::string*)
131 {
132     return true;
133 }
134 
135 }    //namespace armnn
136