xref: /aosp_15_r20/external/armnn/src/backends/cl/test/OpenClTimerTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #if (defined(__aarch64__)) || (defined(__x86_64__)) // disable test failing on FireFly/Armv7
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "ClWorkloadFactoryHelper.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/TensorHelpers.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadFactory.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClContextControl.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClWorkloadFactory.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <cl/OpenClTimer.hpp>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/TensorCopyUtils.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/WorkloadTestUtils.hpp>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/CL/CLScheduler.h>
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker struct OpenClFixture
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker     // Initialising ClContextControl to ensure OpenCL is loaded correctly for each test case.
33*89c4ff92SAndroid Build Coastguard Worker     // NOTE: Profiling needs to be enabled in ClContextControl to be able to obtain execution
34*89c4ff92SAndroid Build Coastguard Worker     // times from OpenClTimer.
OpenClFixtureOpenClFixture35*89c4ff92SAndroid Build Coastguard Worker     OpenClFixture() : m_ClContextControl(nullptr, nullptr, true) {}
~OpenClFixtureOpenClFixture36*89c4ff92SAndroid Build Coastguard Worker     ~OpenClFixture() {}
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     ClContextControl m_ClContextControl;
39*89c4ff92SAndroid Build Coastguard Worker };
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(OpenClFixture, "OpenClTimerBatchNorm")
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker //using FactoryType = ClWorkloadFactory;
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     auto memoryManager = ClWorkloadFactoryHelper::GetMemoryManager();
46*89c4ff92SAndroid Build Coastguard Worker     ClWorkloadFactory workloadFactory = ClWorkloadFactoryHelper::GetFactory(memoryManager);
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     const unsigned int width    = 2;
49*89c4ff92SAndroid Build Coastguard Worker     const unsigned int height   = 3;
50*89c4ff92SAndroid Build Coastguard Worker     const unsigned int channels = 2;
51*89c4ff92SAndroid Build Coastguard Worker     const unsigned int num      = 1;
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo( {num, channels, height, width}, DataType::Float32);
54*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo({num, channels, height, width}, DataType::Float32);
55*89c4ff92SAndroid Build Coastguard Worker     TensorInfo tensorInfo({channels}, DataType::Float32);
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> input =
58*89c4ff92SAndroid Build Coastguard Worker     {
59*89c4ff92SAndroid Build Coastguard Worker          1.f, 4.f,
60*89c4ff92SAndroid Build Coastguard Worker          4.f, 2.f,
61*89c4ff92SAndroid Build Coastguard Worker          1.f, 6.f,
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker          1.f, 1.f,
64*89c4ff92SAndroid Build Coastguard Worker          4.f, 1.f,
65*89c4ff92SAndroid Build Coastguard Worker         -2.f, 4.f
66*89c4ff92SAndroid Build Coastguard Worker     };
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     // these values are per-channel of the input
69*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> mean     = { 3.f, -2.f };
70*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> variance = { 4.f,  9.f };
71*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> beta     = { 3.f,  2.f };
72*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> gamma    = { 2.f,  1.f };
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
75*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
76*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
77*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker     BatchNormalizationQueueDescriptor data;
80*89c4ff92SAndroid Build Coastguard Worker     WorkloadInfo info;
81*89c4ff92SAndroid Build Coastguard Worker     ScopedTensorHandle meanTensor(tensorInfo);
82*89c4ff92SAndroid Build Coastguard Worker     ScopedTensorHandle varianceTensor(tensorInfo);
83*89c4ff92SAndroid Build Coastguard Worker     ScopedTensorHandle betaTensor(tensorInfo);
84*89c4ff92SAndroid Build Coastguard Worker     ScopedTensorHandle gammaTensor(tensorInfo);
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     AllocateAndCopyDataToITensorHandle(&meanTensor, mean.data());
87*89c4ff92SAndroid Build Coastguard Worker     AllocateAndCopyDataToITensorHandle(&varianceTensor, variance.data());
88*89c4ff92SAndroid Build Coastguard Worker     AllocateAndCopyDataToITensorHandle(&betaTensor, beta.data());
89*89c4ff92SAndroid Build Coastguard Worker     AllocateAndCopyDataToITensorHandle(&gammaTensor, gamma.data());
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());
92*89c4ff92SAndroid Build Coastguard Worker     AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get());
93*89c4ff92SAndroid Build Coastguard Worker     data.m_Mean             = &meanTensor;
94*89c4ff92SAndroid Build Coastguard Worker     data.m_Variance         = &varianceTensor;
95*89c4ff92SAndroid Build Coastguard Worker     data.m_Beta             = &betaTensor;
96*89c4ff92SAndroid Build Coastguard Worker     data.m_Gamma            = &gammaTensor;
97*89c4ff92SAndroid Build Coastguard Worker     data.m_Parameters.m_Eps = 0.0f;
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     // for each channel:
100*89c4ff92SAndroid Build Coastguard Worker     // substract mean, divide by standard deviation (with an epsilon to avoid div by 0)
101*89c4ff92SAndroid Build Coastguard Worker     // multiply by gamma and add beta
102*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IWorkload> workload = workloadFactory.CreateWorkload(LayerType::BatchNormalization, data, info);
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker     inputHandle->Allocate();
105*89c4ff92SAndroid Build Coastguard Worker     outputHandle->Allocate();
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     CopyDataToITensorHandle(inputHandle.get(), input.data());
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     OpenClTimer openClTimer;
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(openClTimer.GetName(), "OpenClKernelTimer");
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker     //Start the timer
114*89c4ff92SAndroid Build Coastguard Worker     openClTimer.Start();
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     //Execute the workload
117*89c4ff92SAndroid Build Coastguard Worker     workload->Execute();
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     //Stop the timer
120*89c4ff92SAndroid Build Coastguard Worker     openClTimer.Stop();
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(openClTimer.GetMeasurements().size(), 1);
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(openClTimer.GetMeasurements().front().m_Name,
125*89c4ff92SAndroid Build Coastguard Worker                       "OpenClKernelTimer/0: batchnormalization_layer_nchw GWS[1,3,2]");
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     CHECK(openClTimer.GetMeasurements().front().m_Value > 0);
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker #endif //aarch64 or x86_64
132