1 // Copyright 2017 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // simd_wrappers_neon.h: NEON specialization of simd_wrappers.h 16 17 #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_ 18 #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_ 19 20 #include <arm_neon.h> 21 22 namespace gemmlowp { 23 24 using Int32x4 = int32x4_t; 25 using Int16x4 = int16x4_t; 26 using Int16x8 = int16x8_t; 27 using Uint8x8 = uint8x8_t; 28 using Int8x8 = int8x8_t; 29 30 template <int ScalarCount> 31 struct RegisterType<std::int32_t, ScalarCount> { 32 using Type = 33 typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type; 34 }; 35 36 template <int ScalarCount> 37 struct RegisterType<std::int16_t, ScalarCount> { 38 using Type = typename std::conditional< 39 ScalarCount >= 8, Int16x8, 40 typename std::conditional<ScalarCount >= 4, Int16x4, 41 std::int16_t>::type>::type; 42 }; 43 44 template <int ScalarCount> 45 struct RegisterType<std::uint8_t, ScalarCount> { 46 using Type = typename std::conditional< 47 ScalarCount >= 8, Uint8x8, 48 typename std::conditional<ScalarCount >= 4, std::uint32_t, 49 std::uint8_t>::type>::type; 50 }; 51 52 template <int ScalarCount> 53 struct RegisterType<std::int8_t, ScalarCount> { 54 using Type = typename std::conditional< 55 ScalarCount >= 8, Int8x8, 56 typename std::conditional<ScalarCount >= 4, std::int32_t, 57 std::int8_t>::type>::type; 58 }; 59 60 inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); } 61 inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); } 62 inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); } 63 64 inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) { 65 vst1q_s32(dst, value); 66 } 67 68 inline void StoreInt16x4(std::int16_t* dst, Int16x4 value) { 69 vst1_s16(dst, value); 70 } 71 72 inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) { 73 vst1q_s16(dst, value); 74 } 75 76 template <int Lane> 77 std::int32_t GetLane(Int32x4 value) { 78 return vgetq_lane_s32(value, Lane); 79 } 80 81 template <int Lane> 82 Int32x4 DupLane(Int32x4 value) { 83 switch (Lane) { 84 case 0: 85 return vdupq_lane_s32(vget_low_s32(value), 0); 86 case 1: 87 return vdupq_lane_s32(vget_low_s32(value), 1); 88 case 2: 89 return vdupq_lane_s32(vget_high_s32(value), 0); 90 case 3: 91 return vdupq_lane_s32(vget_high_s32(value), 1); 92 default: 93 static_assert(Lane >= 0 && Lane <= 3, ""); 94 return vdupq_n_s32(0); 95 } 96 } 97 98 inline Int32x4 Mul(Int32x4 a, std::int32_t b) { return vmulq_n_s32(a, b); } 99 100 inline Int32x4 Min(Int32x4 a, Int32x4 b) { return vminq_s32(a, b); } 101 102 inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); } 103 104 inline Int32x4 Max(Int32x4 a, std::int32_t b) { 105 return vmaxq_s32(a, vdupq_n_s32(b)); 106 } 107 108 inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) { 109 return vqrdmulhq_n_s32(a, b); 110 } 111 112 template <int Lane> 113 Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) { 114 switch (Lane) { 115 case 0: 116 return vmulq_lane_s32(a, vget_low_s32(b), 0); 117 case 1: 118 return vmulq_lane_s32(a, vget_low_s32(b), 1); 119 case 2: 120 return vmulq_lane_s32(a, vget_high_s32(b), 0); 121 case 3: 122 return vmulq_lane_s32(a, vget_high_s32(b), 1); 123 default: 124 static_assert(Lane >= 0 && Lane <= 3, ""); 125 return vdupq_n_s32(0); 126 } 127 } 128 129 inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) { 130 *acc = vmlaq_s32(*acc, lhs, rhs); 131 } 132 133 inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) { 134 *acc = vmlaq_n_s32(*acc, lhs, rhs); 135 } 136 137 template <int Lane> 138 inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) { 139 switch (Lane) { 140 case 0: 141 *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 0); 142 break; 143 case 1: 144 *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 1); 145 break; 146 case 2: 147 *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 0); 148 break; 149 case 3: 150 *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 1); 151 break; 152 default: 153 static_assert(Lane >= 0 && Lane <= 3, ""); 154 } 155 } 156 157 template <> 158 struct LoadContiguousImpl<RegBlockInt16<8, 8>> { 159 static RegBlockInt16<8, 8> Run(const std::int16_t* src) { 160 RegBlockInt16<8, 8> result; 161 for (int i = 0; i < 8; i++) { 162 result.buf.reg[i] = vld1q_s16(src + 8 * i); 163 } 164 return result; 165 } 166 }; 167 168 template <> 169 struct LoadContiguousImpl<RegBlockUint8<8, 8>> { 170 static RegBlockUint8<8, 8> Run(const std::uint8_t* src) { 171 RegBlockUint8<8, 8> result; 172 for (int i = 0; i < 8; i++) { 173 result.buf.reg[i] = vld1_u8(src + 8 * i); 174 } 175 return result; 176 } 177 }; 178 179 template <> 180 struct LoadContiguousImpl<RegBlockInt8<8, 8>> { 181 static RegBlockInt8<8, 8> Run(const std::int8_t* src) { 182 RegBlockInt8<8, 8> result; 183 for (int i = 0; i < 8; i++) { 184 result.buf.reg[i] = vld1_s8(src + 8 * i); 185 } 186 return result; 187 } 188 }; 189 190 template <> 191 struct LoadContiguousImpl<RegBlockInt32<8, 8>> { 192 static RegBlockInt32<8, 8> Run(const std::int32_t* src) { 193 RegBlockInt32<8, 8> result; 194 for (int i = 0; i < 16; i++) { 195 result.buf.reg[i] = vld1q_s32(src + 4 * i); 196 } 197 return result; 198 } 199 }; 200 201 // 4x1 := 4x1 + 1x1 202 template <> 203 struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { 204 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 205 const RegBlockInt32<1, 1>& rhs) { 206 RegBlockInt32<4, 1> result; 207 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 208 return result; 209 } 210 }; 211 212 // 1x4 := 1x4 + 1x1 213 template <> 214 struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { 215 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 216 const RegBlockInt32<1, 1>& rhs) { 217 RegBlockInt32<1, 4> result; 218 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 219 return result; 220 } 221 }; 222 223 // 4x1 := 4x1 + 4x1 224 template <> 225 struct BroadcastShiftLeftImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { 226 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 227 const RegBlockInt32<4, 1>& rhs) { 228 RegBlockInt32<4, 1> result; 229 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); 230 return result; 231 } 232 }; 233 234 // 1x4 := 1x4 + 1x4 235 template <> 236 struct BroadcastShiftLeftImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { 237 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 238 const RegBlockInt32<1, 4>& rhs) { 239 RegBlockInt32<1, 4> result; 240 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); 241 return result; 242 } 243 }; 244 245 // 4x4 := 4x4 + 1x4 246 template <> 247 struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { 248 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 249 const RegBlockInt32<1, 4>& rhs) { 250 RegBlockInt32<4, 4> result; 251 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 252 result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); 253 result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); 254 result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); 255 return result; 256 } 257 }; 258 259 // 4x4 := 4x4 + 4x1 260 template <> 261 struct BroadcastShiftLeftImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { 262 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 263 const RegBlockInt32<4, 1>& rhs) { 264 RegBlockInt32<4, 4> result; 265 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); 266 result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[0]); 267 result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]); 268 result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[0]); 269 return result; 270 } 271 }; 272 273 // 8x1 := 8x1 + 1x1 274 template <> 275 struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { 276 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 277 const RegBlockInt32<1, 1>& rhs) { 278 RegBlockInt32<8, 1> result; 279 const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); 280 for (int i = 0; i < 2; i++) { 281 result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], p); 282 } 283 return result; 284 } 285 }; 286 287 // 8x1 := 8x1 + 8x1 288 template <> 289 struct BroadcastShiftLeftImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { 290 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 291 const RegBlockInt32<8, 1>& rhs) { 292 RegBlockInt32<8, 1> result; 293 for (int i = 0; i < 2; i++) { 294 result.buf.reg[i] = ShiftLeft(lhs.buf.reg[i], rhs.buf.reg[i]); 295 } 296 return result; 297 } 298 }; 299 300 // 8x4 := 8x4 + 1x4 301 template <> 302 struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { 303 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 304 const RegBlockInt32<1, 4>& rhs) { 305 RegBlockInt32<8, 4> result; 306 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 307 result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); 308 result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); 309 result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); 310 result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); 311 result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); 312 result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); 313 result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); 314 return result; 315 } 316 }; 317 318 // 8x4 := 8x4 + 8x1 319 template <> 320 struct BroadcastShiftLeftImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { 321 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 322 const RegBlockInt32<8, 1>& rhs) { 323 RegBlockInt32<8, 4> result; 324 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); 325 result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]); 326 result.buf.reg[2] = ShiftLeft(lhs.buf.reg[2], rhs.buf.reg[0]); 327 result.buf.reg[3] = ShiftLeft(lhs.buf.reg[3], rhs.buf.reg[1]); 328 result.buf.reg[4] = ShiftLeft(lhs.buf.reg[4], rhs.buf.reg[0]); 329 result.buf.reg[5] = ShiftLeft(lhs.buf.reg[5], rhs.buf.reg[1]); 330 result.buf.reg[6] = ShiftLeft(lhs.buf.reg[6], rhs.buf.reg[0]); 331 result.buf.reg[7] = ShiftLeft(lhs.buf.reg[7], rhs.buf.reg[1]); 332 return result; 333 } 334 }; 335 336 // 1x8 := 1x8 + 1x8 337 template <> 338 struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> { 339 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 340 const RegBlockInt32<1, 8>& rhs) { 341 RegBlockInt32<1, 8> result; 342 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], rhs.buf.reg[0]); 343 result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], rhs.buf.reg[1]); 344 return result; 345 } 346 }; 347 348 // 1x8 := 1x8 + 1x1 349 template <> 350 struct BroadcastShiftLeftImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { 351 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 352 const RegBlockInt32<1, 1>& rhs) { 353 RegBlockInt32<1, 8> result; 354 result.buf.reg[0] = ShiftLeft(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 355 result.buf.reg[1] = ShiftLeft(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); 356 return result; 357 } 358 }; 359 360 // 4x1 := 4x1 + 1x1 361 template <> 362 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>, 363 RegBlockInt32<1, 1>> { 364 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 365 const RegBlockInt32<1, 1>& rhs) { 366 RegBlockInt32<4, 1> result; 367 result.buf.reg[0] = 368 RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 369 return result; 370 } 371 }; 372 373 // 1x4 := 1x4 + 1x1 374 template <> 375 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>, 376 RegBlockInt32<1, 1>> { 377 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 378 const RegBlockInt32<1, 1>& rhs) { 379 RegBlockInt32<1, 4> result; 380 result.buf.reg[0] = 381 RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 382 return result; 383 } 384 }; 385 386 // 4x1 := 4x1 + 4x1 387 template <> 388 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 1>, 389 RegBlockInt32<4, 1>> { 390 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 391 const RegBlockInt32<4, 1>& rhs) { 392 RegBlockInt32<4, 1> result; 393 result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); 394 return result; 395 } 396 }; 397 398 // 1x4 := 1x4 + 1x4 399 template <> 400 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 4>, 401 RegBlockInt32<1, 4>> { 402 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 403 const RegBlockInt32<1, 4>& rhs) { 404 RegBlockInt32<1, 4> result; 405 result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); 406 return result; 407 } 408 }; 409 410 // 4x4 := 4x4 + 1x4 411 template <> 412 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>, 413 RegBlockInt32<1, 4>> { 414 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 415 const RegBlockInt32<1, 4>& rhs) { 416 RegBlockInt32<4, 4> result; 417 result.buf.reg[0] = 418 RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 419 result.buf.reg[1] = 420 RoundingDivideByPOT(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); 421 result.buf.reg[2] = 422 RoundingDivideByPOT(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); 423 result.buf.reg[3] = 424 RoundingDivideByPOT(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); 425 return result; 426 } 427 }; 428 429 // 4x4 := 4x4 + 4x1 430 template <> 431 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<4, 4>, 432 RegBlockInt32<4, 1>> { 433 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 434 const RegBlockInt32<4, 1>& rhs) { 435 RegBlockInt32<4, 4> result; 436 result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); 437 result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[0]); 438 result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]); 439 result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[0]); 440 return result; 441 } 442 }; 443 444 // 8x1 := 8x1 + 1x1 445 template <> 446 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>, 447 RegBlockInt32<1, 1>> { 448 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 449 const RegBlockInt32<1, 1>& rhs) { 450 RegBlockInt32<8, 1> result; 451 const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); 452 for (int i = 0; i < 2; i++) { 453 result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], p); 454 } 455 return result; 456 } 457 }; 458 459 // 8x1 := 8x1 + 8x1 460 template <> 461 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 1>, 462 RegBlockInt32<8, 1>> { 463 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 464 const RegBlockInt32<8, 1>& rhs) { 465 RegBlockInt32<8, 1> result; 466 for (int i = 0; i < 2; i++) { 467 result.buf.reg[i] = RoundingDivideByPOT(lhs.buf.reg[i], rhs.buf.reg[i]); 468 } 469 return result; 470 } 471 }; 472 473 // 8x4 := 8x4 + 1x4 474 template <> 475 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>, 476 RegBlockInt32<1, 4>> { 477 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 478 const RegBlockInt32<1, 4>& rhs) { 479 RegBlockInt32<8, 4> result; 480 result.buf.reg[0] = 481 RoundingDivideByPOT(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 482 result.buf.reg[1] = 483 RoundingDivideByPOT(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); 484 result.buf.reg[2] = 485 RoundingDivideByPOT(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); 486 result.buf.reg[3] = 487 RoundingDivideByPOT(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); 488 result.buf.reg[4] = 489 RoundingDivideByPOT(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); 490 result.buf.reg[5] = 491 RoundingDivideByPOT(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); 492 result.buf.reg[6] = 493 RoundingDivideByPOT(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); 494 result.buf.reg[7] = 495 RoundingDivideByPOT(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); 496 return result; 497 } 498 }; 499 500 // 8x4 := 8x4 + 8x1 501 template <> 502 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<8, 4>, 503 RegBlockInt32<8, 1>> { 504 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 505 const RegBlockInt32<8, 1>& rhs) { 506 RegBlockInt32<8, 4> result; 507 result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); 508 result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]); 509 result.buf.reg[2] = RoundingDivideByPOT(lhs.buf.reg[2], rhs.buf.reg[0]); 510 result.buf.reg[3] = RoundingDivideByPOT(lhs.buf.reg[3], rhs.buf.reg[1]); 511 result.buf.reg[4] = RoundingDivideByPOT(lhs.buf.reg[4], rhs.buf.reg[0]); 512 result.buf.reg[5] = RoundingDivideByPOT(lhs.buf.reg[5], rhs.buf.reg[1]); 513 result.buf.reg[6] = RoundingDivideByPOT(lhs.buf.reg[6], rhs.buf.reg[0]); 514 result.buf.reg[7] = RoundingDivideByPOT(lhs.buf.reg[7], rhs.buf.reg[1]); 515 return result; 516 } 517 }; 518 519 // 1x8 := 1x8 + 1x8 520 template <> 521 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>, 522 RegBlockInt32<1, 8>> { 523 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 524 const RegBlockInt32<1, 8>& rhs) { 525 RegBlockInt32<1, 8> result; 526 result.buf.reg[0] = RoundingDivideByPOT(lhs.buf.reg[0], rhs.buf.reg[0]); 527 result.buf.reg[1] = RoundingDivideByPOT(lhs.buf.reg[1], rhs.buf.reg[1]); 528 return result; 529 } 530 }; 531 532 // 1x8 := 1x8 + 1x1 533 template <> 534 struct BroadcastRoundingDivideByPOTImpl<RegBlockInt32<1, 8>, 535 RegBlockInt32<1, 1>> { 536 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 537 const RegBlockInt32<1, 1>& rhs) { 538 RegBlockInt32<1, 8> result; 539 result.buf.reg[0] = 540 RoundingDivideByPOT(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 541 result.buf.reg[1] = 542 RoundingDivideByPOT(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); 543 return result; 544 } 545 }; 546 547 } // end namespace gemmlowp 548 549 #include "simd_wrappers_common_neon_sse.h" 550 551 #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_ 552