xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/DepthConv.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 
9 TEST_SUITE("OnnxParser_DepthConv")
10 {
11 struct SimpleDepthConv2DFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
SimpleDepthConv2DFixtureSimpleDepthConv2DFixture13     SimpleDepthConv2DFixture()
14     {
15         m_Prototext = R"(
16                    ir_version: 3
17                    producer_name:  "CNTK"
18                    producer_version:  "2.5.1"
19                    domain:  "ai.cntk"
20                    model_version: 1
21                    graph {
22                      name:  "CNTKGraph"
23                      input {
24                         name: "Input"
25                         type {
26                           tensor_type {
27                             elem_type: 1
28                             shape {
29                               dim {
30                                 dim_value: 1
31                               }
32                               dim {
33                                 dim_value: 3
34                               }
35                               dim {
36                                 dim_value: 2
37                               }
38                               dim {
39                                 dim_value: 2
40                               }
41                             }
42                           }
43                         }
44                       }
45                       input {
46                         name: "Weight"
47                         type {
48                           tensor_type {
49                             elem_type: 1
50                             shape {
51                               dim {
52                                 dim_value: 3
53                               }
54                               dim {
55                                 dim_value: 1
56                               }
57                               dim {
58                                 dim_value: 2
59                               }
60                               dim {
61                                 dim_value: 2
62                               }
63                             }
64                           }
65                         }
66                       }
67                       initializer {
68                           dims: 3
69                           dims: 1
70                           dims: 2
71                           dims: 2
72                           data_type: 1
73                           float_data: 1
74                           float_data: 1
75                           float_data: 1
76                           float_data: 1
77                           float_data: 2
78                           float_data: 2
79                           float_data: 2
80                           float_data: 2
81                           float_data: 3
82                           float_data: 3
83                           float_data: 3
84                           float_data: 3
85                           name: "Weight"
86                         }
87                       node {
88                          input: "Input"
89                          input: "Weight"
90                          output: "Output"
91                          name: "Convolution"
92                          op_type: "Conv"
93                          attribute {
94                            name: "kernel_shape"
95                            ints: 2
96                            ints: 2
97                            type: INTS
98                          }
99                          attribute {
100                            name: "strides"
101                            ints: 1
102                            ints: 1
103                            type: INTS
104                          }
105                          attribute {
106                            name: "auto_pad"
107                            s: "VALID"
108                            type: STRING
109                          }
110                          attribute {
111                            name: "group"
112                            i: 3
113                            type: INT
114                          }
115                          attribute {
116                            name: "dilations"
117                            ints: 1
118                            ints: 1
119                            type: INTS
120                          }
121                          doc_string: ""
122                          domain: ""
123                        }
124                       output {
125                           name: "Output"
126                           type {
127                              tensor_type {
128                                elem_type: 1
129                                shape {
130                                    dim {
131                                        dim_value: 1
132                                    }
133                                    dim {
134                                        dim_value: 3
135                                    }
136                                    dim {
137                                        dim_value: 1
138                                    }
139                                    dim {
140                                        dim_value: 1
141                                    }
142                                }
143                             }
144                         }
145                         }
146                     }
147                    opset_import {
148                       version: 7
149                     })";
150         Setup();
151     }
152 };
153 
154 
155 TEST_CASE_FIXTURE(SimpleDepthConv2DFixture, "ValidDepthConvTest")
156 {
157     RunTest<4>({{"Input", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}},
158                {{"Output", { 10, 52, 126 }}});
159 }
160 
161 }
162