xref: /aosp_15_r20/external/armnn/src/backends/cl/test/Fp16SupportTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/Descriptors.hpp>
7 #include <armnn/IRuntime.hpp>
8 #include <armnn/INetwork.hpp>
9 #include <Half.hpp>
10 
11 #include <Graph.hpp>
12 #include <Optimizer.hpp>
13 #include <armnn/backends/TensorHandle.hpp>
14 #include <armnn/utility/IgnoreUnused.hpp>
15 
16 #include <doctest/doctest.h>
17 
18 #include <set>
19 
20 using namespace armnn;
21 
22 TEST_SUITE("Fp16Support")
23 {
24 TEST_CASE("Fp16DataTypeSupport")
25 {
26     Graph graph;
27 
28     Layer* const inputLayer1 = graph.AddLayer<InputLayer>(1, "input1");
29     Layer* const inputLayer2 = graph.AddLayer<InputLayer>(2, "input2");
30 
31     Layer* const additionLayer = graph.AddLayer<ElementwiseBinaryLayer>(BinaryOperation::Add, "addition");
32     Layer* const outputLayer = graph.AddLayer<armnn::OutputLayer>(0, "output");
33 
34     TensorInfo fp16TensorInfo({1, 2, 3, 5}, armnn::DataType::Float16);
35     inputLayer1->GetOutputSlot(0).Connect(additionLayer->GetInputSlot(0));
36     inputLayer2->GetOutputSlot(0).Connect(additionLayer->GetInputSlot(1));
37     additionLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
38 
39     inputLayer1->GetOutputSlot().SetTensorInfo(fp16TensorInfo);
40     inputLayer2->GetOutputSlot().SetTensorInfo(fp16TensorInfo);
41     additionLayer->GetOutputSlot().SetTensorInfo(fp16TensorInfo);
42 
43     CHECK(inputLayer1->GetOutputSlot(0).GetTensorInfo().GetDataType() == armnn::DataType::Float16);
44     CHECK(inputLayer2->GetOutputSlot(0).GetTensorInfo().GetDataType() == armnn::DataType::Float16);
45     CHECK(additionLayer->GetOutputSlot(0).GetTensorInfo().GetDataType() == armnn::DataType::Float16);
46 }
47 
48 TEST_CASE("Fp16AdditionTest")
49 {
50    using namespace half_float::literal;
51    // Create runtime in which test will run
52    IRuntime::CreationOptions options;
53    IRuntimePtr runtime(IRuntime::Create(options));
54 
55    // Builds up the structure of the network.
56    INetworkPtr net(INetwork::Create());
57 
58    IConnectableLayer* inputLayer1 = net->AddInputLayer(0);
59    IConnectableLayer* inputLayer2 = net->AddInputLayer(1);
60    IConnectableLayer* additionLayer = net->AddElementwiseBinaryLayer(BinaryOperation::Add);
61    IConnectableLayer* outputLayer = net->AddOutputLayer(0);
62 
63    inputLayer1->GetOutputSlot(0).Connect(additionLayer->GetInputSlot(0));
64    inputLayer2->GetOutputSlot(0).Connect(additionLayer->GetInputSlot(1));
65    additionLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
66 
67    //change to float16
68    TensorInfo fp16TensorInfo(TensorShape({4}), DataType::Float16);
69    inputLayer1->GetOutputSlot(0).SetTensorInfo(fp16TensorInfo);
70    inputLayer2->GetOutputSlot(0).SetTensorInfo(fp16TensorInfo);
71    additionLayer->GetOutputSlot(0).SetTensorInfo(fp16TensorInfo);
72 
73    // optimize the network
74    std::vector<BackendId> backends = {Compute::GpuAcc};
75    IOptimizedNetworkPtr optNet = Optimize(*net, backends, runtime->GetDeviceSpec());
76 
77    // Loads it into the runtime.
78    NetworkId netId;
79    runtime->LoadNetwork(netId, std::move(optNet));
80 
81    std::vector<Half> input1Data
82    {
83        1.0_h, 2.0_h, 3.0_h, 4.0_h
84    };
85 
86    std::vector<Half> input2Data
87    {
88        100.0_h, 200.0_h, 300.0_h, 400.0_h
89    };
90 
91    TensorInfo inputTensorInfo = runtime->GetInputTensorInfo(netId, 0);
92    inputTensorInfo.SetConstant(true);
93    InputTensors inputTensors
94    {
95        {0,ConstTensor(inputTensorInfo, input1Data.data())},
96        {1,ConstTensor(inputTensorInfo, input2Data.data())}
97    };
98 
99    std::vector<Half> outputData(input1Data.size());
100    OutputTensors outputTensors
101    {
102        {0,Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
103    };
104 
105    // Does the inference.
106    runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
107 
108    // Checks the results.
109    CHECK(outputData == std::vector<Half>({ 101.0_h, 202.0_h, 303.0_h, 404.0_h})); // Add
110 }
111 
112 }
113