1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <reference/workloads/ArgMinMax.hpp> 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("RefArgMinMax") 11*89c4ff92SAndroid Build Coastguard Worker { 12*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ArgMinTest") 13*89c4ff92SAndroid Build Coastguard Worker { 14*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32); 15*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64); 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f}); 18*89c4ff92SAndroid Build Coastguard Worker std::vector<int64_t> outputValues(outputInfo.GetNumElements()); 19*89c4ff92SAndroid Build Coastguard Worker std::vector<int64_t> expectedValues({ 0, 1, 0 }); 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()), 22*89c4ff92SAndroid Build Coastguard Worker outputValues.data(), 23*89c4ff92SAndroid Build Coastguard Worker inputInfo, 24*89c4ff92SAndroid Build Coastguard Worker outputInfo, 25*89c4ff92SAndroid Build Coastguard Worker armnn::ArgMinMaxFunction::Min, 26*89c4ff92SAndroid Build Coastguard Worker -2); 27*89c4ff92SAndroid Build Coastguard Worker 28*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(outputValues.begin(), outputValues.end(), expectedValues.begin(), expectedValues.end())); 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker } 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ArgMaxTest") 33*89c4ff92SAndroid Build Coastguard Worker { 34*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32); 35*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64); 36*89c4ff92SAndroid Build Coastguard Worker 37*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f }); 38*89c4ff92SAndroid Build Coastguard Worker std::vector<int64_t> outputValues(outputInfo.GetNumElements()); 39*89c4ff92SAndroid Build Coastguard Worker std::vector<int64_t> expectedValues({ 1, 0, 1 }); 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()), 42*89c4ff92SAndroid Build Coastguard Worker outputValues.data(), 43*89c4ff92SAndroid Build Coastguard Worker inputInfo, 44*89c4ff92SAndroid Build Coastguard Worker outputInfo, 45*89c4ff92SAndroid Build Coastguard Worker armnn::ArgMinMaxFunction::Max, 46*89c4ff92SAndroid Build Coastguard Worker -2); 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(outputValues.begin(), outputValues.end(), expectedValues.begin(), expectedValues.end())); 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker } 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker }