xref: /aosp_15_r20/external/gemmlowp/internal/simd_wrappers.h (revision 5f39d1b313f0528e11bae88b3029b54b9e1033e7)
1*5f39d1b3SJooyung Han // Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2*5f39d1b3SJooyung Han //
3*5f39d1b3SJooyung Han // Licensed under the Apache License, Version 2.0 (the "License");
4*5f39d1b3SJooyung Han // you may not use this file except in compliance with the License.
5*5f39d1b3SJooyung Han // You may obtain a copy of the License at
6*5f39d1b3SJooyung Han //
7*5f39d1b3SJooyung Han //     http://www.apache.org/licenses/LICENSE-2.0
8*5f39d1b3SJooyung Han //
9*5f39d1b3SJooyung Han // Unless required by applicable law or agreed to in writing, software
10*5f39d1b3SJooyung Han // distributed under the License is distributed on an "AS IS" BASIS,
11*5f39d1b3SJooyung Han // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*5f39d1b3SJooyung Han // See the License for the specific language governing permissions and
13*5f39d1b3SJooyung Han // limitations under the License.
14*5f39d1b3SJooyung Han 
15*5f39d1b3SJooyung Han // simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
16*5f39d1b3SJooyung Han // extending the set of such functions from fixedpoint.h.
17*5f39d1b3SJooyung Han 
18*5f39d1b3SJooyung Han #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
19*5f39d1b3SJooyung Han #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
20*5f39d1b3SJooyung Han 
21*5f39d1b3SJooyung Han #include <algorithm>
22*5f39d1b3SJooyung Han #include <type_traits>
23*5f39d1b3SJooyung Han #include "../fixedpoint/fixedpoint.h"
24*5f39d1b3SJooyung Han 
25*5f39d1b3SJooyung Han namespace gemmlowp {
26*5f39d1b3SJooyung Han 
27*5f39d1b3SJooyung Han template <typename ScalarType, int ScalarCount>
28*5f39d1b3SJooyung Han struct RegisterType {
29*5f39d1b3SJooyung Han   using Type = ScalarType;
30*5f39d1b3SJooyung Han };
31*5f39d1b3SJooyung Han 
Min(std::int32_t a,std::int32_t b)32*5f39d1b3SJooyung Han inline std::int32_t Min(std::int32_t a, std::int32_t b) {
33*5f39d1b3SJooyung Han   return std::min(a, b);
34*5f39d1b3SJooyung Han }
35*5f39d1b3SJooyung Han 
Max(std::int32_t a,std::int32_t b)36*5f39d1b3SJooyung Han inline std::int32_t Max(std::int32_t a, std::int32_t b) {
37*5f39d1b3SJooyung Han   return std::max(a, b);
38*5f39d1b3SJooyung Han }
39*5f39d1b3SJooyung Han 
MulAdd(std::int32_t lhs,std::int32_t rhs,std::int32_t * acc)40*5f39d1b3SJooyung Han inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
41*5f39d1b3SJooyung Han   *acc += lhs * rhs;
42*5f39d1b3SJooyung Han }
43*5f39d1b3SJooyung Han 
44*5f39d1b3SJooyung Han template <typename tScalarType, int tScalarCount>
45*5f39d1b3SJooyung Han struct RegisterBuffer {
46*5f39d1b3SJooyung Han   using ScalarType = tScalarType;
47*5f39d1b3SJooyung Han   static constexpr int kScalarCount = tScalarCount;
48*5f39d1b3SJooyung Han   using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
49*5f39d1b3SJooyung Han   static_assert((kScalarCount & (kScalarCount - 1)) == 0,
50*5f39d1b3SJooyung Han                 "kScalarCount must be a power of two");
51*5f39d1b3SJooyung Han   static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
52*5f39d1b3SJooyung Han   static constexpr int kRegisterLanes =
53*5f39d1b3SJooyung Han       sizeof(RegisterType) / sizeof(ScalarType);
54*5f39d1b3SJooyung Han   static constexpr int kRegisterCount =
55*5f39d1b3SJooyung Han       (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
56*5f39d1b3SJooyung Han       sizeof(RegisterType);
57*5f39d1b3SJooyung Han 
58*5f39d1b3SJooyung Han   RegisterType reg[kRegisterCount];
59*5f39d1b3SJooyung Han };
60*5f39d1b3SJooyung Han 
61*5f39d1b3SJooyung Han template <typename tScalarType, int tRows, int tCols>
62*5f39d1b3SJooyung Han struct RegisterBlock {
63*5f39d1b3SJooyung Han   using ScalarType = tScalarType;
64*5f39d1b3SJooyung Han   static constexpr int kRows = tRows;
65*5f39d1b3SJooyung Han   static constexpr int kCols = tCols;
66*5f39d1b3SJooyung Han   static constexpr int kScalarCount = kRows * kCols;
67*5f39d1b3SJooyung Han   using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
68*5f39d1b3SJooyung Han   using RegisterType = typename BufferType::RegisterType;
69*5f39d1b3SJooyung Han   static constexpr int kRegisterCount = BufferType::kRegisterCount;
70*5f39d1b3SJooyung Han   static constexpr int kRegisterLanes = BufferType::kRegisterLanes;
71*5f39d1b3SJooyung Han 
72*5f39d1b3SJooyung Han   BufferType buf;
73*5f39d1b3SJooyung Han };
74*5f39d1b3SJooyung Han 
75*5f39d1b3SJooyung Han template <typename RegisterBlockType>
76*5f39d1b3SJooyung Han struct RegisterBlockAddImpl {
RunRegisterBlockAddImpl77*5f39d1b3SJooyung Han   static RegisterBlockType Run(const RegisterBlockType& lhs,
78*5f39d1b3SJooyung Han                                const RegisterBlockType& rhs) {
79*5f39d1b3SJooyung Han     RegisterBlockType result;
80*5f39d1b3SJooyung Han     for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
81*5f39d1b3SJooyung Han       result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
82*5f39d1b3SJooyung Han     }
83*5f39d1b3SJooyung Han     return result;
84*5f39d1b3SJooyung Han   }
85*5f39d1b3SJooyung Han };
86*5f39d1b3SJooyung Han 
87*5f39d1b3SJooyung Han template <typename RegisterBlockType>
RegisterBlockAdd(const RegisterBlockType & lhs,const RegisterBlockType & rhs)88*5f39d1b3SJooyung Han RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
89*5f39d1b3SJooyung Han                                    const RegisterBlockType& rhs) {
90*5f39d1b3SJooyung Han   return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
91*5f39d1b3SJooyung Han }
92*5f39d1b3SJooyung Han 
93*5f39d1b3SJooyung Han template <typename LhsType, typename RhsType>
94*5f39d1b3SJooyung Han struct ShouldFlipLhsRhs {
95*5f39d1b3SJooyung Han   static constexpr bool kValue =
96*5f39d1b3SJooyung Han       (LhsType::kScalarCount < RhsType::kScalarCount) ||
97*5f39d1b3SJooyung Han       (LhsType::kScalarCount == RhsType::kScalarCount &&
98*5f39d1b3SJooyung Han        (LhsType::kRows < RhsType::kRows));
99*5f39d1b3SJooyung Han };
100*5f39d1b3SJooyung Han 
101*5f39d1b3SJooyung Han template <typename LhsType, typename RhsType,
102*5f39d1b3SJooyung Han           bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
103*5f39d1b3SJooyung Han struct FlipLhsRhs {
104*5f39d1b3SJooyung Han   using FlippedLhsType = LhsType;
105*5f39d1b3SJooyung Han   using FlippedRhsType = RhsType;
FlippedLhsFlipLhsRhs106*5f39d1b3SJooyung Han   static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
107*5f39d1b3SJooyung Han                                           const RhsType& rhs) {
108*5f39d1b3SJooyung Han     (void)rhs;
109*5f39d1b3SJooyung Han     return lhs;
110*5f39d1b3SJooyung Han   }
FlippedRhsFlipLhsRhs111*5f39d1b3SJooyung Han   static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
112*5f39d1b3SJooyung Han                                           const RhsType& rhs) {
113*5f39d1b3SJooyung Han     (void)lhs;
114*5f39d1b3SJooyung Han     return rhs;
115*5f39d1b3SJooyung Han   }
116*5f39d1b3SJooyung Han };
117*5f39d1b3SJooyung Han 
118*5f39d1b3SJooyung Han template <typename LhsType, typename RhsType>
119*5f39d1b3SJooyung Han struct FlipLhsRhs<LhsType, RhsType, true> {
120*5f39d1b3SJooyung Han   using FlippedLhsType = RhsType;
121*5f39d1b3SJooyung Han   using FlippedRhsType = LhsType;
122*5f39d1b3SJooyung Han   static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
123*5f39d1b3SJooyung Han                                           const RhsType& rhs) {
124*5f39d1b3SJooyung Han     (void)lhs;
125*5f39d1b3SJooyung Han     return rhs;
126*5f39d1b3SJooyung Han   }
127*5f39d1b3SJooyung Han   static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
128*5f39d1b3SJooyung Han                                           const RhsType& rhs) {
129*5f39d1b3SJooyung Han     (void)rhs;
130*5f39d1b3SJooyung Han     return lhs;
131*5f39d1b3SJooyung Han   }
132*5f39d1b3SJooyung Han };
133*5f39d1b3SJooyung Han 
134*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
135*5f39d1b3SJooyung Han struct BroadcastBinaryOpShape {
136*5f39d1b3SJooyung Han   static constexpr int kRows =
137*5f39d1b3SJooyung Han       Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
138*5f39d1b3SJooyung Han   static constexpr int kCols =
139*5f39d1b3SJooyung Han       Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
140*5f39d1b3SJooyung Han };
141*5f39d1b3SJooyung Han 
142*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
143*5f39d1b3SJooyung Han struct BroadcastBinaryOpRegisterBlock {
144*5f39d1b3SJooyung Han   using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
145*5f39d1b3SJooyung Han   using ScalarType = typename Lhs::ScalarType;
146*5f39d1b3SJooyung Han   using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
147*5f39d1b3SJooyung Han };
148*5f39d1b3SJooyung Han 
149*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
150*5f39d1b3SJooyung Han struct BroadcastAddImpl {
151*5f39d1b3SJooyung Han   using ResultBlockType =
152*5f39d1b3SJooyung Han       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
153*5f39d1b3SJooyung Han   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
154*5f39d1b3SJooyung Han     ResultBlockType result;
155*5f39d1b3SJooyung Han     static constexpr int Rows = ResultBlockType::kRows;
156*5f39d1b3SJooyung Han     static constexpr int Cols = ResultBlockType::kCols;
157*5f39d1b3SJooyung Han     static constexpr int LhsRows = Lhs::kRows;
158*5f39d1b3SJooyung Han     static constexpr int LhsCols = Lhs::kCols;
159*5f39d1b3SJooyung Han     static constexpr int RhsRows = Rhs::kRows;
160*5f39d1b3SJooyung Han     static constexpr int RhsCols = Rhs::kCols;
161*5f39d1b3SJooyung Han 
162*5f39d1b3SJooyung Han     static_assert(LhsRows == Rows || LhsRows == 1, "");
163*5f39d1b3SJooyung Han     static_assert(RhsRows == Rows || RhsRows == 1, "");
164*5f39d1b3SJooyung Han     static_assert(LhsCols == Cols || LhsCols == 1, "");
165*5f39d1b3SJooyung Han     static_assert(RhsCols == Cols || RhsCols == 1, "");
166*5f39d1b3SJooyung Han     static_assert(ResultBlockType::kRegisterLanes == 1,
167*5f39d1b3SJooyung Han                   "This path is only for scalar values");
168*5f39d1b3SJooyung Han     static_assert(Lhs::kRegisterLanes == 1,
169*5f39d1b3SJooyung Han                   "This path is only for scalar values");
170*5f39d1b3SJooyung Han     static_assert(Rhs::kRegisterLanes == 1,
171*5f39d1b3SJooyung Han                   "This path is only for scalar values");
172*5f39d1b3SJooyung Han 
173*5f39d1b3SJooyung Han     for (int c = 0; c < Cols; c++) {
174*5f39d1b3SJooyung Han       const int lhs_c = LhsCols == Cols ? c : 0;
175*5f39d1b3SJooyung Han       const int rhs_c = RhsCols == Cols ? c : 0;
176*5f39d1b3SJooyung Han       for (int r = 0; r < Rows; r++) {
177*5f39d1b3SJooyung Han         const int lhs_r = LhsRows == Rows ? r : 0;
178*5f39d1b3SJooyung Han         const int rhs_r = RhsRows == Rows ? r : 0;
179*5f39d1b3SJooyung Han         result.buf.reg[r + c * Rows] =
180*5f39d1b3SJooyung Han             Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
181*5f39d1b3SJooyung Han                 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
182*5f39d1b3SJooyung Han       }
183*5f39d1b3SJooyung Han     }
184*5f39d1b3SJooyung Han     return result;
185*5f39d1b3SJooyung Han   }
186*5f39d1b3SJooyung Han };
187*5f39d1b3SJooyung Han 
188*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
189*5f39d1b3SJooyung Han typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
190*5f39d1b3SJooyung Han     const Lhs& lhs, const Rhs& rhs) {
191*5f39d1b3SJooyung Han   using Flip = FlipLhsRhs<Lhs, Rhs>;
192*5f39d1b3SJooyung Han   return BroadcastAddImpl<
193*5f39d1b3SJooyung Han       typename Flip::FlippedLhsType,
194*5f39d1b3SJooyung Han       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
195*5f39d1b3SJooyung Han                                           Flip::FlippedRhs(lhs, rhs));
196*5f39d1b3SJooyung Han }
197*5f39d1b3SJooyung Han 
198*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
199*5f39d1b3SJooyung Han struct BroadcastShiftLeftImpl {
200*5f39d1b3SJooyung Han   using ResultBlockType =
201*5f39d1b3SJooyung Han       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
202*5f39d1b3SJooyung Han   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
203*5f39d1b3SJooyung Han     ResultBlockType result;
204*5f39d1b3SJooyung Han     static constexpr int Rows = ResultBlockType::kRows;
205*5f39d1b3SJooyung Han     static constexpr int Cols = ResultBlockType::kCols;
206*5f39d1b3SJooyung Han     static constexpr int LhsRows = Lhs::kRows;
207*5f39d1b3SJooyung Han     static constexpr int LhsCols = Lhs::kCols;
208*5f39d1b3SJooyung Han     static constexpr int RhsRows = Rhs::kRows;
209*5f39d1b3SJooyung Han     static constexpr int RhsCols = Rhs::kCols;
210*5f39d1b3SJooyung Han 
211*5f39d1b3SJooyung Han     static_assert(LhsRows == Rows || LhsRows == 1, "");
212*5f39d1b3SJooyung Han     static_assert(RhsRows == Rows || RhsRows == 1, "");
213*5f39d1b3SJooyung Han     static_assert(LhsCols == Cols || LhsCols == 1, "");
214*5f39d1b3SJooyung Han     static_assert(RhsCols == Cols || RhsCols == 1, "");
215*5f39d1b3SJooyung Han     static_assert(ResultBlockType::kRegisterLanes == 1,
216*5f39d1b3SJooyung Han                   "This path is only for scalar values");
217*5f39d1b3SJooyung Han     static_assert(Lhs::kRegisterLanes == 1,
218*5f39d1b3SJooyung Han                   "This path is only for scalar values");
219*5f39d1b3SJooyung Han     static_assert(Rhs::kRegisterLanes == 1,
220*5f39d1b3SJooyung Han                   "This path is only for scalar values");
221*5f39d1b3SJooyung Han 
222*5f39d1b3SJooyung Han     for (int c = 0; c < Cols; c++) {
223*5f39d1b3SJooyung Han       const int lhs_c = LhsCols == Cols ? c : 0;
224*5f39d1b3SJooyung Han       const int rhs_c = RhsCols == Cols ? c : 0;
225*5f39d1b3SJooyung Han       for (int r = 0; r < Rows; r++) {
226*5f39d1b3SJooyung Han         const int lhs_r = LhsRows == Rows ? r : 0;
227*5f39d1b3SJooyung Han         const int rhs_r = RhsRows == Rows ? r : 0;
228*5f39d1b3SJooyung Han         result.buf.reg[r + c * Rows] =
229*5f39d1b3SJooyung Han             ShiftLeft(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
230*5f39d1b3SJooyung Han                       rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
231*5f39d1b3SJooyung Han       }
232*5f39d1b3SJooyung Han     }
233*5f39d1b3SJooyung Han     return result;
234*5f39d1b3SJooyung Han   }
235*5f39d1b3SJooyung Han };
236*5f39d1b3SJooyung Han 
237*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
238*5f39d1b3SJooyung Han typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastShiftLeft(
239*5f39d1b3SJooyung Han     const Lhs& lhs, const Rhs& rhs) {
240*5f39d1b3SJooyung Han   using Flip = FlipLhsRhs<Lhs, Rhs>;
241*5f39d1b3SJooyung Han   return BroadcastShiftLeftImpl<
242*5f39d1b3SJooyung Han       typename Flip::FlippedLhsType,
243*5f39d1b3SJooyung Han       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
244*5f39d1b3SJooyung Han                                           Flip::FlippedRhs(lhs, rhs));
245*5f39d1b3SJooyung Han }
246*5f39d1b3SJooyung Han 
247*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
248*5f39d1b3SJooyung Han struct BroadcastSaturatingRoundingDoublingHighMulImpl {
249*5f39d1b3SJooyung Han   using ResultBlockType =
250*5f39d1b3SJooyung Han       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
251*5f39d1b3SJooyung Han   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
252*5f39d1b3SJooyung Han     ResultBlockType result;
253*5f39d1b3SJooyung Han     static constexpr int Rows = ResultBlockType::kRows;
254*5f39d1b3SJooyung Han     static constexpr int Cols = ResultBlockType::kCols;
255*5f39d1b3SJooyung Han     static constexpr int LhsRows = Lhs::kRows;
256*5f39d1b3SJooyung Han     static constexpr int LhsCols = Lhs::kCols;
257*5f39d1b3SJooyung Han     static constexpr int RhsRows = Rhs::kRows;
258*5f39d1b3SJooyung Han     static constexpr int RhsCols = Rhs::kCols;
259*5f39d1b3SJooyung Han 
260*5f39d1b3SJooyung Han     static_assert(LhsRows == Rows || LhsRows == 1, "");
261*5f39d1b3SJooyung Han     static_assert(RhsRows == Rows || RhsRows == 1, "");
262*5f39d1b3SJooyung Han     static_assert(LhsCols == Cols || LhsCols == 1, "");
263*5f39d1b3SJooyung Han     static_assert(RhsCols == Cols || RhsCols == 1, "");
264*5f39d1b3SJooyung Han     static_assert(ResultBlockType::kRegisterLanes == 1,
265*5f39d1b3SJooyung Han                   "This path is only for scalar values");
266*5f39d1b3SJooyung Han     static_assert(Lhs::kRegisterLanes == 1,
267*5f39d1b3SJooyung Han                   "This path is only for scalar values");
268*5f39d1b3SJooyung Han     static_assert(Rhs::kRegisterLanes == 1,
269*5f39d1b3SJooyung Han                   "This path is only for scalar values");
270*5f39d1b3SJooyung Han 
271*5f39d1b3SJooyung Han     for (int c = 0; c < Cols; c++) {
272*5f39d1b3SJooyung Han       const int lhs_c = LhsCols == Cols ? c : 0;
273*5f39d1b3SJooyung Han       const int rhs_c = RhsCols == Cols ? c : 0;
274*5f39d1b3SJooyung Han       for (int r = 0; r < Rows; r++) {
275*5f39d1b3SJooyung Han         const int lhs_r = LhsRows == Rows ? r : 0;
276*5f39d1b3SJooyung Han         const int rhs_r = RhsRows == Rows ? r : 0;
277*5f39d1b3SJooyung Han         result.buf.reg[r + c * Rows] = SaturatingRoundingDoublingHighMul(
278*5f39d1b3SJooyung Han             lhs.buf.reg[lhs_r + lhs_c * LhsRows],
279*5f39d1b3SJooyung Han             rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
280*5f39d1b3SJooyung Han       }
281*5f39d1b3SJooyung Han     }
282*5f39d1b3SJooyung Han     return result;
283*5f39d1b3SJooyung Han   }
284*5f39d1b3SJooyung Han };
285*5f39d1b3SJooyung Han 
286*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
287*5f39d1b3SJooyung Han typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
288*5f39d1b3SJooyung Han BroadcastSaturatingRoundingDoublingHighMul(const Lhs& lhs, const Rhs& rhs) {
289*5f39d1b3SJooyung Han   using Flip = FlipLhsRhs<Lhs, Rhs>;
290*5f39d1b3SJooyung Han   return BroadcastSaturatingRoundingDoublingHighMulImpl<
291*5f39d1b3SJooyung Han       typename Flip::FlippedLhsType,
292*5f39d1b3SJooyung Han       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
293*5f39d1b3SJooyung Han                                           Flip::FlippedRhs(lhs, rhs));
294*5f39d1b3SJooyung Han }
295*5f39d1b3SJooyung Han 
296*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
297*5f39d1b3SJooyung Han struct BroadcastRoundingDivideByPOTImpl {
298*5f39d1b3SJooyung Han   using ResultBlockType =
299*5f39d1b3SJooyung Han       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
300*5f39d1b3SJooyung Han   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
301*5f39d1b3SJooyung Han     ResultBlockType result;
302*5f39d1b3SJooyung Han     static constexpr int Rows = ResultBlockType::kRows;
303*5f39d1b3SJooyung Han     static constexpr int Cols = ResultBlockType::kCols;
304*5f39d1b3SJooyung Han     static constexpr int LhsRows = Lhs::kRows;
305*5f39d1b3SJooyung Han     static constexpr int LhsCols = Lhs::kCols;
306*5f39d1b3SJooyung Han     static constexpr int RhsRows = Rhs::kRows;
307*5f39d1b3SJooyung Han     static constexpr int RhsCols = Rhs::kCols;
308*5f39d1b3SJooyung Han 
309*5f39d1b3SJooyung Han     static_assert(LhsRows == Rows || LhsRows == 1, "");
310*5f39d1b3SJooyung Han     static_assert(RhsRows == Rows || RhsRows == 1, "");
311*5f39d1b3SJooyung Han     static_assert(LhsCols == Cols || LhsCols == 1, "");
312*5f39d1b3SJooyung Han     static_assert(RhsCols == Cols || RhsCols == 1, "");
313*5f39d1b3SJooyung Han     static_assert(ResultBlockType::kRegisterLanes == 1,
314*5f39d1b3SJooyung Han                   "This path is only for scalar values");
315*5f39d1b3SJooyung Han     static_assert(Lhs::kRegisterLanes == 1,
316*5f39d1b3SJooyung Han                   "This path is only for scalar values");
317*5f39d1b3SJooyung Han     static_assert(Rhs::kRegisterLanes == 1,
318*5f39d1b3SJooyung Han                   "This path is only for scalar values");
319*5f39d1b3SJooyung Han 
320*5f39d1b3SJooyung Han     for (int c = 0; c < Cols; c++) {
321*5f39d1b3SJooyung Han       const int lhs_c = LhsCols == Cols ? c : 0;
322*5f39d1b3SJooyung Han       const int rhs_c = RhsCols == Cols ? c : 0;
323*5f39d1b3SJooyung Han       for (int r = 0; r < Rows; r++) {
324*5f39d1b3SJooyung Han         const int lhs_r = LhsRows == Rows ? r : 0;
325*5f39d1b3SJooyung Han         const int rhs_r = RhsRows == Rows ? r : 0;
326*5f39d1b3SJooyung Han         result.buf.reg[r + c * Rows] =
327*5f39d1b3SJooyung Han             RoundingDivideByPOT(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
328*5f39d1b3SJooyung Han                                 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
329*5f39d1b3SJooyung Han       }
330*5f39d1b3SJooyung Han     }
331*5f39d1b3SJooyung Han     return result;
332*5f39d1b3SJooyung Han   }
333*5f39d1b3SJooyung Han };
334*5f39d1b3SJooyung Han 
335*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
336*5f39d1b3SJooyung Han typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type
337*5f39d1b3SJooyung Han BroadcastRoundingDivideByPOT(const Lhs& lhs, const Rhs& rhs) {
338*5f39d1b3SJooyung Han   using Flip = FlipLhsRhs<Lhs, Rhs>;
339*5f39d1b3SJooyung Han   return BroadcastRoundingDivideByPOTImpl<
340*5f39d1b3SJooyung Han       typename Flip::FlippedLhsType,
341*5f39d1b3SJooyung Han       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
342*5f39d1b3SJooyung Han                                           Flip::FlippedRhs(lhs, rhs));
343*5f39d1b3SJooyung Han }
344*5f39d1b3SJooyung Han 
345*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
346*5f39d1b3SJooyung Han struct BroadcastMulImpl {
347*5f39d1b3SJooyung Han   using ResultBlockType =
348*5f39d1b3SJooyung Han       typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
349*5f39d1b3SJooyung Han   static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
350*5f39d1b3SJooyung Han     ResultBlockType result;
351*5f39d1b3SJooyung Han     static constexpr int Rows = ResultBlockType::kRows;
352*5f39d1b3SJooyung Han     static constexpr int Cols = ResultBlockType::kCols;
353*5f39d1b3SJooyung Han     static constexpr int LhsRows = Lhs::kRows;
354*5f39d1b3SJooyung Han     static constexpr int LhsCols = Lhs::kCols;
355*5f39d1b3SJooyung Han     static constexpr int RhsRows = Rhs::kRows;
356*5f39d1b3SJooyung Han     static constexpr int RhsCols = Rhs::kCols;
357*5f39d1b3SJooyung Han     static_assert(ResultBlockType::kRegisterLanes == 1,
358*5f39d1b3SJooyung Han                   "This path is only for scalar values");
359*5f39d1b3SJooyung Han     static_assert(Lhs::kRegisterLanes == 1,
360*5f39d1b3SJooyung Han                   "This path is only for scalar values");
361*5f39d1b3SJooyung Han     static_assert(Rhs::kRegisterLanes == 1,
362*5f39d1b3SJooyung Han                   "This path is only for scalar values");
363*5f39d1b3SJooyung Han 
364*5f39d1b3SJooyung Han     static_assert(LhsRows == Rows || LhsRows == 1, "");
365*5f39d1b3SJooyung Han     static_assert(RhsRows == Rows || RhsRows == 1, "");
366*5f39d1b3SJooyung Han     static_assert(LhsCols == Cols || LhsCols == 1, "");
367*5f39d1b3SJooyung Han     static_assert(RhsCols == Cols || RhsCols == 1, "");
368*5f39d1b3SJooyung Han     for (int c = 0; c < Cols; c++) {
369*5f39d1b3SJooyung Han       const int lhs_c = LhsCols == Cols ? c : 0;
370*5f39d1b3SJooyung Han       const int rhs_c = RhsCols == Cols ? c : 0;
371*5f39d1b3SJooyung Han       for (int r = 0; r < Rows; r++) {
372*5f39d1b3SJooyung Han         const int lhs_r = LhsRows == Rows ? r : 0;
373*5f39d1b3SJooyung Han         const int rhs_r = RhsRows == Rows ? r : 0;
374*5f39d1b3SJooyung Han         result.buf.reg[r + c * Rows] =
375*5f39d1b3SJooyung Han             Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
376*5f39d1b3SJooyung Han                 rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
377*5f39d1b3SJooyung Han       }
378*5f39d1b3SJooyung Han     }
379*5f39d1b3SJooyung Han     return result;
380*5f39d1b3SJooyung Han   }
381*5f39d1b3SJooyung Han };
382*5f39d1b3SJooyung Han 
383*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs>
384*5f39d1b3SJooyung Han typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
385*5f39d1b3SJooyung Han     const Lhs& lhs, const Rhs& rhs) {
386*5f39d1b3SJooyung Han   using Flip = FlipLhsRhs<Lhs, Rhs>;
387*5f39d1b3SJooyung Han   return BroadcastMulImpl<
388*5f39d1b3SJooyung Han       typename Flip::FlippedLhsType,
389*5f39d1b3SJooyung Han       typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
390*5f39d1b3SJooyung Han                                           Flip::FlippedRhs(lhs, rhs));
391*5f39d1b3SJooyung Han }
392*5f39d1b3SJooyung Han 
393*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs, typename Acc>
394*5f39d1b3SJooyung Han struct BroadcastMulAddImpl {
395*5f39d1b3SJooyung Han   static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
396*5f39d1b3SJooyung Han     static constexpr int Rows = Acc::kRows;
397*5f39d1b3SJooyung Han     static constexpr int Cols = Acc::kCols;
398*5f39d1b3SJooyung Han     static constexpr int LhsRows = Lhs::kRows;
399*5f39d1b3SJooyung Han     static constexpr int LhsCols = Lhs::kCols;
400*5f39d1b3SJooyung Han     static constexpr int RhsRows = Rhs::kRows;
401*5f39d1b3SJooyung Han     static constexpr int RhsCols = Rhs::kCols;
402*5f39d1b3SJooyung Han     static_assert(Acc::kRegisterLanes == 1,
403*5f39d1b3SJooyung Han                   "This path is only for scalar values");
404*5f39d1b3SJooyung Han     static_assert(Lhs::kRegisterLanes == 1,
405*5f39d1b3SJooyung Han                   "This path is only for scalar values");
406*5f39d1b3SJooyung Han     static_assert(Rhs::kRegisterLanes == 1,
407*5f39d1b3SJooyung Han                   "This path is only for scalar values");
408*5f39d1b3SJooyung Han 
409*5f39d1b3SJooyung Han     static_assert(LhsRows == Rows || LhsRows == 1, "");
410*5f39d1b3SJooyung Han     static_assert(RhsRows == Rows || RhsRows == 1, "");
411*5f39d1b3SJooyung Han     static_assert(LhsCols == Cols || LhsCols == 1, "");
412*5f39d1b3SJooyung Han     static_assert(RhsCols == Cols || RhsCols == 1, "");
413*5f39d1b3SJooyung Han     for (int c = 0; c < Cols; c++) {
414*5f39d1b3SJooyung Han       const int lhs_c = LhsCols == Cols ? c : 0;
415*5f39d1b3SJooyung Han       const int rhs_c = RhsCols == Cols ? c : 0;
416*5f39d1b3SJooyung Han       for (int r = 0; r < Rows; r++) {
417*5f39d1b3SJooyung Han         const int lhs_r = LhsRows == Rows ? r : 0;
418*5f39d1b3SJooyung Han         const int rhs_r = RhsRows == Rows ? r : 0;
419*5f39d1b3SJooyung Han         MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
420*5f39d1b3SJooyung Han                rhs.buf.reg[rhs_r + rhs_c * RhsRows],
421*5f39d1b3SJooyung Han                &acc->buf.reg[r + c * Rows]);
422*5f39d1b3SJooyung Han       }
423*5f39d1b3SJooyung Han     }
424*5f39d1b3SJooyung Han   }
425*5f39d1b3SJooyung Han };
426*5f39d1b3SJooyung Han 
427*5f39d1b3SJooyung Han template <typename Lhs, typename Rhs, typename Acc>
428*5f39d1b3SJooyung Han void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
429*5f39d1b3SJooyung Han   using Flip = FlipLhsRhs<Lhs, Rhs>;
430*5f39d1b3SJooyung Han   BroadcastMulAddImpl<typename Flip::FlippedLhsType,
431*5f39d1b3SJooyung Han                       typename Flip::FlippedRhsType,
432*5f39d1b3SJooyung Han                       Acc>::Run(Flip::FlippedLhs(lhs, rhs),
433*5f39d1b3SJooyung Han                                 Flip::FlippedRhs(lhs, rhs), acc);
434*5f39d1b3SJooyung Han }
435*5f39d1b3SJooyung Han 
436*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename SrcObjectType>
437*5f39d1b3SJooyung Han struct LoadImpl {
438*5f39d1b3SJooyung Han   static_assert(std::is_same<SrcObjectType, void>::value,
439*5f39d1b3SJooyung Han                 "This generic impl should never be hit");
440*5f39d1b3SJooyung Han };
441*5f39d1b3SJooyung Han 
442*5f39d1b3SJooyung Han template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
443*5f39d1b3SJooyung Han struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
444*5f39d1b3SJooyung Han                 MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
445*5f39d1b3SJooyung Han   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
446*5f39d1b3SJooyung Han   using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
447*5f39d1b3SJooyung Han   static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
448*5f39d1b3SJooyung Han     RegisterBlockType result;
449*5f39d1b3SJooyung Han     int i = 0;
450*5f39d1b3SJooyung Han     for (int c = 0; c < Cols; c++) {
451*5f39d1b3SJooyung Han       const ScalarType* src_ptr = src.data(row, col + c);
452*5f39d1b3SJooyung Han       for (int r = 0; r < Rows; r++) {
453*5f39d1b3SJooyung Han         result.buf.reg[i++] = *src_ptr++;
454*5f39d1b3SJooyung Han       }
455*5f39d1b3SJooyung Han     }
456*5f39d1b3SJooyung Han     return result;
457*5f39d1b3SJooyung Han   }
458*5f39d1b3SJooyung Han };
459*5f39d1b3SJooyung Han 
460*5f39d1b3SJooyung Han template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
461*5f39d1b3SJooyung Han           VectorShape Shape>
462*5f39d1b3SJooyung Han struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
463*5f39d1b3SJooyung Han                 VectorMap<SrcScalarType, Shape>> {
464*5f39d1b3SJooyung Han   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
465*5f39d1b3SJooyung Han   using SrcObjectType = VectorMap<SrcScalarType, Shape>;
466*5f39d1b3SJooyung Han   static RegisterBlockType Run(const SrcObjectType& src, int pos) {
467*5f39d1b3SJooyung Han     static_assert(Shape == VectorShape::Col || Rows == 1, "");
468*5f39d1b3SJooyung Han     static_assert(Shape == VectorShape::Row || Cols == 1, "");
469*5f39d1b3SJooyung Han     RegisterBlockType result;
470*5f39d1b3SJooyung Han     for (int i = 0; i < Rows * Cols; i++) {
471*5f39d1b3SJooyung Han       result.buf.reg[i] = src(pos + i);
472*5f39d1b3SJooyung Han     }
473*5f39d1b3SJooyung Han     return result;
474*5f39d1b3SJooyung Han   }
475*5f39d1b3SJooyung Han };
476*5f39d1b3SJooyung Han 
477*5f39d1b3SJooyung Han template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
478*5f39d1b3SJooyung Han           VectorShape Shape>
479*5f39d1b3SJooyung Han struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
480*5f39d1b3SJooyung Han                 VectorDup<SrcScalarType, Shape>> {
481*5f39d1b3SJooyung Han   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
482*5f39d1b3SJooyung Han   using SrcObjectType = VectorDup<SrcScalarType, Shape>;
483*5f39d1b3SJooyung Han   static RegisterBlockType Run(const SrcObjectType& src, int) {
484*5f39d1b3SJooyung Han     static_assert(Shape == VectorShape::Col || Rows == 1, "");
485*5f39d1b3SJooyung Han     static_assert(Shape == VectorShape::Row || Cols == 1, "");
486*5f39d1b3SJooyung Han     RegisterBlockType result;
487*5f39d1b3SJooyung Han     for (int i = 0; i < Rows * Cols; i++) {
488*5f39d1b3SJooyung Han       result.buf.reg[i] = src(0);
489*5f39d1b3SJooyung Han     }
490*5f39d1b3SJooyung Han     return result;
491*5f39d1b3SJooyung Han   }
492*5f39d1b3SJooyung Han };
493*5f39d1b3SJooyung Han 
494*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename SrcObjectType>
495*5f39d1b3SJooyung Han RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
496*5f39d1b3SJooyung Han   return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
497*5f39d1b3SJooyung Han }
498*5f39d1b3SJooyung Han 
499*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename SrcObjectType>
500*5f39d1b3SJooyung Han RegisterBlockType Load(const SrcObjectType& src, int pos) {
501*5f39d1b3SJooyung Han   return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
502*5f39d1b3SJooyung Han }
503*5f39d1b3SJooyung Han 
504*5f39d1b3SJooyung Han template <typename RegisterBlockType>
505*5f39d1b3SJooyung Han struct LoadContiguousImpl {
506*5f39d1b3SJooyung Han   using ScalarType = typename RegisterBlockType::ScalarType;
507*5f39d1b3SJooyung Han   static_assert(RegisterBlockType::kRegisterLanes == 1,
508*5f39d1b3SJooyung Han                 "This path is only for scalar values");
509*5f39d1b3SJooyung Han   static RegisterBlockType Run(const ScalarType* src) {
510*5f39d1b3SJooyung Han     RegisterBlockType result;
511*5f39d1b3SJooyung Han     for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
512*5f39d1b3SJooyung Han       result.buf.reg[i] = src[i];
513*5f39d1b3SJooyung Han     }
514*5f39d1b3SJooyung Han     return result;
515*5f39d1b3SJooyung Han   }
516*5f39d1b3SJooyung Han };
517*5f39d1b3SJooyung Han 
518*5f39d1b3SJooyung Han template <typename RegisterBlockType>
519*5f39d1b3SJooyung Han RegisterBlockType LoadContiguous(
520*5f39d1b3SJooyung Han     const typename RegisterBlockType::ScalarType* src) {
521*5f39d1b3SJooyung Han   return LoadContiguousImpl<RegisterBlockType>::Run(src);
522*5f39d1b3SJooyung Han }
523*5f39d1b3SJooyung Han 
524*5f39d1b3SJooyung Han template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
525*5f39d1b3SJooyung Han struct LoadForBroadcastingShape {};
526*5f39d1b3SJooyung Han 
527*5f39d1b3SJooyung Han template <int BroadcastRows, int BroadcastCols, typename ScalarType,
528*5f39d1b3SJooyung Han           VectorShape Shape>
529*5f39d1b3SJooyung Han struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
530*5f39d1b3SJooyung Han                                 VectorMap<ScalarType, Shape>> {
531*5f39d1b3SJooyung Han   static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
532*5f39d1b3SJooyung Han   static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
533*5f39d1b3SJooyung Han };
534*5f39d1b3SJooyung Han 
535*5f39d1b3SJooyung Han template <int BroadcastRows, int BroadcastCols, typename ScalarType,
536*5f39d1b3SJooyung Han           VectorShape Shape>
537*5f39d1b3SJooyung Han struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
538*5f39d1b3SJooyung Han                                 VectorDup<ScalarType, Shape>> {
539*5f39d1b3SJooyung Han   static constexpr int kRows = 1;
540*5f39d1b3SJooyung Han   static constexpr int kCols = 1;
541*5f39d1b3SJooyung Han };
542*5f39d1b3SJooyung Han 
543*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename SrcObjectType>
544*5f39d1b3SJooyung Han struct LoadForBroadcastingRegisterBlock {
545*5f39d1b3SJooyung Han   using Shape =
546*5f39d1b3SJooyung Han       LoadForBroadcastingShape<RegisterBlockType::kRows,
547*5f39d1b3SJooyung Han                                RegisterBlockType::kCols, SrcObjectType>;
548*5f39d1b3SJooyung Han   using ScalarType = typename RegisterBlockType::ScalarType;
549*5f39d1b3SJooyung Han   using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
550*5f39d1b3SJooyung Han };
551*5f39d1b3SJooyung Han 
552*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename SrcObjectType>
553*5f39d1b3SJooyung Han struct LoadForBroadcastingImpl {
554*5f39d1b3SJooyung Han   static_assert(std::is_same<SrcObjectType, void>::value,
555*5f39d1b3SJooyung Han                 "This generic impl should never be hit");
556*5f39d1b3SJooyung Han };
557*5f39d1b3SJooyung Han 
558*5f39d1b3SJooyung Han template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
559*5f39d1b3SJooyung Han           VectorShape Shape>
560*5f39d1b3SJooyung Han struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
561*5f39d1b3SJooyung Han                                VectorMap<SrcScalarType, Shape>> {
562*5f39d1b3SJooyung Han   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
563*5f39d1b3SJooyung Han   using SrcObjectType = VectorMap<SrcScalarType, Shape>;
564*5f39d1b3SJooyung Han   using ResultBlockType =
565*5f39d1b3SJooyung Han       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
566*5f39d1b3SJooyung Han                                                 SrcObjectType>::Type;
567*5f39d1b3SJooyung Han   static_assert(ResultBlockType::kRegisterLanes == 1,
568*5f39d1b3SJooyung Han                 "This path is only for scalar values");
569*5f39d1b3SJooyung Han   static ResultBlockType Run(const SrcObjectType& src, int pos) {
570*5f39d1b3SJooyung Han     ResultBlockType result;
571*5f39d1b3SJooyung Han     for (int c = 0; c < ResultBlockType::kCols; c++) {
572*5f39d1b3SJooyung Han       for (int r = 0; r < ResultBlockType::kRows; r++) {
573*5f39d1b3SJooyung Han         const int i = Shape == VectorShape::Col ? r : c;
574*5f39d1b3SJooyung Han         result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
575*5f39d1b3SJooyung Han       }
576*5f39d1b3SJooyung Han     }
577*5f39d1b3SJooyung Han     return result;
578*5f39d1b3SJooyung Han   }
579*5f39d1b3SJooyung Han };
580*5f39d1b3SJooyung Han 
581*5f39d1b3SJooyung Han template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
582*5f39d1b3SJooyung Han           VectorShape Shape>
583*5f39d1b3SJooyung Han struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
584*5f39d1b3SJooyung Han                                VectorDup<SrcScalarType, Shape>> {
585*5f39d1b3SJooyung Han   using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
586*5f39d1b3SJooyung Han   using SrcObjectType = VectorDup<SrcScalarType, Shape>;
587*5f39d1b3SJooyung Han   using ResultBlockType =
588*5f39d1b3SJooyung Han       typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
589*5f39d1b3SJooyung Han                                                 SrcObjectType>::Type;
590*5f39d1b3SJooyung Han   static_assert(ResultBlockType::kRegisterLanes == 1,
591*5f39d1b3SJooyung Han                 "This path is only for scalar values");
592*5f39d1b3SJooyung Han   static ResultBlockType Run(const SrcObjectType& src, int) {
593*5f39d1b3SJooyung Han     ResultBlockType result;
594*5f39d1b3SJooyung Han     for (int c = 0; c < ResultBlockType::kCols; c++) {
595*5f39d1b3SJooyung Han       for (int r = 0; r < ResultBlockType::kRows; r++) {
596*5f39d1b3SJooyung Han         result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
597*5f39d1b3SJooyung Han       }
598*5f39d1b3SJooyung Han     }
599*5f39d1b3SJooyung Han     return result;
600*5f39d1b3SJooyung Han   }
601*5f39d1b3SJooyung Han };
602*5f39d1b3SJooyung Han 
603*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename SrcObjectType>
604*5f39d1b3SJooyung Han typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
605*5f39d1b3SJooyung Han                                           SrcObjectType>::Type
606*5f39d1b3SJooyung Han LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
607*5f39d1b3SJooyung Han   return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
608*5f39d1b3SJooyung Han       src, row, col);
609*5f39d1b3SJooyung Han }
610*5f39d1b3SJooyung Han 
611*5f39d1b3SJooyung Han template <typename RegisterBlockType, typename SrcObjectType>
612*5f39d1b3SJooyung Han typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
613*5f39d1b3SJooyung Han                                           SrcObjectType>::Type
614*5f39d1b3SJooyung Han LoadForBroadcasting(const SrcObjectType& src, int pos) {
615*5f39d1b3SJooyung Han   return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
616*5f39d1b3SJooyung Han                                                                         pos);
617*5f39d1b3SJooyung Han }
618*5f39d1b3SJooyung Han 
619*5f39d1b3SJooyung Han template <int ConstantValue, typename RegisterBlockType>
620*5f39d1b3SJooyung Han struct AddConstantImpl {
621*5f39d1b3SJooyung Han   static void Run(RegisterBlockType* block) {
622*5f39d1b3SJooyung Han     using RegisterType = typename RegisterBlockType::RegisterType;
623*5f39d1b3SJooyung Han     const RegisterType dup = Dup<RegisterType>(ConstantValue);
624*5f39d1b3SJooyung Han     for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
625*5f39d1b3SJooyung Han       block->buf.reg[i] = Add(block->buf.reg[i], dup);
626*5f39d1b3SJooyung Han     }
627*5f39d1b3SJooyung Han   }
628*5f39d1b3SJooyung Han };
629*5f39d1b3SJooyung Han 
630*5f39d1b3SJooyung Han template <typename RegisterBlockType>
631*5f39d1b3SJooyung Han struct AddConstantImpl<0, RegisterBlockType> {
632*5f39d1b3SJooyung Han   static void Run(RegisterBlockType*) {
633*5f39d1b3SJooyung Han     // This is a no-op.
634*5f39d1b3SJooyung Han   }
635*5f39d1b3SJooyung Han };
636*5f39d1b3SJooyung Han 
637*5f39d1b3SJooyung Han template <int ConstantValue, typename RegisterBlockType>
638*5f39d1b3SJooyung Han void AddConstant(RegisterBlockType* block) {
639*5f39d1b3SJooyung Han   AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
640*5f39d1b3SJooyung Han }
641*5f39d1b3SJooyung Han 
642*5f39d1b3SJooyung Han template <int N>
643*5f39d1b3SJooyung Han using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
644*5f39d1b3SJooyung Han template <int N>
645*5f39d1b3SJooyung Han using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
646*5f39d1b3SJooyung Han template <int N>
647*5f39d1b3SJooyung Han using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
648*5f39d1b3SJooyung Han template <int N>
649*5f39d1b3SJooyung Han using RegBufferInt8 = RegisterBuffer<std::int8_t, N>;
650*5f39d1b3SJooyung Han template <int R, int C>
651*5f39d1b3SJooyung Han using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
652*5f39d1b3SJooyung Han template <int R, int C>
653*5f39d1b3SJooyung Han using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
654*5f39d1b3SJooyung Han template <int R, int C>
655*5f39d1b3SJooyung Han using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;
656*5f39d1b3SJooyung Han template <int R, int C>
657*5f39d1b3SJooyung Han using RegBlockInt8 = RegisterBlock<std::int8_t, R, C>;
658*5f39d1b3SJooyung Han 
659*5f39d1b3SJooyung Han }  // end namespace gemmlowp
660*5f39d1b3SJooyung Han 
661*5f39d1b3SJooyung Han #if defined GEMMLOWP_NEON
662*5f39d1b3SJooyung Han #include "simd_wrappers_neon.h"
663*5f39d1b3SJooyung Han #elif defined GEMMLOWP_SSE4
664*5f39d1b3SJooyung Han #include "simd_wrappers_sse.h"
665*5f39d1b3SJooyung Han #elif defined GEMMLOWP_MSA
666*5f39d1b3SJooyung Han #include "simd_wrappers_msa.h"
667*5f39d1b3SJooyung Han #endif
668*5f39d1b3SJooyung Han 
669*5f39d1b3SJooyung Han #endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
670