/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #ifdef __aarch64__ #include #include #endif using torch::executor::BFloat16; namespace executorch { namespace cpublas { namespace internal { #ifdef __aarch64__ static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) { #ifdef __ARM_FEATURE_FMA return vfmaq_f32(a, b, c); #else return vaddq_f32(a, vmulq_f32(b, c)); #endif // __ARM_FEATURE_FMA } // The below reduce overload and fp16_dot_with_fp32_arith are adapted // from llama.cpp's ggml_vec_dot_f32 and surrounding utility // functions. See NOTE [ GGML Copyright Notice ] above for the // required notice. // We need the shift for reduce(), hence the extra constants. static constexpr auto kF32ElementsPerIterationShift = 5; static constexpr auto kF32ElementsPerIteration = 1 << kF32ElementsPerIterationShift; static_assert(kF32ElementsPerIteration == 32); static constexpr auto kF32ElementsPerRegisterShift = 2; static constexpr auto kF32ElementsPerRegister = 1 << kF32ElementsPerRegisterShift; static_assert(kF32ElementsPerRegister == 4); static constexpr auto kF32RegisterPairsPerIteration = 4; static constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2; static constexpr auto kF32RegistersPerIterationShift = 3; static_assert( kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister); static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift); static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { int offset = kF32RegistersPerIteration; utils::ForcedUnroll{}( [&offset, &x](auto idx) ET_INLINE_ATTRIBUTE { offset /= 2; for (int i = 0; i < offset; ++i) { x[i] = vaddq_f32(x[i], x[offset + i]); } }); return vaddvq_f32(x[0]); } static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) { int32x4_t shift = vdupq_n_s32(16); return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift)); } static ET_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); } #define ET_TARGET_ARM_BF16_ATTRIBUTE \ __attribute__((target("arch=armv8.2-a+bf16"))) ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE float32x4_t f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { return vbfdotq_f32(a, b, c); } ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void dot_with_fp32_arith_main_inner_loop_bfdot( const BFloat16* vec1, const BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); sum[registerPairIndex] = f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); } static ET_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( const BFloat16* vec1, const BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); sum[2 * registerPairIndex] = f32_fma_bf16( sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2)); sum[2 * registerPairIndex + 1] = f32_fma_bf16( sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2)); } template ET_TARGET_ARM_BF16_ATTRIBUTE static ET_INLINE void dot_with_fp32_arith_main_inner_loop( const BFloat16* vec1, const BFloat16* vec2, float32x4_t sum[kF32RegistersPerIteration], int registerPairIndex) { if constexpr (useBfdot) { dot_with_fp32_arith_main_inner_loop_bfdot( vec1, vec2, sum, registerPairIndex); } else { dot_with_fp32_arith_main_inner_loop_no_bfdot( vec1, vec2, sum, registerPairIndex); } } static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( const BFloat16* vec1, const BFloat16* vec2, float32x4_t* tailSum, int idx) { const auto temp_vec1 = vld1_u16(reinterpret_cast(&vec1[idx])); const auto temp_vec2 = vld1_u16(reinterpret_cast(&vec2[idx])); *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); } namespace { template struct ForcedUnrollTargetBFloat16 { template ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const { ForcedUnrollTargetBFloat16{}(f); f(n - 1); } }; template <> struct ForcedUnrollTargetBFloat16<1> { template ET_TARGET_ARM_BF16_ATTRIBUTE ET_INLINE void operator()(const Func& f) const { f(0); } }; } // namespace template ET_TARGET_ARM_BF16_ATTRIBUTE float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; ForcedUnrollTargetBFloat16{}( [vec1_, vec2_, &sum](auto k) ET_INLINE_ATTRIBUTE ET_TARGET_ARM_BF16_ATTRIBUTE { dot_with_fp32_arith_main_inner_loop( vec1_, vec2_, sum, k); }); } auto reducedSum = reduce(sum); // First-tier tail fixup: make sure we handle workloads that can // benefit from vectorization, but don't fit into our fully unrolled // loop above. float32x4_t tailSum = vdupq_n_f32(0); const auto len_aligned_4 = len & ~3; for (int j = len_aligned; j < len_aligned_4; j += 4) { dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); } auto reducedTail = vpaddq_f32(tailSum, tailSum); reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); // Second-tier tail fixup: handle all workloads. for (int j = len_aligned_4; j < len; ++j) { reducedSum += vec1[j] * vec2[j]; } return reducedSum; } float bf16_dot_with_fp32_arith( const BFloat16* vec1, const BFloat16* vec2, int64_t len) { if (cpuinfo_has_arm_bf16()) { return dot_with_fp32_arith(vec1, vec2, len); } else { return dot_with_fp32_arith(vec1, vec2, len); } } #endif // __aarch64__ } // namespace internal } // namespace cpublas } // namespace executorch