1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefDebugWorkload.hpp"
7 #include "Debug.hpp"
8 #include "RefWorkloadUtils.hpp"
9
10 #include <ResolveType.hpp>
11
12 #include <cstring>
13
14 namespace armnn
15 {
16
17 template<armnn::DataType DataType>
Execute() const18 void RefDebugWorkload<DataType>::Execute() const
19 {
20 Execute(m_Data.m_Inputs);
21 }
22
23 template<armnn::DataType DataType>
ExecuteAsync(ExecutionData & executionData)24 void RefDebugWorkload<DataType>::ExecuteAsync(ExecutionData& executionData)
25 {
26 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
27 Execute(workingMemDescriptor->m_Inputs);
28 }
29
30 template<armnn::DataType DataType>
Execute(std::vector<ITensorHandle * > inputs) const31 void RefDebugWorkload<DataType>::Execute(std::vector<ITensorHandle*> inputs) const
32 {
33 using T = ResolveType<DataType>;
34
35 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute");
36
37 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
38
39 const T* inputData = GetInputTensorData<T>(0, m_Data);
40 T* outputData = GetOutputTensorData<T>(0, m_Data);
41
42 if (m_Callback)
43 {
44 m_Callback(m_Data.m_Guid, m_Data.m_SlotIndex, inputs[0]);
45 }
46 else
47 {
48 Debug(inputInfo, inputData, m_Data.m_Guid, m_Data.m_LayerName, m_Data.m_SlotIndex, m_Data.m_LayerOutputToFile);
49 }
50
51 std::memcpy(outputData, inputData, inputInfo.GetNumElements()*sizeof(T));
52 }
53
54 template<armnn::DataType DataType>
RegisterDebugCallback(const DebugCallbackFunction & func)55 void RefDebugWorkload<DataType>::RegisterDebugCallback(const DebugCallbackFunction& func)
56 {
57 m_Callback = func;
58 }
59
60 template class RefDebugWorkload<DataType::BFloat16>;
61 template class RefDebugWorkload<DataType::Float16>;
62 template class RefDebugWorkload<DataType::Float32>;
63 template class RefDebugWorkload<DataType::QAsymmU8>;
64 template class RefDebugWorkload<DataType::QAsymmS8>;
65 template class RefDebugWorkload<DataType::QSymmS16>;
66 template class RefDebugWorkload<DataType::QSymmS8>;
67 template class RefDebugWorkload<DataType::Signed32>;
68
69 } // namespace armnn
70