1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 9 TEST_SUITE("OnnxParser_LoadScopeDynamicTensor") 10 { 11 12 struct DynamicBatchTensorFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 13 { DynamicBatchTensorFixtureDynamicBatchTensorFixture14 DynamicBatchTensorFixture() 15 { 16 m_Prototext = R"( 17 ir_version: 3 18 producer_name: "CNTK" 19 producer_version: "2.5.1" 20 domain: "ai.cntk" 21 model_version: 1 22 graph { 23 name: "CNTKGraph" 24 input { 25 name: "Input" 26 type { 27 tensor_type { 28 elem_type: 1 29 shape { 30 dim { 31 dim_value: 0 32 } 33 dim { 34 dim_value: 1 35 } 36 dim { 37 dim_value: 3 38 } 39 dim { 40 dim_value: 3 41 } 42 } 43 } 44 } 45 } 46 input { 47 name: "Weight" 48 type { 49 tensor_type { 50 elem_type: 1 51 shape { 52 dim { 53 dim_value: 1 54 } 55 dim { 56 dim_value: 1 57 } 58 dim { 59 dim_value: 3 60 } 61 dim { 62 dim_value: 3 63 } 64 } 65 } 66 } 67 } 68 initializer { 69 dims: 1 70 dims: 1 71 dims: 3 72 dims: 3 73 data_type: 1 74 float_data: 2 75 float_data: 1 76 float_data: 0 77 float_data: 6 78 float_data: 2 79 float_data: 1 80 float_data: 4 81 float_data: 1 82 float_data: 2 83 name: "Weight" 84 } 85 node { 86 input: "Input" 87 input: "Weight" 88 output: "Output" 89 name: "Convolution" 90 op_type: "Conv" 91 attribute { 92 name: "kernel_shape" 93 ints: 3 94 ints: 3 95 type: INTS 96 } 97 attribute { 98 name: "strides" 99 ints: 1 100 ints: 1 101 type: INTS 102 } 103 attribute { 104 name: "auto_pad" 105 s: "VALID" 106 type: STRING 107 } 108 attribute { 109 name: "group" 110 i: 1 111 type: INT 112 } 113 attribute { 114 name: "dilations" 115 ints: 1 116 ints: 1 117 type: INTS 118 } 119 doc_string: "" 120 domain: "" 121 } 122 output { 123 name: "Output" 124 type { 125 tensor_type { 126 elem_type: 1 127 shape { 128 dim { 129 dim_value: 0 130 } 131 dim { 132 dim_value: 1 133 } 134 dim { 135 dim_value: 1 136 } 137 dim { 138 dim_value: 1 139 } 140 } 141 } 142 } 143 } 144 } 145 opset_import { 146 version: 7 147 })"; 148 } 149 }; 150 151 TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "DynamicBatchTensorTest") 152 { 153 Setup({{"Input", armnn::TensorShape({1, 1, 3, 3})}}); 154 RunTest<4>({{"Input", {1.0, 2.0, 3.0, 155 4.0, 5.0, 6.0, 156 7.0, 8.0, 9.0}}}, 157 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 158 4.0 * 6 + 5.0 * 2 + 6.0 * 1 + 159 7.0 * 4 + 8.0 * 1 + 9.0 * 2}}}); 160 } 161 162 TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "TensorShapeNotSpecifiedTest") 163 { 164 CHECK_THROWS_AS(Setup(), armnn::ParseException); 165 } 166 167 TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectInputNameTest") 168 { 169 CHECK_THROWS_AS(Setup({{"Incorrect", armnn::TensorShape({1, 1, 3, 3})}}), armnn::ParseException); 170 } 171 172 TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectBatchTensorTest") 173 { 174 Setup({{"Input", armnn::TensorShape({2, 1, 3, 3}) }}); 175 CHECK_THROWS_AS(RunTest<4>({{"Input", { 1.0, 2.0, 3.0, 176 4.0, 5.0, 6.0, 177 7.0, 8.0, 9.0 }}}, 178 {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 179 4.0 * 6 + 5.0 * 2 + 6.0 * 1 + 180 7.0 * 4 + 8.0 * 1 + 9.0 * 2 }}}), armnn::Exception); 181 182 } 183 184 } 185