1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/optimized/blas/BlasKernel.h>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #ifdef __aarch64__
12*523fa7a6SAndroid Build Coastguard Worker #include <arm_neon.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <cpuinfo.h>
14*523fa7a6SAndroid Build Coastguard Worker #endif
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker using torch::executor::BFloat16;
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
19*523fa7a6SAndroid Build Coastguard Worker namespace cpublas {
20*523fa7a6SAndroid Build Coastguard Worker namespace internal {
21*523fa7a6SAndroid Build Coastguard Worker #ifdef __aarch64__
f32_fma(float32x4_t a,float32x4_t b,float32x4_t c)22*523fa7a6SAndroid Build Coastguard Worker static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
23*523fa7a6SAndroid Build Coastguard Worker #ifdef __ARM_FEATURE_FMA
24*523fa7a6SAndroid Build Coastguard Worker return vfmaq_f32(a, b, c);
25*523fa7a6SAndroid Build Coastguard Worker #else
26*523fa7a6SAndroid Build Coastguard Worker return vaddq_f32(a, vmulq_f32(b, c));
27*523fa7a6SAndroid Build Coastguard Worker #endif // __ARM_FEATURE_FMA
28*523fa7a6SAndroid Build Coastguard Worker }
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker // The below reduce overload and fp16_dot_with_fp32_arith are adapted
31*523fa7a6SAndroid Build Coastguard Worker // from llama.cpp's ggml_vec_dot_f32 and surrounding utility
32*523fa7a6SAndroid Build Coastguard Worker // functions. See NOTE [ GGML Copyright Notice ] above for the
33*523fa7a6SAndroid Build Coastguard Worker // required notice.
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker // We need the shift for reduce(), hence the extra constants.
36*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32ElementsPerIterationShift = 5;
37*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32ElementsPerIteration = 1
38*523fa7a6SAndroid Build Coastguard Worker << kF32ElementsPerIterationShift;
39*523fa7a6SAndroid Build Coastguard Worker static_assert(kF32ElementsPerIteration == 32);
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32ElementsPerRegisterShift = 2;
42*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32ElementsPerRegister = 1
43*523fa7a6SAndroid Build Coastguard Worker << kF32ElementsPerRegisterShift;
44*523fa7a6SAndroid Build Coastguard Worker static_assert(kF32ElementsPerRegister == 4);
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32RegisterPairsPerIteration = 4;
47*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32RegistersPerIteration =
48*523fa7a6SAndroid Build Coastguard Worker kF32RegisterPairsPerIteration * 2;
49*523fa7a6SAndroid Build Coastguard Worker static constexpr auto kF32RegistersPerIterationShift = 3;
50*523fa7a6SAndroid Build Coastguard Worker static_assert(
51*523fa7a6SAndroid Build Coastguard Worker kF32RegistersPerIteration ==
52*523fa7a6SAndroid Build Coastguard Worker kF32ElementsPerIteration / kF32ElementsPerRegister);
53*523fa7a6SAndroid Build Coastguard Worker static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);
54*523fa7a6SAndroid Build Coastguard Worker
reduce(float32x4_t x[kF32RegistersPerIteration])55*523fa7a6SAndroid Build Coastguard Worker static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
56*523fa7a6SAndroid Build Coastguard Worker int offset = kF32RegistersPerIteration;
57*523fa7a6SAndroid Build Coastguard Worker utils::ForcedUnroll<kF32RegistersPerIterationShift>{}(
58*523fa7a6SAndroid Build Coastguard Worker [&offset, &x](auto idx) ET_INLINE_ATTRIBUTE {
59*523fa7a6SAndroid Build Coastguard Worker offset /= 2;
60*523fa7a6SAndroid Build Coastguard Worker for (int i = 0; i < offset; ++i) {
61*523fa7a6SAndroid Build Coastguard Worker x[i] = vaddq_f32(x[i], x[offset + i]);
62*523fa7a6SAndroid Build Coastguard Worker }
63*523fa7a6SAndroid Build Coastguard Worker });
64*523fa7a6SAndroid Build Coastguard Worker return vaddvq_f32(x[0]);
65*523fa7a6SAndroid Build Coastguard Worker }
66*523fa7a6SAndroid Build Coastguard Worker
to_bfloat16(uint16x4_t u16)67*523fa7a6SAndroid Build Coastguard Worker static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
68*523fa7a6SAndroid Build Coastguard Worker int32x4_t shift = vdupq_n_s32(16);
69*523fa7a6SAndroid Build Coastguard Worker return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
70*523fa7a6SAndroid Build Coastguard Worker }
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker static ET_INLINE float32x4_t
f32_fma_bf16(float32x4_t a,uint16x4_t b,uint16x4_t c)73*523fa7a6SAndroid Build Coastguard Worker f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
74*523fa7a6SAndroid Build Coastguard Worker return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
75*523fa7a6SAndroid Build Coastguard Worker }
76*523fa7a6SAndroid Build Coastguard Worker
77*523fa7a6SAndroid Build Coastguard Worker #define ET_TARGET_ARM_BF16_ATTRIBUTE \
78*523fa7a6SAndroid Build Coastguard Worker __attribute__((target("arch=armv8.2-a+bf16")))
79*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t
f32_dot_bf16(float32x4_t a,bfloat16x8_t b,bfloat16x8_t c)80*523fa7a6SAndroid Build Coastguard Worker f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) {
81*523fa7a6SAndroid Build Coastguard Worker return vbfdotq_f32(a, b, c);
82*523fa7a6SAndroid Build Coastguard Worker }
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop_bfdot(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)85*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith_main_inner_loop_bfdot(
86*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec1,
87*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec2,
88*523fa7a6SAndroid Build Coastguard Worker float32x4_t sum[kF32RegistersPerIteration],
89*523fa7a6SAndroid Build Coastguard Worker int registerPairIndex) {
90*523fa7a6SAndroid Build Coastguard Worker const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast<const __bf16*>(
91*523fa7a6SAndroid Build Coastguard Worker &vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
92*523fa7a6SAndroid Build Coastguard Worker const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast<const __bf16*>(
93*523fa7a6SAndroid Build Coastguard Worker &vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
94*523fa7a6SAndroid Build Coastguard Worker sum[registerPairIndex] =
95*523fa7a6SAndroid Build Coastguard Worker f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2);
96*523fa7a6SAndroid Build Coastguard Worker }
97*523fa7a6SAndroid Build Coastguard Worker
dot_with_fp32_arith_main_inner_loop_no_bfdot(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)98*523fa7a6SAndroid Build Coastguard Worker static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot(
99*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec1,
100*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec2,
101*523fa7a6SAndroid Build Coastguard Worker float32x4_t sum[kF32RegistersPerIteration],
102*523fa7a6SAndroid Build Coastguard Worker int registerPairIndex) {
103*523fa7a6SAndroid Build Coastguard Worker const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
104*523fa7a6SAndroid Build Coastguard Worker &vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
105*523fa7a6SAndroid Build Coastguard Worker const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
106*523fa7a6SAndroid Build Coastguard Worker &vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Worker sum[2 * registerPairIndex] = f32_fma_bf16(
109*523fa7a6SAndroid Build Coastguard Worker sum[2 * registerPairIndex],
110*523fa7a6SAndroid Build Coastguard Worker vget_low_u16(temp_vec1),
111*523fa7a6SAndroid Build Coastguard Worker vget_low_u16(temp_vec2));
112*523fa7a6SAndroid Build Coastguard Worker sum[2 * registerPairIndex + 1] = f32_fma_bf16(
113*523fa7a6SAndroid Build Coastguard Worker sum[2 * registerPairIndex + 1],
114*523fa7a6SAndroid Build Coastguard Worker vget_high_u16(temp_vec1),
115*523fa7a6SAndroid Build Coastguard Worker vget_high_u16(temp_vec2));
116*523fa7a6SAndroid Build Coastguard Worker }
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker template <bool useBfdot>
119*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void
dot_with_fp32_arith_main_inner_loop(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t sum[kF32RegistersPerIteration],int registerPairIndex)120*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith_main_inner_loop(
121*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec1,
122*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec2,
123*523fa7a6SAndroid Build Coastguard Worker float32x4_t sum[kF32RegistersPerIteration],
124*523fa7a6SAndroid Build Coastguard Worker int registerPairIndex) {
125*523fa7a6SAndroid Build Coastguard Worker if constexpr (useBfdot) {
126*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith_main_inner_loop_bfdot(
127*523fa7a6SAndroid Build Coastguard Worker vec1, vec2, sum, registerPairIndex);
128*523fa7a6SAndroid Build Coastguard Worker } else {
129*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith_main_inner_loop_no_bfdot(
130*523fa7a6SAndroid Build Coastguard Worker vec1, vec2, sum, registerPairIndex);
131*523fa7a6SAndroid Build Coastguard Worker }
132*523fa7a6SAndroid Build Coastguard Worker }
133*523fa7a6SAndroid Build Coastguard Worker
dot_with_fp32_arith_vectorized_tail_inner_loop(const BFloat16 * vec1,const BFloat16 * vec2,float32x4_t * tailSum,int idx)134*523fa7a6SAndroid Build Coastguard Worker static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
135*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec1,
136*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec2,
137*523fa7a6SAndroid Build Coastguard Worker float32x4_t* tailSum,
138*523fa7a6SAndroid Build Coastguard Worker int idx) {
139*523fa7a6SAndroid Build Coastguard Worker const auto temp_vec1 =
140*523fa7a6SAndroid Build Coastguard Worker vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
141*523fa7a6SAndroid Build Coastguard Worker const auto temp_vec2 =
142*523fa7a6SAndroid Build Coastguard Worker vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
143*523fa7a6SAndroid Build Coastguard Worker *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
144*523fa7a6SAndroid Build Coastguard Worker }
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker namespace {
147*523fa7a6SAndroid Build Coastguard Worker template <int n>
148*523fa7a6SAndroid Build Coastguard Worker struct ForcedUnrollTargetBFloat16 {
149*523fa7a6SAndroid Build Coastguard Worker template <typename Func>
operator ()executorch::cpublas::internal::__anon6f41bdbc0211::ForcedUnrollTargetBFloat16150*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
151*523fa7a6SAndroid Build Coastguard Worker ForcedUnrollTargetBFloat16<n - 1>{}(f);
152*523fa7a6SAndroid Build Coastguard Worker f(n - 1);
153*523fa7a6SAndroid Build Coastguard Worker }
154*523fa7a6SAndroid Build Coastguard Worker };
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker template <>
157*523fa7a6SAndroid Build Coastguard Worker struct ForcedUnrollTargetBFloat16<1> {
158*523fa7a6SAndroid Build Coastguard Worker template <typename Func>
operator ()executorch::cpublas::internal::__anon6f41bdbc0211::ForcedUnrollTargetBFloat16159*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const {
160*523fa7a6SAndroid Build Coastguard Worker f(0);
161*523fa7a6SAndroid Build Coastguard Worker }
162*523fa7a6SAndroid Build Coastguard Worker };
163*523fa7a6SAndroid Build Coastguard Worker
164*523fa7a6SAndroid Build Coastguard Worker } // namespace
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Worker template <typename T, bool useBFloat16Dot>
167*523fa7a6SAndroid Build Coastguard Worker ET_TARGET_ARM_BF16_ATTRIBUTE float
dot_with_fp32_arith(const T * vec1,const T * vec2,int64_t len)168*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
169*523fa7a6SAndroid Build Coastguard Worker float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
170*523fa7a6SAndroid Build Coastguard Worker const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
171*523fa7a6SAndroid Build Coastguard Worker for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
172*523fa7a6SAndroid Build Coastguard Worker const auto* vec1_ = vec1 + j;
173*523fa7a6SAndroid Build Coastguard Worker const auto* vec2_ = vec2 + j;
174*523fa7a6SAndroid Build Coastguard Worker ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration>{}(
175*523fa7a6SAndroid Build Coastguard Worker [vec1_, vec2_, &sum](auto k)
176*523fa7a6SAndroid Build Coastguard Worker ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE {
177*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith_main_inner_loop<useBFloat16Dot>(
178*523fa7a6SAndroid Build Coastguard Worker vec1_, vec2_, sum, k);
179*523fa7a6SAndroid Build Coastguard Worker });
180*523fa7a6SAndroid Build Coastguard Worker }
181*523fa7a6SAndroid Build Coastguard Worker auto reducedSum = reduce(sum);
182*523fa7a6SAndroid Build Coastguard Worker
183*523fa7a6SAndroid Build Coastguard Worker // First-tier tail fixup: make sure we handle workloads that can
184*523fa7a6SAndroid Build Coastguard Worker // benefit from vectorization, but don't fit into our fully unrolled
185*523fa7a6SAndroid Build Coastguard Worker // loop above.
186*523fa7a6SAndroid Build Coastguard Worker float32x4_t tailSum = vdupq_n_f32(0);
187*523fa7a6SAndroid Build Coastguard Worker const auto len_aligned_4 = len & ~3;
188*523fa7a6SAndroid Build Coastguard Worker for (int j = len_aligned; j < len_aligned_4; j += 4) {
189*523fa7a6SAndroid Build Coastguard Worker dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
190*523fa7a6SAndroid Build Coastguard Worker }
191*523fa7a6SAndroid Build Coastguard Worker auto reducedTail = vpaddq_f32(tailSum, tailSum);
192*523fa7a6SAndroid Build Coastguard Worker reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker // Second-tier tail fixup: handle all workloads.
195*523fa7a6SAndroid Build Coastguard Worker for (int j = len_aligned_4; j < len; ++j) {
196*523fa7a6SAndroid Build Coastguard Worker reducedSum += vec1[j] * vec2[j];
197*523fa7a6SAndroid Build Coastguard Worker }
198*523fa7a6SAndroid Build Coastguard Worker return reducedSum;
199*523fa7a6SAndroid Build Coastguard Worker }
200*523fa7a6SAndroid Build Coastguard Worker
bf16_dot_with_fp32_arith(const BFloat16 * vec1,const BFloat16 * vec2,int64_t len)201*523fa7a6SAndroid Build Coastguard Worker float bf16_dot_with_fp32_arith(
202*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec1,
203*523fa7a6SAndroid Build Coastguard Worker const BFloat16* vec2,
204*523fa7a6SAndroid Build Coastguard Worker int64_t len) {
205*523fa7a6SAndroid Build Coastguard Worker if (cpuinfo_has_arm_bf16()) {
206*523fa7a6SAndroid Build Coastguard Worker return dot_with_fp32_arith<BFloat16, true>(vec1, vec2, len);
207*523fa7a6SAndroid Build Coastguard Worker } else {
208*523fa7a6SAndroid Build Coastguard Worker return dot_with_fp32_arith<BFloat16, false>(vec1, vec2, len);
209*523fa7a6SAndroid Build Coastguard Worker }
210*523fa7a6SAndroid Build Coastguard Worker }
211*523fa7a6SAndroid Build Coastguard Worker #endif // __aarch64__
212*523fa7a6SAndroid Build Coastguard Worker } // namespace internal
213*523fa7a6SAndroid Build Coastguard Worker } // namespace cpublas
214*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
215