xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefDebugWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefDebugWorkload.hpp"
7 #include "Debug.hpp"
8 #include "RefWorkloadUtils.hpp"
9 
10 #include <ResolveType.hpp>
11 
12 #include <cstring>
13 
14 namespace armnn
15 {
16 
17 template<armnn::DataType DataType>
Execute() const18 void RefDebugWorkload<DataType>::Execute() const
19 {
20     Execute(m_Data.m_Inputs);
21 }
22 
23 template<armnn::DataType DataType>
ExecuteAsync(ExecutionData & executionData)24 void RefDebugWorkload<DataType>::ExecuteAsync(ExecutionData& executionData)
25 {
26     WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
27     Execute(workingMemDescriptor->m_Inputs);
28 }
29 
30 template<armnn::DataType DataType>
Execute(std::vector<ITensorHandle * > inputs) const31 void RefDebugWorkload<DataType>::Execute(std::vector<ITensorHandle*> inputs) const
32 {
33     using T = ResolveType<DataType>;
34 
35     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute");
36 
37     const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
38 
39     const T* inputData = GetInputTensorData<T>(0, m_Data);
40     T* outputData = GetOutputTensorData<T>(0, m_Data);
41 
42     if (m_Callback)
43     {
44         m_Callback(m_Data.m_Guid, m_Data.m_SlotIndex, inputs[0]);
45     }
46     else
47     {
48         Debug(inputInfo, inputData, m_Data.m_Guid, m_Data.m_LayerName, m_Data.m_SlotIndex, m_Data.m_LayerOutputToFile);
49     }
50 
51     std::memcpy(outputData, inputData, inputInfo.GetNumElements()*sizeof(T));
52 }
53 
54 template<armnn::DataType DataType>
RegisterDebugCallback(const DebugCallbackFunction & func)55 void RefDebugWorkload<DataType>::RegisterDebugCallback(const DebugCallbackFunction& func)
56 {
57     m_Callback = func;
58 }
59 
60 template class RefDebugWorkload<DataType::BFloat16>;
61 template class RefDebugWorkload<DataType::Float16>;
62 template class RefDebugWorkload<DataType::Float32>;
63 template class RefDebugWorkload<DataType::QAsymmU8>;
64 template class RefDebugWorkload<DataType::QAsymmS8>;
65 template class RefDebugWorkload<DataType::QSymmS16>;
66 template class RefDebugWorkload<DataType::QSymmS8>;
67 template class RefDebugWorkload<DataType::Signed32>;
68 
69 } // namespace armnn
70