1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #if (defined(__aarch64__)) || (defined(__x86_64__)) // disable test failing on FireFly/Armv7 7 8 #include "ClWorkloadFactoryHelper.hpp" 9 10 #include <armnnTestUtils/TensorHelpers.hpp> 11 12 #include <armnn/backends/TensorHandle.hpp> 13 #include <armnn/backends/WorkloadFactory.hpp> 14 15 #include <cl/ClContextControl.hpp> 16 #include <cl/ClWorkloadFactory.hpp> 17 #include <cl/OpenClTimer.hpp> 18 19 #include <armnnTestUtils/TensorCopyUtils.hpp> 20 #include <armnnTestUtils/WorkloadTestUtils.hpp> 21 22 #include <arm_compute/runtime/CL/CLScheduler.h> 23 24 #include <doctest/doctest.h> 25 26 #include <iostream> 27 28 using namespace armnn; 29 30 struct OpenClFixture 31 { 32 // Initialising ClContextControl to ensure OpenCL is loaded correctly for each test case. 33 // NOTE: Profiling needs to be enabled in ClContextControl to be able to obtain execution 34 // times from OpenClTimer. OpenClFixtureOpenClFixture35 OpenClFixture() : m_ClContextControl(nullptr, nullptr, true) {} ~OpenClFixtureOpenClFixture36 ~OpenClFixture() {} 37 38 ClContextControl m_ClContextControl; 39 }; 40 41 TEST_CASE_FIXTURE(OpenClFixture, "OpenClTimerBatchNorm") 42 { 43 //using FactoryType = ClWorkloadFactory; 44 45 auto memoryManager = ClWorkloadFactoryHelper::GetMemoryManager(); 46 ClWorkloadFactory workloadFactory = ClWorkloadFactoryHelper::GetFactory(memoryManager); 47 48 const unsigned int width = 2; 49 const unsigned int height = 3; 50 const unsigned int channels = 2; 51 const unsigned int num = 1; 52 53 TensorInfo inputTensorInfo( {num, channels, height, width}, DataType::Float32); 54 TensorInfo outputTensorInfo({num, channels, height, width}, DataType::Float32); 55 TensorInfo tensorInfo({channels}, DataType::Float32); 56 57 std::vector<float> input = 58 { 59 1.f, 4.f, 60 4.f, 2.f, 61 1.f, 6.f, 62 63 1.f, 1.f, 64 4.f, 1.f, 65 -2.f, 4.f 66 }; 67 68 // these values are per-channel of the input 69 std::vector<float> mean = { 3.f, -2.f }; 70 std::vector<float> variance = { 4.f, 9.f }; 71 std::vector<float> beta = { 3.f, 2.f }; 72 std::vector<float> gamma = { 2.f, 1.f }; 73 74 ARMNN_NO_DEPRECATE_WARN_BEGIN 75 std::unique_ptr<ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo); 76 std::unique_ptr<ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo); 77 ARMNN_NO_DEPRECATE_WARN_END 78 79 BatchNormalizationQueueDescriptor data; 80 WorkloadInfo info; 81 ScopedTensorHandle meanTensor(tensorInfo); 82 ScopedTensorHandle varianceTensor(tensorInfo); 83 ScopedTensorHandle betaTensor(tensorInfo); 84 ScopedTensorHandle gammaTensor(tensorInfo); 85 86 AllocateAndCopyDataToITensorHandle(&meanTensor, mean.data()); 87 AllocateAndCopyDataToITensorHandle(&varianceTensor, variance.data()); 88 AllocateAndCopyDataToITensorHandle(&betaTensor, beta.data()); 89 AllocateAndCopyDataToITensorHandle(&gammaTensor, gamma.data()); 90 91 AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get()); 92 AddOutputToWorkload(data, info, outputTensorInfo, outputHandle.get()); 93 data.m_Mean = &meanTensor; 94 data.m_Variance = &varianceTensor; 95 data.m_Beta = &betaTensor; 96 data.m_Gamma = &gammaTensor; 97 data.m_Parameters.m_Eps = 0.0f; 98 99 // for each channel: 100 // substract mean, divide by standard deviation (with an epsilon to avoid div by 0) 101 // multiply by gamma and add beta 102 std::unique_ptr<IWorkload> workload = workloadFactory.CreateWorkload(LayerType::BatchNormalization, data, info); 103 104 inputHandle->Allocate(); 105 outputHandle->Allocate(); 106 107 CopyDataToITensorHandle(inputHandle.get(), input.data()); 108 109 OpenClTimer openClTimer; 110 111 CHECK_EQ(openClTimer.GetName(), "OpenClKernelTimer"); 112 113 //Start the timer 114 openClTimer.Start(); 115 116 //Execute the workload 117 workload->Execute(); 118 119 //Stop the timer 120 openClTimer.Stop(); 121 122 CHECK_EQ(openClTimer.GetMeasurements().size(), 1); 123 124 CHECK_EQ(openClTimer.GetMeasurements().front().m_Name, 125 "OpenClKernelTimer/0: batchnormalization_layer_nchw GWS[1,3,2]"); 126 127 CHECK(openClTimer.GetMeasurements().front().m_Value > 0); 128 129 } 130 131 #endif //aarch64 or x86_64 132