xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/LoadScopeDynamicTensor.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include "ParserPrototxtFixture.hpp"
8 
9 TEST_SUITE("OnnxParser_LoadScopeDynamicTensor")
10 {
11 
12 struct DynamicBatchTensorFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
DynamicBatchTensorFixtureDynamicBatchTensorFixture14     DynamicBatchTensorFixture()
15     {
16         m_Prototext = R"(
17                    ir_version: 3
18                    producer_name:  "CNTK"
19                    producer_version:  "2.5.1"
20                    domain:  "ai.cntk"
21                    model_version: 1
22                    graph {
23                      name:  "CNTKGraph"
24                      input {
25                         name: "Input"
26                         type {
27                           tensor_type {
28                             elem_type: 1
29                             shape {
30                               dim {
31                                 dim_value: 0
32                               }
33                               dim {
34                                 dim_value: 1
35                               }
36                               dim {
37                                 dim_value: 3
38                               }
39                               dim {
40                                 dim_value: 3
41                               }
42                             }
43                           }
44                         }
45                       }
46                       input {
47                         name: "Weight"
48                         type {
49                           tensor_type {
50                             elem_type: 1
51                             shape {
52                               dim {
53                                 dim_value: 1
54                               }
55                               dim {
56                                 dim_value: 1
57                               }
58                               dim {
59                                 dim_value: 3
60                               }
61                               dim {
62                                 dim_value: 3
63                               }
64                             }
65                           }
66                         }
67                       }
68                       initializer {
69                           dims: 1
70                           dims: 1
71                           dims: 3
72                           dims: 3
73                           data_type: 1
74                           float_data: 2
75                           float_data: 1
76                           float_data: 0
77                           float_data: 6
78                           float_data: 2
79                           float_data: 1
80                           float_data: 4
81                           float_data: 1
82                           float_data: 2
83                           name: "Weight"
84                         }
85                       node {
86                          input: "Input"
87                          input: "Weight"
88                          output: "Output"
89                          name: "Convolution"
90                          op_type: "Conv"
91                          attribute {
92                            name: "kernel_shape"
93                            ints: 3
94                            ints: 3
95                            type: INTS
96                          }
97                          attribute {
98                            name: "strides"
99                            ints: 1
100                            ints: 1
101                            type: INTS
102                          }
103                          attribute {
104                            name: "auto_pad"
105                            s: "VALID"
106                            type: STRING
107                          }
108                          attribute {
109                            name: "group"
110                            i: 1
111                            type: INT
112                          }
113                          attribute {
114                            name: "dilations"
115                            ints: 1
116                            ints: 1
117                            type: INTS
118                          }
119                          doc_string: ""
120                          domain: ""
121                        }
122                       output {
123                           name: "Output"
124                           type {
125                              tensor_type {
126                                elem_type: 1
127                                shape {
128                                    dim {
129                                        dim_value: 0
130                                    }
131                                    dim {
132                                        dim_value: 1
133                                    }
134                                    dim {
135                                        dim_value: 1
136                                    }
137                                    dim {
138                                        dim_value: 1
139                                    }
140                                }
141                             }
142                         }
143                         }
144                     }
145                    opset_import {
146                       version: 7
147                     })";
148     }
149 };
150 
151 TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "DynamicBatchTensorTest")
152 {
153     Setup({{"Input", armnn::TensorShape({1, 1, 3, 3})}});
154     RunTest<4>({{"Input", {1.0, 2.0, 3.0,
155                            4.0, 5.0, 6.0,
156                            7.0, 8.0, 9.0}}},
157               {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
158                            4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
159                            7.0 * 4 + 8.0 * 1 + 9.0 * 2}}});
160 }
161 
162 TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "TensorShapeNotSpecifiedTest")
163 {
164     CHECK_THROWS_AS(Setup(), armnn::ParseException);
165 }
166 
167 TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectInputNameTest")
168 {
169     CHECK_THROWS_AS(Setup({{"Incorrect", armnn::TensorShape({1, 1, 3, 3})}}), armnn::ParseException);
170 }
171 
172 TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectBatchTensorTest")
173 {
174     Setup({{"Input", armnn::TensorShape({2, 1, 3, 3}) }});
175     CHECK_THROWS_AS(RunTest<4>({{"Input", { 1.0, 2.0, 3.0,
176                                             4.0, 5.0, 6.0,
177                                             7.0, 8.0, 9.0 }}},
178                                {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
179                                             4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
180                                             7.0 * 4 + 8.0 * 1 + 9.0 * 2 }}}), armnn::Exception);
181 
182 }
183 
184 }
185