xref: /aosp_15_r20/external/armnn/delegate/test/SoftmaxTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "SoftmaxTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE ("Softmax_GpuAccTests")
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Softmax_Standard_Beta_GpuAcc_Test")
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
23*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput = {0.00994190481, 0.0445565246, 0.0734612942, 0.329230666, 0.542809606,
24*89c4ff92SAndroid Build Coastguard Worker                                          0.710742831, 0.158588171, 0.0961885825, 0.0214625746, 0.0130177103};
25*89c4ff92SAndroid Build Coastguard Worker     SoftmaxTestCase(tflite::BuiltinOperator_SOFTMAX, backends, 1, expectedOutput);
26*89c4ff92SAndroid Build Coastguard Worker }
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Softmax_Different_Beta_GpuAcc_Test")
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
31*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput = {0.0946234912, 0.148399189, 0.172415257, 0.270400971, 0.314161092, 0.352414012,
32*89c4ff92SAndroid Build Coastguard Worker                                          0.224709094, 0.193408906, 0.123322964, 0.106145054};
33*89c4ff92SAndroid Build Coastguard Worker     SoftmaxTestCase(tflite::BuiltinOperator_SOFTMAX, backends, 0.3, expectedOutput);
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Log_Softmax_GpuAcc_Test")
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
40*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput =
41*89c4ff92SAndroid Build Coastguard Worker         {-4.61099672, -3.11099672, -2.61099672, -1.11099672, -0.610996664,
42*89c4ff92SAndroid Build Coastguard Worker          -0.341444582, -1.84144461, -2.34144449, -3.84144449, -4.34144449};
43*89c4ff92SAndroid Build Coastguard Worker     SoftmaxTestCase(tflite::BuiltinOperator_LOG_SOFTMAX, backends, 0, expectedOutput);
44*89c4ff92SAndroid Build Coastguard Worker }
45*89c4ff92SAndroid Build Coastguard Worker } // TEST_SUITE ("Softmax_GpuAccTests")
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE ("Softmax_CpuRefTests")
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Softmax_Standard_Beta_CpuRef_Test")
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
53*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput = {
54*89c4ff92SAndroid Build Coastguard Worker         0.00994190481, 0.0445565246, 0.0734612942, 0.329230666, 0.542809606,
55*89c4ff92SAndroid Build Coastguard Worker         0.710742831, 0.158588171, 0.0961885825, 0.0214625746, 0.0130177103};
56*89c4ff92SAndroid Build Coastguard Worker     SoftmaxTestCase(tflite::BuiltinOperator_SOFTMAX, backends, 1, expectedOutput);
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Softmax_Different_Beta_CpuRef_Test")
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
62*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput = {
63*89c4ff92SAndroid Build Coastguard Worker         0.0946234912, 0.148399189, 0.172415257, 0.270400971, 0.314161092,
64*89c4ff92SAndroid Build Coastguard Worker         0.352414012, 0.224709094, 0.193408906, 0.123322964, 0.106145054};
65*89c4ff92SAndroid Build Coastguard Worker     SoftmaxTestCase(tflite::BuiltinOperator_SOFTMAX, backends, 0.3, expectedOutput);
66*89c4ff92SAndroid Build Coastguard Worker }
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Log_Softmax_CpuRef_Test")
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
71*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput =
72*89c4ff92SAndroid Build Coastguard Worker         {-4.61099672, -3.11099672, -2.61099672, -1.11099672, -0.610996664,
73*89c4ff92SAndroid Build Coastguard Worker          -0.341444582, -1.84144461, -2.34144449, -3.84144449, -4.34144449};
74*89c4ff92SAndroid Build Coastguard Worker     SoftmaxTestCase(tflite::BuiltinOperator_LOG_SOFTMAX, backends, 0, expectedOutput);
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker } // TEST_SUITE ("Softmax_CpuRefTests")
77*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate
78