xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeCast.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <armnnUtils/QuantizeHelper.hpp>
10 #include <ResolveType.hpp>
11 
12 #include <string>
13 
14 TEST_SUITE("Deserializer_Cast")
15 {
16 struct CastFixture : public ParserFlatbuffersSerializeFixture
17 {
CastFixtureCastFixture18     explicit CastFixture(const std::string& inputShape,
19                          const std::string& outputShape,
20                          const std::string& inputDataType,
21                          const std::string& outputDataType)
22     {
23         m_JsonString = R"(
24             {
25                 inputIds: [0],
26                 outputIds: [2],
27                 layers: [
28                     {
29                         layer_type: "InputLayer",
30                         layer: {
31                             base: {
32                                 layerBindingId: 0,
33                                 base: {
34                                     index: 0,
35                                     layerName: "inputTensor",
36                                     layerType: "Input",
37                                     inputSlots: [{
38                                         index: 0,
39                                         connection: { sourceLayerIndex:0, outputSlotIndex:0 },
40                                     }],
41                                     outputSlots: [{
42                                         index: 0,
43                                         tensorInfo: {
44                                             dimensions: )" + inputShape + R"(,
45                                             dataType: )" + inputDataType + R"(
46                                         }
47                                     }]
48                                 }
49                             }
50                         }
51                     },
52                     {
53                         layer_type: "CastLayer",
54                         layer: {
55                             base: {
56                                  index:1,
57                                  layerName: "CastLayer",
58                                  layerType: "Cast",
59                                  inputSlots: [{
60                                      index: 0,
61                                      connection: { sourceLayerIndex:0, outputSlotIndex:0 },
62                                  }],
63                                  outputSlots: [{
64                                      index: 0,
65                                      tensorInfo: {
66                                          dimensions: )" + outputShape + R"(,
67                                          dataType: )" + outputDataType + R"(
68                                      },
69                                  }],
70                             },
71                         },
72                     },
73                     {
74                         layer_type: "OutputLayer",
75                         layer: {
76                             base:{
77                                 layerBindingId: 2,
78                                 base: {
79                                     index: 2,
80                                     layerName: "outputTensor",
81                                     layerType: "Output",
82                                     inputSlots: [{
83                                         index: 0,
84                                         connection: { sourceLayerIndex:1, outputSlotIndex:0 },
85                                     }],
86                                     outputSlots: [{
87                                         index: 0,
88                                         tensorInfo: {
89                                             dimensions: )" + outputShape + R"(,
90                                             dataType: )" + outputDataType + R"(
91                                         },
92                                     }],
93                                 }
94                             }
95                         },
96                     }
97                 ]
98             }
99         )";
100         Setup();
101     }
102 };
103 
104 struct SimpleCastFixture : CastFixture
105 {
SimpleCastFixtureSimpleCastFixture106     SimpleCastFixture() : CastFixture("[ 1, 6 ]",
107                                       "[ 1, 6 ]",
108                                       "Signed32",
109                                       "Float32") {}
110 };
111 
112 TEST_CASE_FIXTURE(SimpleCastFixture, "SimpleCast")
113 {
114     RunTest<2, armnn::DataType::Signed32 , armnn::DataType::Float32>(
115         0,
116         {{"inputTensor",  { 0,   -1,   5,   -100,   200,   -255 }}},
117         {{"outputTensor", { 0.0f, -1.0f, 5.0f, -100.0f, 200.0f, -255.0f }}});
118 }
119 
120 }
121