1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersSerializeFixture.hpp" 7 #include <armnnDeserializer/IDeserializer.hpp> 8 9 #include <doctest/doctest.h> 10 11 #include <string> 12 13 TEST_SUITE("Deserializer_BatchMatMul") 14 { 15 struct BatchMatMulFixture : public ParserFlatbuffersSerializeFixture 16 { BatchMatMulFixtureBatchMatMulFixture17 explicit BatchMatMulFixture(const std::string& inputXShape, 18 const std::string& inputYShape, 19 const std::string& outputShape, 20 const std::string& dataType) 21 { 22 m_JsonString = R"( 23 { 24 inputIds:[ 25 0, 26 1 27 ], 28 outputIds:[ 29 3 30 ], 31 layers:[ 32 { 33 layer_type:"InputLayer", 34 layer:{ 35 base:{ 36 layerBindingId:0, 37 base:{ 38 index:0, 39 layerName:"InputXLayer", 40 layerType:"Input", 41 inputSlots:[ 42 { 43 index:0, 44 connection:{ 45 sourceLayerIndex:0, 46 outputSlotIndex:0 47 }, 48 49 } 50 ], 51 outputSlots:[ 52 { 53 index:0, 54 tensorInfo:{ 55 dimensions:)" + inputXShape + R"(, 56 dataType:)" + dataType + R"( 57 }, 58 59 } 60 ], 61 62 }, 63 64 } 65 }, 66 67 }, 68 { 69 layer_type:"InputLayer", 70 layer:{ 71 base:{ 72 layerBindingId:1, 73 base:{ 74 index:1, 75 layerName:"InputYLayer", 76 layerType:"Input", 77 inputSlots:[ 78 { 79 index:0, 80 connection:{ 81 sourceLayerIndex:0, 82 outputSlotIndex:0 83 }, 84 85 } 86 ], 87 outputSlots:[ 88 { 89 index:0, 90 tensorInfo:{ 91 dimensions:)" + inputYShape + R"(, 92 dataType:)" + dataType + R"( 93 }, 94 95 } 96 ], 97 98 }, 99 100 } 101 }, 102 103 }, 104 { 105 layer_type:"BatchMatMulLayer", 106 layer:{ 107 base:{ 108 index:2, 109 layerName:"BatchMatMulLayer", 110 layerType:"BatchMatMul", 111 inputSlots:[ 112 { 113 index:0, 114 connection:{ 115 sourceLayerIndex:0, 116 outputSlotIndex:0 117 }, 118 119 }, 120 { 121 index:1, 122 connection:{ 123 sourceLayerIndex:1, 124 outputSlotIndex:0 125 }, 126 127 } 128 ], 129 outputSlots:[ 130 { 131 index:0, 132 tensorInfo:{ 133 dimensions:)" + outputShape + R"(, 134 dataType:)" + dataType + R"( 135 }, 136 137 } 138 ], 139 140 }, 141 descriptor:{ 142 transposeX:false, 143 transposeY:false, 144 adjointX:false, 145 adjointY:false, 146 dataLayoutX:NHWC, 147 dataLayoutY:NHWC 148 } 149 }, 150 151 }, 152 { 153 layer_type:"OutputLayer", 154 layer:{ 155 base:{ 156 layerBindingId:0, 157 base:{ 158 index:3, 159 layerName:"OutputLayer", 160 layerType:"Output", 161 inputSlots:[ 162 { 163 index:0, 164 connection:{ 165 sourceLayerIndex:2, 166 outputSlotIndex:0 167 }, 168 169 } 170 ], 171 outputSlots:[ 172 { 173 index:0, 174 tensorInfo:{ 175 dimensions:)" + outputShape + R"(, 176 dataType:)" + dataType + R"( 177 }, 178 179 } 180 ], 181 182 } 183 } 184 }, 185 186 } 187 ] 188 } 189 )"; 190 Setup(); 191 } 192 }; 193 194 struct SimpleBatchMatMulFixture : BatchMatMulFixture 195 { SimpleBatchMatMulFixtureSimpleBatchMatMulFixture196 SimpleBatchMatMulFixture() 197 : BatchMatMulFixture("[ 1, 2, 2, 1 ]", 198 "[ 1, 2, 2, 1 ]", 199 "[ 1, 2, 2, 1 ]", 200 "Float32") 201 {} 202 }; 203 204 TEST_CASE_FIXTURE(SimpleBatchMatMulFixture, "SimpleBatchMatMulTest") 205 { 206 RunTest<4, armnn::DataType::Float32>( 207 0, 208 {{"InputXLayer", { 1.0f, 2.0f, 3.0f, 4.0f }}, 209 {"InputYLayer", { 5.0f, 6.0f, 7.0f, 8.0f }}}, 210 {{"OutputLayer", { 19.0f, 22.0f, 43.0f, 50.0f }}}); 211 } 212 213 }