1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_LoadScopeDynamicTensor") 10*89c4ff92SAndroid Build Coastguard Worker { 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker struct DynamicBatchTensorFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 13*89c4ff92SAndroid Build Coastguard Worker { DynamicBatchTensorFixtureDynamicBatchTensorFixture14*89c4ff92SAndroid Build Coastguard Worker DynamicBatchTensorFixture() 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 17*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 18*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 19*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 20*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 21*89c4ff92SAndroid Build Coastguard Worker model_version: 1 22*89c4ff92SAndroid Build Coastguard Worker graph { 23*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 24*89c4ff92SAndroid Build Coastguard Worker input { 25*89c4ff92SAndroid Build Coastguard Worker name: "Input" 26*89c4ff92SAndroid Build Coastguard Worker type { 27*89c4ff92SAndroid Build Coastguard Worker tensor_type { 28*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 29*89c4ff92SAndroid Build Coastguard Worker shape { 30*89c4ff92SAndroid Build Coastguard Worker dim { 31*89c4ff92SAndroid Build Coastguard Worker dim_value: 0 32*89c4ff92SAndroid Build Coastguard Worker } 33*89c4ff92SAndroid Build Coastguard Worker dim { 34*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 35*89c4ff92SAndroid Build Coastguard Worker } 36*89c4ff92SAndroid Build Coastguard Worker dim { 37*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker dim { 40*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 41*89c4ff92SAndroid Build Coastguard Worker } 42*89c4ff92SAndroid Build Coastguard Worker } 43*89c4ff92SAndroid Build Coastguard Worker } 44*89c4ff92SAndroid Build Coastguard Worker } 45*89c4ff92SAndroid Build Coastguard Worker } 46*89c4ff92SAndroid Build Coastguard Worker input { 47*89c4ff92SAndroid Build Coastguard Worker name: "Weight" 48*89c4ff92SAndroid Build Coastguard Worker type { 49*89c4ff92SAndroid Build Coastguard Worker tensor_type { 50*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 51*89c4ff92SAndroid Build Coastguard Worker shape { 52*89c4ff92SAndroid Build Coastguard Worker dim { 53*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker dim { 56*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 57*89c4ff92SAndroid Build Coastguard Worker } 58*89c4ff92SAndroid Build Coastguard Worker dim { 59*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 60*89c4ff92SAndroid Build Coastguard Worker } 61*89c4ff92SAndroid Build Coastguard Worker dim { 62*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 63*89c4ff92SAndroid Build Coastguard Worker } 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker } 66*89c4ff92SAndroid Build Coastguard Worker } 67*89c4ff92SAndroid Build Coastguard Worker } 68*89c4ff92SAndroid Build Coastguard Worker initializer { 69*89c4ff92SAndroid Build Coastguard Worker dims: 1 70*89c4ff92SAndroid Build Coastguard Worker dims: 1 71*89c4ff92SAndroid Build Coastguard Worker dims: 3 72*89c4ff92SAndroid Build Coastguard Worker dims: 3 73*89c4ff92SAndroid Build Coastguard Worker data_type: 1 74*89c4ff92SAndroid Build Coastguard Worker float_data: 2 75*89c4ff92SAndroid Build Coastguard Worker float_data: 1 76*89c4ff92SAndroid Build Coastguard Worker float_data: 0 77*89c4ff92SAndroid Build Coastguard Worker float_data: 6 78*89c4ff92SAndroid Build Coastguard Worker float_data: 2 79*89c4ff92SAndroid Build Coastguard Worker float_data: 1 80*89c4ff92SAndroid Build Coastguard Worker float_data: 4 81*89c4ff92SAndroid Build Coastguard Worker float_data: 1 82*89c4ff92SAndroid Build Coastguard Worker float_data: 2 83*89c4ff92SAndroid Build Coastguard Worker name: "Weight" 84*89c4ff92SAndroid Build Coastguard Worker } 85*89c4ff92SAndroid Build Coastguard Worker node { 86*89c4ff92SAndroid Build Coastguard Worker input: "Input" 87*89c4ff92SAndroid Build Coastguard Worker input: "Weight" 88*89c4ff92SAndroid Build Coastguard Worker output: "Output" 89*89c4ff92SAndroid Build Coastguard Worker name: "Convolution" 90*89c4ff92SAndroid Build Coastguard Worker op_type: "Conv" 91*89c4ff92SAndroid Build Coastguard Worker attribute { 92*89c4ff92SAndroid Build Coastguard Worker name: "kernel_shape" 93*89c4ff92SAndroid Build Coastguard Worker ints: 3 94*89c4ff92SAndroid Build Coastguard Worker ints: 3 95*89c4ff92SAndroid Build Coastguard Worker type: INTS 96*89c4ff92SAndroid Build Coastguard Worker } 97*89c4ff92SAndroid Build Coastguard Worker attribute { 98*89c4ff92SAndroid Build Coastguard Worker name: "strides" 99*89c4ff92SAndroid Build Coastguard Worker ints: 1 100*89c4ff92SAndroid Build Coastguard Worker ints: 1 101*89c4ff92SAndroid Build Coastguard Worker type: INTS 102*89c4ff92SAndroid Build Coastguard Worker } 103*89c4ff92SAndroid Build Coastguard Worker attribute { 104*89c4ff92SAndroid Build Coastguard Worker name: "auto_pad" 105*89c4ff92SAndroid Build Coastguard Worker s: "VALID" 106*89c4ff92SAndroid Build Coastguard Worker type: STRING 107*89c4ff92SAndroid Build Coastguard Worker } 108*89c4ff92SAndroid Build Coastguard Worker attribute { 109*89c4ff92SAndroid Build Coastguard Worker name: "group" 110*89c4ff92SAndroid Build Coastguard Worker i: 1 111*89c4ff92SAndroid Build Coastguard Worker type: INT 112*89c4ff92SAndroid Build Coastguard Worker } 113*89c4ff92SAndroid Build Coastguard Worker attribute { 114*89c4ff92SAndroid Build Coastguard Worker name: "dilations" 115*89c4ff92SAndroid Build Coastguard Worker ints: 1 116*89c4ff92SAndroid Build Coastguard Worker ints: 1 117*89c4ff92SAndroid Build Coastguard Worker type: INTS 118*89c4ff92SAndroid Build Coastguard Worker } 119*89c4ff92SAndroid Build Coastguard Worker doc_string: "" 120*89c4ff92SAndroid Build Coastguard Worker domain: "" 121*89c4ff92SAndroid Build Coastguard Worker } 122*89c4ff92SAndroid Build Coastguard Worker output { 123*89c4ff92SAndroid Build Coastguard Worker name: "Output" 124*89c4ff92SAndroid Build Coastguard Worker type { 125*89c4ff92SAndroid Build Coastguard Worker tensor_type { 126*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 127*89c4ff92SAndroid Build Coastguard Worker shape { 128*89c4ff92SAndroid Build Coastguard Worker dim { 129*89c4ff92SAndroid Build Coastguard Worker dim_value: 0 130*89c4ff92SAndroid Build Coastguard Worker } 131*89c4ff92SAndroid Build Coastguard Worker dim { 132*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 133*89c4ff92SAndroid Build Coastguard Worker } 134*89c4ff92SAndroid Build Coastguard Worker dim { 135*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 136*89c4ff92SAndroid Build Coastguard Worker } 137*89c4ff92SAndroid Build Coastguard Worker dim { 138*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 139*89c4ff92SAndroid Build Coastguard Worker } 140*89c4ff92SAndroid Build Coastguard Worker } 141*89c4ff92SAndroid Build Coastguard Worker } 142*89c4ff92SAndroid Build Coastguard Worker } 143*89c4ff92SAndroid Build Coastguard Worker } 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker opset_import { 146*89c4ff92SAndroid Build Coastguard Worker version: 7 147*89c4ff92SAndroid Build Coastguard Worker })"; 148*89c4ff92SAndroid Build Coastguard Worker } 149*89c4ff92SAndroid Build Coastguard Worker }; 150*89c4ff92SAndroid Build Coastguard Worker 151*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "DynamicBatchTensorTest") 152*89c4ff92SAndroid Build Coastguard Worker { 153*89c4ff92SAndroid Build Coastguard Worker Setup({{"Input", armnn::TensorShape({1, 1, 3, 3})}}); 154*89c4ff92SAndroid Build Coastguard Worker RunTest<4>({{"Input", {1.0, 2.0, 3.0, 155*89c4ff92SAndroid Build Coastguard Worker 4.0, 5.0, 6.0, 156*89c4ff92SAndroid Build Coastguard Worker 7.0, 8.0, 9.0}}}, 157*89c4ff92SAndroid Build Coastguard Worker {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 158*89c4ff92SAndroid Build Coastguard Worker 4.0 * 6 + 5.0 * 2 + 6.0 * 1 + 159*89c4ff92SAndroid Build Coastguard Worker 7.0 * 4 + 8.0 * 1 + 9.0 * 2}}}); 160*89c4ff92SAndroid Build Coastguard Worker } 161*89c4ff92SAndroid Build Coastguard Worker 162*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "TensorShapeNotSpecifiedTest") 163*89c4ff92SAndroid Build Coastguard Worker { 164*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Setup(), armnn::ParseException); 165*89c4ff92SAndroid Build Coastguard Worker } 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectInputNameTest") 168*89c4ff92SAndroid Build Coastguard Worker { 169*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Setup({{"Incorrect", armnn::TensorShape({1, 1, 3, 3})}}), armnn::ParseException); 170*89c4ff92SAndroid Build Coastguard Worker } 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicBatchTensorFixture, "IncorrectBatchTensorTest") 173*89c4ff92SAndroid Build Coastguard Worker { 174*89c4ff92SAndroid Build Coastguard Worker Setup({{"Input", armnn::TensorShape({2, 1, 3, 3}) }}); 175*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(RunTest<4>({{"Input", { 1.0, 2.0, 3.0, 176*89c4ff92SAndroid Build Coastguard Worker 4.0, 5.0, 6.0, 177*89c4ff92SAndroid Build Coastguard Worker 7.0, 8.0, 9.0 }}}, 178*89c4ff92SAndroid Build Coastguard Worker {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 179*89c4ff92SAndroid Build Coastguard Worker 4.0 * 6 + 5.0 * 2 + 6.0 * 1 + 180*89c4ff92SAndroid Build Coastguard Worker 7.0 * 4 + 8.0 * 1 + 9.0 * 2 }}}), armnn::Exception); 181*89c4ff92SAndroid Build Coastguard Worker 182*89c4ff92SAndroid Build Coastguard Worker } 183*89c4ff92SAndroid Build Coastguard Worker 184*89c4ff92SAndroid Build Coastguard Worker } 185