1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "TosaRefPreCompiledWorkload.hpp"
7
8 namespace armnn
9 {
10
TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor & descriptor,const WorkloadInfo & info)11 TosaRefPreCompiledWorkload::TosaRefPreCompiledWorkload(const PreCompiledQueueDescriptor& descriptor,
12 const WorkloadInfo& info)
13 : BaseWorkload<PreCompiledQueueDescriptor>(descriptor, info)
14 , m_workloadInfo(info)
15 {
16 // Check that the workload is holding a pointer to a valid pre-compiled object
17 if (m_Data.m_PreCompiledObject == nullptr)
18 {
19 throw InvalidArgumentException(
20 "TosaRefPreCompiledWorkload requires a valid pre-compiled object (TosaSerializationHandler).");
21 }
22 }
23
Execute() const24 void TosaRefPreCompiledWorkload::Execute() const
25 {
26 tosa::TosaSerializationHandler* handler = static_cast<tosa::TosaSerializationHandler*>(m_Data.m_PreCompiledObject);
27
28 std::vector<std::string> inputNames = handler->GetInputs();
29 std::vector<std::string> outputNames = handler->GetOutputs();
30
31 TosaReference::IModelRunner runner;
32 GraphStatus status;
33
34 // Initialise the model runner with the TosaSerializationHandler
35 status = runner.initialize(*handler);
36 if(status != GraphStatus::TOSA_VALID)
37 {
38 throw armnn::Exception("An error has occurred while initialising the TOSA Reference Model.");
39 }
40
41 // Set the inputs
42 for (uint32_t inputSlotIdx = 0; inputSlotIdx < inputNames.size(); ++inputSlotIdx)
43 {
44 DataType dataType = m_workloadInfo.m_InputTensorInfos[inputSlotIdx].GetDataType();
45 switch (dataType)
46 {
47 case DataType::Float16:
48 SetInput<half_float::half>(runner, inputNames[inputSlotIdx], inputSlotIdx);
49 break;
50 case DataType::Float32:
51 SetInput<float>(runner, inputNames[inputSlotIdx], inputSlotIdx);
52 break;
53 case DataType::QAsymmU8:
54 case DataType::QAsymmS8:
55 case DataType::QSymmS8:
56 case DataType::QSymmS16:
57 case DataType::Signed32:
58 SetInput<int32_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
59 break;
60 case DataType::Signed64:
61 SetInput<int64_t>(runner, inputNames[inputSlotIdx], inputSlotIdx);
62 break;
63 case DataType::Boolean:
64 SetInput<unsigned char>(runner, inputNames[inputSlotIdx], inputSlotIdx);
65 break;
66 default:
67 throw armnn::Exception("Input data type is unsupported in TOSA Reference Backend.");
68 }
69 }
70
71 // Run the TOSA Reference Model
72 status = runner.run();
73 if(status != GraphStatus::TOSA_VALID)
74 {
75 throw armnn::Exception("An error has occurred while running the TOSA Reference Model.");
76 }
77
78 // Gets the outputs
79 for (uint32_t outputSlotIdx = 0; outputSlotIdx < outputNames.size(); ++outputSlotIdx)
80 {
81 DataType dataType = m_workloadInfo.m_OutputTensorInfos[outputSlotIdx].GetDataType();
82 switch (dataType)
83 {
84 case DataType::Float16:
85 GetOutput<half_float::half>(runner, outputNames[outputSlotIdx], outputSlotIdx);
86 break;
87 case DataType::Float32:
88 GetOutput<float>(runner, outputNames[outputSlotIdx], outputSlotIdx);
89 break;
90 case DataType::QAsymmU8:
91 case DataType::QAsymmS8:
92 case DataType::QSymmS8:
93 case DataType::QSymmS16:
94 case DataType::Signed32:
95 GetOutput<int32_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
96 break;
97 case DataType::Signed64:
98 GetOutput<int64_t>(runner, outputNames[outputSlotIdx], outputSlotIdx);
99 break;
100 case DataType::Boolean:
101 GetOutput<unsigned char>(runner, outputNames[outputSlotIdx], outputSlotIdx);
102 break;
103 default:
104 throw armnn::Exception("Output data type is unsupported in TOSA Reference Backend.");
105 }
106 }
107 }
108
109 template <typename T>
SetInput(TosaReference::IModelRunner & runner,std::string inputName,uint32_t inputIndex) const110 void TosaRefPreCompiledWorkload::SetInput(TosaReference::IModelRunner& runner,
111 std::string inputName,
112 uint32_t inputIndex) const
113 {
114 std::vector<T> inputData(m_Data.m_Inputs[inputIndex]->GetShape().GetNumElements());
115 m_Data.m_Inputs[inputIndex]->CopyOutTo(inputData.data());
116
117 runner.setInput<T>(inputName, inputData);
118 }
119
120 template <typename T>
GetOutput(TosaReference::IModelRunner & runner,std::string outputName,uint32_t outputIndex) const121 void TosaRefPreCompiledWorkload::GetOutput(TosaReference::IModelRunner& runner,
122 std::string outputName,
123 uint32_t outputIndex) const
124 {
125 std::vector<T> actualOutputs = runner.getOutput<T>(outputName);
126
127 m_Data.m_Outputs[outputIndex]->CopyInFrom(actualOutputs.data());
128 }
129
TosaRefPreCompiledWorkloadValidate(std::string *)130 bool TosaRefPreCompiledWorkloadValidate(std::string*)
131 {
132 return true;
133 }
134
135 } //namespace armnn
136