1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #if (defined(__aarch64__)) || (defined(__x86_64__)) // disable test failing on FireFly/Armv7 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include "ClWorkloadFactoryHelper.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/TensorHelpers.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp> 13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadFactory.hpp> 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClContextControl.hpp> 16*89c4ff92SAndroid Build Coastguard Worker #include <cl/ClWorkloadFactory.hpp> 17*89c4ff92SAndroid Build Coastguard Worker #include <cl/OpenClTimer.hpp> 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/TensorCopyUtils.hpp> 20*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/WorkloadTestUtils.hpp> 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/CL/CLScheduler.h> 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker #include <iostream> 27*89c4ff92SAndroid Build Coastguard Worker 28*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker struct OpenClFixture 31*89c4ff92SAndroid Build Coastguard Worker { 32*89c4ff92SAndroid Build Coastguard Worker // Initialising ClContextControl to ensure OpenCL is loaded correctly for each test case. 33*89c4ff92SAndroid Build Coastguard Worker // NOTE: Profiling needs to be enabled in ClContextControl to be able to obtain execution 34*89c4ff92SAndroid Build Coastguard Worker // times from OpenClTimer. OpenClFixtureOpenClFixture35*89c4ff92SAndroid Build Coastguard Worker OpenClFixture() : m_ClContextControl(nullptr, nullptr, true) {} ~OpenClFixtureOpenClFixture36*89c4ff92SAndroid Build Coastguard Worker ~OpenClFixture() {} 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker ClContextControl m_ClContextControl; 39*89c4ff92SAndroid Build Coastguard Worker }; 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(OpenClFixture, "OpenClTimerBatchNorm") 42*89c4ff92SAndroid Build Coastguard Worker { 43*89c4ff92SAndroid Build Coastguard Worker //using FactoryType = ClWorkloadFactory; 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker auto memoryManager = ClWorkloadFactoryHelper::GetMemoryManager(); 46*89c4ff92SAndroid Build Coastguard Worker ClWorkloadFactory workloadFactory = ClWorkloadFactoryHelper::GetFactory(memoryManager); 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker const unsigned int width = 2; 49*89c4ff92SAndroid Build Coastguard Worker const unsigned int height = 3; 50*89c4ff92SAndroid Build Coastguard Worker const unsigned int channels = 2; 51*89c4ff92SAndroid Build Coastguard Worker const unsigned int num = 1; 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo( {num, channels, height, width}, DataType::Float32); 54*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo({num, channels, height, width}, DataType::Float32); 55*89c4ff92SAndroid Build Coastguard Worker TensorInfo tensorInfo({channels}, DataType::Float32); 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker std::vector<float> input = 58*89c4ff92SAndroid Build Coastguard Worker { 59*89c4ff92SAndroid Build Coastguard Worker 1.f, 4.f, 60*89c4ff92SAndroid Build Coastguard Worker 4.f, 2.f, 61*89c4ff92SAndroid Build Coastguard Worker 1.f, 6.f, 62*89c4ff92SAndroid Build Coastguard Worker 63*89c4ff92SAndroid Build Coastguard Worker 1.f, 1.f, 64*89c4ff92SAndroid Build Coastguard Worker 4.f, 1.f, 65*89c4ff92SAndroid Build Coastguard Worker -2.f, 4.f 66*89c4ff92SAndroid Build Coastguard Worker }; 67*89c4ff92SAndroid Build Coastguard Worker 68*89c4ff92SAndroid Build Coastguard Worker // these values are per-channel of the input 69*89c4ff92SAndroid Build Coastguard Worker std::vector<float> mean = { 3.f, -2.f }; 70*89c4ff92SAndroid Build Coastguard Worker std::vector<float> variance = { 4.f, 9.f }; 71*89c4ff92SAndroid Build Coastguard Worker std::vector<float> beta = { 3.f, 2.f }; 72*89c4ff92SAndroid Build Coastguard Worker std::vector<float> gamma = { 2.f, 1.f }; 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN 75*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo); 76*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo); 77*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationQueueDescriptor data; 80*89c4ff92SAndroid Build Coastguard Worker WorkloadInfo info; 81*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle meanTensor(tensorInfo); 82*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle varianceTensor(tensorInfo); 83*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle betaTensor(tensorInfo); 84*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle gammaTensor(tensorInfo); 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker AllocateAndCopyDataToITensorHandle(&meanTensor, mean.data()); 87*89c4ff92SAndroid Build Coastguard Worker AllocateAndCopyDataToITensorHandle(&varianceTensor, variance.data()); 88*89c4ff92SAndroid Build Coastguard Worker AllocateAndCopyDataToITensorHandle(&betaTensor, beta.data()); 89*89c4ff92SAndroid Build Coastguard Worker AllocateAndCopyDataToITensorHandle(&gammaTensor, gamma.data()); 90*89c4ff92SAndroid Build Coastguard Worker 91*89c4ff92SAndroid Build Coastguard Worker AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); 92*89c4ff92SAndroid Build Coastguard Worker AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); 93*89c4ff92SAndroid Build Coastguard Worker data.m_Mean = &meanTensor; 94*89c4ff92SAndroid Build Coastguard Worker data.m_Variance = &varianceTensor; 95*89c4ff92SAndroid Build Coastguard Worker data.m_Beta = &betaTensor; 96*89c4ff92SAndroid Build Coastguard Worker data.m_Gamma = &gammaTensor; 97*89c4ff92SAndroid Build Coastguard Worker data.m_Parameters.m_Eps = 0.0f; 98*89c4ff92SAndroid Build Coastguard Worker 99*89c4ff92SAndroid Build Coastguard Worker // for each channel: 100*89c4ff92SAndroid Build Coastguard Worker // substract mean, divide by standard deviation (with an epsilon to avoid div by 0) 101*89c4ff92SAndroid Build Coastguard Worker // multiply by gamma and add beta 102*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> workload = workloadFactory.CreateWorkload(LayerType::BatchNormalization, data, info); 103*89c4ff92SAndroid Build Coastguard Worker 104*89c4ff92SAndroid Build Coastguard Worker inputHandle->Allocate(); 105*89c4ff92SAndroid Build Coastguard Worker outputHandle->Allocate(); 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Worker CopyDataToITensorHandle(inputHandle.get(), input.data()); 108*89c4ff92SAndroid Build Coastguard Worker 109*89c4ff92SAndroid Build Coastguard Worker OpenClTimer openClTimer; 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(openClTimer.GetName(), "OpenClKernelTimer"); 112*89c4ff92SAndroid Build Coastguard Worker 113*89c4ff92SAndroid Build Coastguard Worker //Start the timer 114*89c4ff92SAndroid Build Coastguard Worker openClTimer.Start(); 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Worker //Execute the workload 117*89c4ff92SAndroid Build Coastguard Worker workload->Execute(); 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker //Stop the timer 120*89c4ff92SAndroid Build Coastguard Worker openClTimer.Stop(); 121*89c4ff92SAndroid Build Coastguard Worker 122*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(openClTimer.GetMeasurements().size(), 1); 123*89c4ff92SAndroid Build Coastguard Worker 124*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(openClTimer.GetMeasurements().front().m_Name, 125*89c4ff92SAndroid Build Coastguard Worker "OpenClKernelTimer/0: batchnormalization_layer_nchw GWS[1,3,2]"); 126*89c4ff92SAndroid Build Coastguard Worker 127*89c4ff92SAndroid Build Coastguard Worker CHECK(openClTimer.GetMeasurements().front().m_Value > 0); 128*89c4ff92SAndroid Build Coastguard Worker 129*89c4ff92SAndroid Build Coastguard Worker } 130*89c4ff92SAndroid Build Coastguard Worker 131*89c4ff92SAndroid Build Coastguard Worker #endif //aarch64 or x86_64 132