xref: /aosp_15_r20/external/executorch/kernels/optimized/blas/BlasKernel.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/optimized/blas/BlasKernel.h>
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #ifdef __aarch64__
12*523fa7a6SAndroid Build Coastguard Worker #include <arm_neon.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <cpuinfo.h>
14*523fa7a6SAndroid Build Coastguard Worker #endif
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker using torch::executor::BFloat16;
17*523fa7a6SAndroid Build Coastguard Worker 
18*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
19*523fa7a6SAndroid Build Coastguard Worker namespace cpublas {
20*523fa7a6SAndroid Build Coastguard Worker namespace internal {
21*523fa7a6SAndroid Build Coastguard Worker #ifdef __aarch64__
f32_fma(float32x4_t a,float32x4_t b,float32x4_t c)22*523fa7a6SAndroid Build Coastguard Worker static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
23*523fa7a6SAndroid Build Coastguard Worker #ifdef __ARM_FEATURE_FMA
24*523fa7a6SAndroid Build Coastguard Worker   return vfmaq_f32(a, b, c);
25*523fa7a6SAndroid Build Coastguard Worker #else
26*523fa7a6SAndroid Build Coastguard Worker   return vaddq_f32(a, vmulq_f32(b, c));
27*523fa7a6SAndroid Build Coastguard Worker #endif // __ARM_FEATURE_FMA
28*523fa7a6SAndroid Build Coastguard Worker }
29*523fa7a6SAndroid Build Coastguard Worker 
30*523fa7a6SAndroid Build Coastguard Worker // The below reduce overload and fp16_dot_with_fp32_arith are adapted
31*523fa7a6SAndroid Build Coastguard Worker // from llama.cpp's ggml_vec_dot_f32 and surrounding utility
32*523fa7a6SAndroid Build Coastguard Worker // functions. See NOTE [ GGML Copyright Notice ] above for the
33*523fa7a6SAndroid Build Coastguard Worker // required notice.
34*523fa7a6SAndroid Build Coastguard Worker 
35*523fa7a6SAndroid Build Coastguard Worker // We need the shift for reduce(), hence the extra constants.
36*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32ElementsPerIterationShift = 5;
37*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32ElementsPerIteration = 1
38*523fa7a6SAndroid Build Coastguard Worker     << kF32ElementsPerIterationShift;
39*523fa7a6SAndroid Build Coastguard Worker static_assert(kF32ElementsPerIteration == 32);
40*523fa7a6SAndroid Build Coastguard Worker 
41*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32ElementsPerRegisterShift = 2;
42*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32ElementsPerRegister = 1
43*523fa7a6SAndroid Build Coastguard Worker     << kF32ElementsPerRegisterShift;
44*523fa7a6SAndroid Build Coastguard Worker static_assert(kF32ElementsPerRegister == 4);
45*523fa7a6SAndroid Build Coastguard Worker 
46*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32RegisterPairsPerIteration = 4;
47*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32RegistersPerIteration =
48*523fa7a6SAndroid Build Coastguard Worker     kF32RegisterPairsPerIteration * 2;
49*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32RegistersPerIterationShift = 3;
50*523fa7a6SAndroid Build Coastguard Worker static_assert(
51*523fa7a6SAndroid Build Coastguard Worker     kF32RegistersPerIteration ==
52*523fa7a6SAndroid Build Coastguard Worker     kF32ElementsPerIteration / kF32ElementsPerRegister);
53*523fa7a6SAndroid Build Coastguard Worker static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);
54*523fa7a6SAndroid Build Coastguard Worker 
reduce(float32x4_t x[kF32RegistersPerIteration])55*523fa7a6SAndroid Build Coastguard Worker static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
56*523fa7a6SAndroid Build Coastguard Worker   int offset = kF32RegistersPerIteration;
57*523fa7a6SAndroid Build Coastguard Worker   utils::ForcedUnroll<kF32RegistersPerIterationShift>{}(
58*523fa7a6SAndroid Build Coastguard Worker       [&offset, &x](auto idx) ET_INLINE_ATTRIBUTE {
59*523fa7a6SAndroid Build Coastguard Worker         offset /= 2;
60*523fa7a6SAndroid Build Coastguard Worker         for (int i = 0; i < offset; ++i) {
61*523fa7a6SAndroid Build Coastguard Worker           x[i] = vaddq_f32(x[i], x[offset + i]);
62*523fa7a6SAndroid Build Coastguard Worker         }
63*523fa7a6SAndroid Build Coastguard Worker       });
64*523fa7a6SAndroid Build Coastguard Worker   return vaddvq_f32(x[0]);
65*523fa7a6SAndroid Build Coastguard Worker }
66*523fa7a6SAndroid Build Coastguard Worker 
to_bfloat16(uint16x4_t u16)67*523fa7a6SAndroid Build Coastguard Worker static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
68*523fa7a6SAndroid Build Coastguard Worker   int32x4_t shift = vdupq_n_s32(16);
69*523fa7a6SAndroid Build Coastguard Worker   return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
70*523fa7a6SAndroid Build Coastguard Worker }
71*523fa7a6SAndroid Build Coastguard Worker 
72*523fa7a6SAndroid Build Coastguard Worker static ET_INLINE float32x4_t
f32_fma_bf16(float32x4_t a,uint16x4_t b,uint16x4_t c)73*523fa7a6SAndroid Build Coastguard Worker f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
74*523fa7a6SAndroid Build Coastguard Worker   return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
75*523fa7a6SAndroid Build Coastguard Worker }
76*523fa7a6SAndroid Build Coastguard Worker 
77*523fa7a6SAndroid Build Coastguard Worker #define ET_TARGET_ARM_BF16_ATTRIBUTE \
78*523fa7a6SAndroid Build Coastguard Worker   __attribute__((target("arch=armv8.2-a+bf16")))
79*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t
f32_dot_bf16(float32x4_t a,bfloat16x8_t b,bfloat16x8_t c)80*523fa7a6SAndroid Build Coastguard Worker f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
81*523fa7a6SAndroid Build Coastguard Worker   return vbfdotq_f32(a, b, c);
82*523fa7a6SAndroid Build Coastguard Worker }
83*523fa7a6SAndroid Build Coastguard Worker 
84*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop_bfdot(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)85*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith_main_inner_loop_bfdot(
86*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec1,
87*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec2,
88*523fa7a6SAndroid Build Coastguard Worker     float32x4_t sum[kF32RegistersPerIteration],
89*523fa7a6SAndroid Build Coastguard Worker     int registerPairIndex) {
90*523fa7a6SAndroid Build Coastguard Worker   const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
91*523fa7a6SAndroid Build Coastguard Worker       &vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
92*523fa7a6SAndroid Build Coastguard Worker   const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
93*523fa7a6SAndroid Build Coastguard Worker       &vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
94*523fa7a6SAndroid Build Coastguard Worker   sum[registerPairIndex] =
95*523fa7a6SAndroid Build Coastguard Worker       f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
96*523fa7a6SAndroid Build Coastguard Worker }
97*523fa7a6SAndroid Build Coastguard Worker 
dot_with_fp32_arith_main_inner_loop_no_bfdot(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)98*523fa7a6SAndroid Build Coastguard Worker static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
99*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec1,
100*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec2,
101*523fa7a6SAndroid Build Coastguard Worker     float32x4_t sum[kF32RegistersPerIteration],
102*523fa7a6SAndroid Build Coastguard Worker     int registerPairIndex) {
103*523fa7a6SAndroid Build Coastguard Worker   const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
104*523fa7a6SAndroid Build Coastguard Worker       &vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
105*523fa7a6SAndroid Build Coastguard Worker   const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
106*523fa7a6SAndroid Build Coastguard Worker       &vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
107*523fa7a6SAndroid Build Coastguard Worker 
108*523fa7a6SAndroid Build Coastguard Worker   sum[2 * registerPairIndex] = f32_fma_bf16(
109*523fa7a6SAndroid Build Coastguard Worker       sum[2 * registerPairIndex],
110*523fa7a6SAndroid Build Coastguard Worker       vget_low_u16(temp_vec1),
111*523fa7a6SAndroid Build Coastguard Worker       vget_low_u16(temp_vec2));
112*523fa7a6SAndroid Build Coastguard Worker   sum[2 * registerPairIndex + 1] = f32_fma_bf16(
113*523fa7a6SAndroid Build Coastguard Worker       sum[2 * registerPairIndex + 1],
114*523fa7a6SAndroid Build Coastguard Worker       vget_high_u16(temp_vec1),
115*523fa7a6SAndroid Build Coastguard Worker       vget_high_u16(temp_vec2));
116*523fa7a6SAndroid Build Coastguard Worker }
117*523fa7a6SAndroid Build Coastguard Worker 
118*523fa7a6SAndroid Build Coastguard Worker template <bool useBfdot>
119*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)120*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith_main_inner_loop(
121*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec1,
122*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec2,
123*523fa7a6SAndroid Build Coastguard Worker     float32x4_t sum[kF32RegistersPerIteration],
124*523fa7a6SAndroid Build Coastguard Worker     int registerPairIndex) {
125*523fa7a6SAndroid Build Coastguard Worker   if constexpr (useBfdot) {
126*523fa7a6SAndroid Build Coastguard Worker     dot_with_fp32_arith_main_inner_loop_bfdot(
127*523fa7a6SAndroid Build Coastguard Worker         vec1, vec2, sum, registerPairIndex);
128*523fa7a6SAndroid Build Coastguard Worker   } else {
129*523fa7a6SAndroid Build Coastguard Worker     dot_with_fp32_arith_main_inner_loop_no_bfdot(
130*523fa7a6SAndroid Build Coastguard Worker         vec1, vec2, sum, registerPairIndex);
131*523fa7a6SAndroid Build Coastguard Worker   }
132*523fa7a6SAndroid Build Coastguard Worker }
133*523fa7a6SAndroid Build Coastguard Worker 
dot_with_fp32_arith_vectorized_tail_inner_loop(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t * tailSum,int idx)134*523fa7a6SAndroid Build Coastguard Worker static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
135*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec1,
136*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec2,
137*523fa7a6SAndroid Build Coastguard Worker     float32x4_t* tailSum,
138*523fa7a6SAndroid Build Coastguard Worker     int idx) {
139*523fa7a6SAndroid Build Coastguard Worker   const auto temp_vec1 =
140*523fa7a6SAndroid Build Coastguard Worker       vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
141*523fa7a6SAndroid Build Coastguard Worker   const auto temp_vec2 =
142*523fa7a6SAndroid Build Coastguard Worker       vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
143*523fa7a6SAndroid Build Coastguard Worker   *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
144*523fa7a6SAndroid Build Coastguard Worker }
145*523fa7a6SAndroid Build Coastguard Worker 
146*523fa7a6SAndroid Build Coastguard Worker namespace {
147*523fa7a6SAndroid Build Coastguard Worker template <int n>
148*523fa7a6SAndroid Build Coastguard Worker struct ForcedUnrollTargetBFloat16 {
149*523fa7a6SAndroid Build Coastguard Worker   template <typename Func>
operator ()executorch::cpublas::internal::__anon6f41bdbc0211::ForcedUnrollTargetBFloat16150*523fa7a6SAndroid Build Coastguard Worker   ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
151*523fa7a6SAndroid Build Coastguard Worker     ForcedUnrollTargetBFloat16<n - 1>{}(f);
152*523fa7a6SAndroid Build Coastguard Worker     f(n - 1);
153*523fa7a6SAndroid Build Coastguard Worker   }
154*523fa7a6SAndroid Build Coastguard Worker };
155*523fa7a6SAndroid Build Coastguard Worker 
156*523fa7a6SAndroid Build Coastguard Worker template <>
157*523fa7a6SAndroid Build Coastguard Worker struct ForcedUnrollTargetBFloat16<1> {
158*523fa7a6SAndroid Build Coastguard Worker   template <typename Func>
operator ()executorch::cpublas::internal::__anon6f41bdbc0211::ForcedUnrollTargetBFloat16159*523fa7a6SAndroid Build Coastguard Worker   ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
160*523fa7a6SAndroid Build Coastguard Worker     f(0);
161*523fa7a6SAndroid Build Coastguard Worker   }
162*523fa7a6SAndroid Build Coastguard Worker };
163*523fa7a6SAndroid Build Coastguard Worker 
164*523fa7a6SAndroid Build Coastguard Worker } // namespace
165*523fa7a6SAndroid Build Coastguard Worker 
166*523fa7a6SAndroid Build Coastguard Worker template <typename T, bool useBFloat16Dot>
167*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE float
dot_with_fp32_arith(const T * vec1,const T * vec2,int64_t len)168*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
169*523fa7a6SAndroid Build Coastguard Worker   float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
170*523fa7a6SAndroid Build Coastguard Worker   const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
171*523fa7a6SAndroid Build Coastguard Worker   for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
172*523fa7a6SAndroid Build Coastguard Worker     const auto* vec1_ = vec1 + j;
173*523fa7a6SAndroid Build Coastguard Worker     const auto* vec2_ = vec2 + j;
174*523fa7a6SAndroid Build Coastguard Worker     ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}(
175*523fa7a6SAndroid Build Coastguard Worker         [vec1_, vec2_, &sum](auto k)
176*523fa7a6SAndroid Build Coastguard Worker             ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE {
177*523fa7a6SAndroid Build Coastguard Worker               dot_with_fp32_arith_main_inner_loop<useBFloat16Dot>(
178*523fa7a6SAndroid Build Coastguard Worker                   vec1_, vec2_, sum, k);
179*523fa7a6SAndroid Build Coastguard Worker             });
180*523fa7a6SAndroid Build Coastguard Worker   }
181*523fa7a6SAndroid Build Coastguard Worker   auto reducedSum = reduce(sum);
182*523fa7a6SAndroid Build Coastguard Worker 
183*523fa7a6SAndroid Build Coastguard Worker   // First-tier tail fixup: make sure we handle workloads that can
184*523fa7a6SAndroid Build Coastguard Worker   // benefit from vectorization, but don't fit into our fully unrolled
185*523fa7a6SAndroid Build Coastguard Worker   // loop above.
186*523fa7a6SAndroid Build Coastguard Worker   float32x4_t tailSum = vdupq_n_f32(0);
187*523fa7a6SAndroid Build Coastguard Worker   const auto len_aligned_4 = len & ~3;
188*523fa7a6SAndroid Build Coastguard Worker   for (int j = len_aligned; j < len_aligned_4; j += 4) {
189*523fa7a6SAndroid Build Coastguard Worker     dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
190*523fa7a6SAndroid Build Coastguard Worker   }
191*523fa7a6SAndroid Build Coastguard Worker   auto reducedTail = vpaddq_f32(tailSum, tailSum);
192*523fa7a6SAndroid Build Coastguard Worker   reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
193*523fa7a6SAndroid Build Coastguard Worker 
194*523fa7a6SAndroid Build Coastguard Worker   // Second-tier tail fixup: handle all workloads.
195*523fa7a6SAndroid Build Coastguard Worker   for (int j = len_aligned_4; j < len; ++j) {
196*523fa7a6SAndroid Build Coastguard Worker     reducedSum += vec1[j] * vec2[j];
197*523fa7a6SAndroid Build Coastguard Worker   }
198*523fa7a6SAndroid Build Coastguard Worker   return reducedSum;
199*523fa7a6SAndroid Build Coastguard Worker }
200*523fa7a6SAndroid Build Coastguard Worker 
bf16_dot_with_fp32_arith(const BFloat16 * vec1,const BFloat16 * vec2,int64_t len)201*523fa7a6SAndroid Build Coastguard Worker float bf16_dot_with_fp32_arith(
202*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec1,
203*523fa7a6SAndroid Build Coastguard Worker     const BFloat16* vec2,
204*523fa7a6SAndroid Build Coastguard Worker     int64_t len) {
205*523fa7a6SAndroid Build Coastguard Worker   if (cpuinfo_has_arm_bf16()) {
206*523fa7a6SAndroid Build Coastguard Worker     return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
207*523fa7a6SAndroid Build Coastguard Worker   } else {
208*523fa7a6SAndroid Build Coastguard Worker     return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
209*523fa7a6SAndroid Build Coastguard Worker   }
210*523fa7a6SAndroid Build Coastguard Worker }
211*523fa7a6SAndroid Build Coastguard Worker #endif // __aarch64__
212*523fa7a6SAndroid Build Coastguard Worker } // namespace internal
213*523fa7a6SAndroid Build Coastguard Worker } // namespace cpublas
214*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
215