xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/LoadScopeDynamicTensor.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_LoadScopeDynamicTensor")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker struct DynamicBatchTensorFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13*89c4ff92SAndroid Build Coastguard Worker {
DynamicBatchTensorFixtureDynamicBatchTensorFixture14*89c4ff92SAndroid Build Coastguard Worker     DynamicBatchTensorFixture()
15*89c4ff92SAndroid Build Coastguard Worker     {
16*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
17*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
18*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
19*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
20*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
21*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
22*89c4ff92SAndroid Build Coastguard Worker                    graph {
23*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
24*89c4ff92SAndroid Build Coastguard Worker                      input {
25*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
26*89c4ff92SAndroid Build Coastguard Worker                         type {
27*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
28*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
29*89c4ff92SAndroid Build Coastguard Worker                             shape {
30*89c4ff92SAndroid Build Coastguard Worker                               dim {
31*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 0
32*89c4ff92SAndroid Build Coastguard Worker                               }
33*89c4ff92SAndroid Build Coastguard Worker                               dim {
34*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
35*89c4ff92SAndroid Build Coastguard Worker                               }
36*89c4ff92SAndroid Build Coastguard Worker                               dim {
37*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
38*89c4ff92SAndroid Build Coastguard Worker                               }
39*89c4ff92SAndroid Build Coastguard Worker                               dim {
40*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
41*89c4ff92SAndroid Build Coastguard Worker                               }
42*89c4ff92SAndroid Build Coastguard Worker                             }
43*89c4ff92SAndroid Build Coastguard Worker                           }
44*89c4ff92SAndroid Build Coastguard Worker                         }
45*89c4ff92SAndroid Build Coastguard Worker                       }
46*89c4ff92SAndroid Build Coastguard Worker                       input {
47*89c4ff92SAndroid Build Coastguard Worker                         name: "Weight"
48*89c4ff92SAndroid Build Coastguard Worker                         type {
49*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
50*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
51*89c4ff92SAndroid Build Coastguard Worker                             shape {
52*89c4ff92SAndroid Build Coastguard Worker                               dim {
53*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
54*89c4ff92SAndroid Build Coastguard Worker                               }
55*89c4ff92SAndroid Build Coastguard Worker                               dim {
56*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
57*89c4ff92SAndroid Build Coastguard Worker                               }
58*89c4ff92SAndroid Build Coastguard Worker                               dim {
59*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
60*89c4ff92SAndroid Build Coastguard Worker                               }
61*89c4ff92SAndroid Build Coastguard Worker                               dim {
62*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
63*89c4ff92SAndroid Build Coastguard Worker                               }
64*89c4ff92SAndroid Build Coastguard Worker                             }
65*89c4ff92SAndroid Build Coastguard Worker                           }
66*89c4ff92SAndroid Build Coastguard Worker                         }
67*89c4ff92SAndroid Build Coastguard Worker                       }
68*89c4ff92SAndroid Build Coastguard Worker                       initializer {
69*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
70*89c4ff92SAndroid Build Coastguard Worker                           dims: 1
71*89c4ff92SAndroid Build Coastguard Worker                           dims: 3
72*89c4ff92SAndroid Build Coastguard Worker                           dims: 3
73*89c4ff92SAndroid Build Coastguard Worker                           data_type: 1
74*89c4ff92SAndroid Build Coastguard Worker                           float_data: 2
75*89c4ff92SAndroid Build Coastguard Worker                           float_data: 1
76*89c4ff92SAndroid Build Coastguard Worker                           float_data: 0
77*89c4ff92SAndroid Build Coastguard Worker                           float_data: 6
78*89c4ff92SAndroid Build Coastguard Worker                           float_data: 2
79*89c4ff92SAndroid Build Coastguard Worker                           float_data: 1
80*89c4ff92SAndroid Build Coastguard Worker                           float_data: 4
81*89c4ff92SAndroid Build Coastguard Worker                           float_data: 1
82*89c4ff92SAndroid Build Coastguard Worker                           float_data: 2
83*89c4ff92SAndroid Build Coastguard Worker                           name: "Weight"
84*89c4ff92SAndroid Build Coastguard Worker                         }
85*89c4ff92SAndroid Build Coastguard Worker                       node {
86*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
87*89c4ff92SAndroid Build Coastguard Worker                          input: "Weight"
88*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
89*89c4ff92SAndroid Build Coastguard Worker                          name: "Convolution"
90*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Conv"
91*89c4ff92SAndroid Build Coastguard Worker                          attribute {
92*89c4ff92SAndroid Build Coastguard Worker                            name: "kernel_shape"
93*89c4ff92SAndroid Build Coastguard Worker                            ints: 3
94*89c4ff92SAndroid Build Coastguard Worker                            ints: 3
95*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
96*89c4ff92SAndroid Build Coastguard Worker                          }
97*89c4ff92SAndroid Build Coastguard Worker                          attribute {
98*89c4ff92SAndroid Build Coastguard Worker                            name: "strides"
99*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
100*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
101*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
102*89c4ff92SAndroid Build Coastguard Worker                          }
103*89c4ff92SAndroid Build Coastguard Worker                          attribute {
104*89c4ff92SAndroid Build Coastguard Worker                            name: "auto_pad"
105*89c4ff92SAndroid Build Coastguard Worker                            s: "VALID"
106*89c4ff92SAndroid Build Coastguard Worker                            type: STRING
107*89c4ff92SAndroid Build Coastguard Worker                          }
108*89c4ff92SAndroid Build Coastguard Worker                          attribute {
109*89c4ff92SAndroid Build Coastguard Worker                            name: "group"
110*89c4ff92SAndroid Build Coastguard Worker                            i: 1
111*89c4ff92SAndroid Build Coastguard Worker                            type: INT
112*89c4ff92SAndroid Build Coastguard Worker                          }
113*89c4ff92SAndroid Build Coastguard Worker                          attribute {
114*89c4ff92SAndroid Build Coastguard Worker                            name: "dilations"
115*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
116*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
117*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
118*89c4ff92SAndroid Build Coastguard Worker                          }
119*89c4ff92SAndroid Build Coastguard Worker                          doc_string: ""
120*89c4ff92SAndroid Build Coastguard Worker                          domain: ""
121*89c4ff92SAndroid Build Coastguard Worker                        }
122*89c4ff92SAndroid Build Coastguard Worker                       output {
123*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
124*89c4ff92SAndroid Build Coastguard Worker                           type {
125*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
126*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
127*89c4ff92SAndroid Build Coastguard Worker                                shape {
128*89c4ff92SAndroid Build Coastguard Worker                                    dim {
129*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 0
130*89c4ff92SAndroid Build Coastguard Worker                                    }
131*89c4ff92SAndroid Build Coastguard Worker                                    dim {
132*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
133*89c4ff92SAndroid Build Coastguard Worker                                    }
134*89c4ff92SAndroid Build Coastguard Worker                                    dim {
135*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
136*89c4ff92SAndroid Build Coastguard Worker                                    }
137*89c4ff92SAndroid Build Coastguard Worker                                    dim {
138*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
139*89c4ff92SAndroid Build Coastguard Worker                                    }
140*89c4ff92SAndroid Build Coastguard Worker                                }
141*89c4ff92SAndroid Build Coastguard Worker                             }
142*89c4ff92SAndroid Build Coastguard Worker                         }
143*89c4ff92SAndroid Build Coastguard Worker                         }
144*89c4ff92SAndroid Build Coastguard Worker                     }
145*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
146*89c4ff92SAndroid Build Coastguard Worker                       version: 7
147*89c4ff92SAndroid Build Coastguard Worker                     })";
148*89c4ff92SAndroid Build Coastguard Worker     }
149*89c4ff92SAndroid Build Coastguard Worker };
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "DynamicBatchTensorTest")
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker     Setup({{"Input", armnn::TensorShape({1, 1, 3, 3})}});
154*89c4ff92SAndroid Build Coastguard Worker     RunTest<4>({{"Input", {1.0, 2.0, 3.0,
155*89c4ff92SAndroid Build Coastguard Worker                            4.0, 5.0, 6.0,
156*89c4ff92SAndroid Build Coastguard Worker                            7.0, 8.0, 9.0}}},
157*89c4ff92SAndroid Build Coastguard Worker               {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
158*89c4ff92SAndroid Build Coastguard Worker                            4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
159*89c4ff92SAndroid Build Coastguard Worker                            7.0 * 4 + 8.0 * 1 + 9.0 * 2}}});
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "TensorShapeNotSpecifiedTest")
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup(), armnn::ParseException);
165*89c4ff92SAndroid Build Coastguard Worker }
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectInputNameTest")
168*89c4ff92SAndroid Build Coastguard Worker {
169*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup({{"Incorrect", armnn::TensorShape({1, 1, 3, 3})}}), armnn::ParseException);
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectBatchTensorTest")
173*89c4ff92SAndroid Build Coastguard Worker {
174*89c4ff92SAndroid Build Coastguard Worker     Setup({{"Input", armnn::TensorShape({2, 1, 3, 3}) }});
175*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(RunTest<4>({{"Input", { 1.0, 2.0, 3.0,
176*89c4ff92SAndroid Build Coastguard Worker                                             4.0, 5.0, 6.0,
177*89c4ff92SAndroid Build Coastguard Worker                                             7.0, 8.0, 9.0 }}},
178*89c4ff92SAndroid Build Coastguard Worker                                {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
179*89c4ff92SAndroid Build Coastguard Worker                                             4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
180*89c4ff92SAndroid Build Coastguard Worker                                             7.0 * 4 + 8.0 * 1 + 9.0 * 2 }}}), armnn::Exception);
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker }
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker }
185