xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
5 #include <ATen/cpu/vec/vec_base.h>
6 #include <c10/util/irange.h>
7 
8 namespace at {
9 namespace vec {
10 // See Note [CPU_CAPABILITY namespace]
11 inline namespace CPU_CAPABILITY {
12 
convert_bfloat16_float(const Vectorized<BFloat16> & a)13 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
14     const Vectorized<BFloat16>& a) {
15   constexpr int64_t K = Vectorized<BFloat16>::size();
16   __at_align__ float arr[K];
17   __at_align__ BFloat16 arr2[K];
18   a.store(arr2);
19   convert(arr2, arr, K);
20   return std::make_tuple(
21       Vectorized<float>::loadu(arr),
22       Vectorized<float>::loadu(arr + Vectorized<float>::size()));
23 }
24 
convert_float_bfloat16(const Vectorized<float> & a,const Vectorized<float> & b)25 inline Vectorized<BFloat16> convert_float_bfloat16(
26     const Vectorized<float>& a,
27     const Vectorized<float>& b) {
28   constexpr int64_t K = Vectorized<BFloat16>::size();
29   __at_align__ float arr[K];
30   __at_align__ BFloat16 arr2[K];
31   a.store(arr);
32   b.store(arr + Vectorized<float>::size());
33   convert(arr, arr2, K);
34   return Vectorized<BFloat16>::loadu(arr2);
35 }
36 
load_fp32_from_bf16(const c10::BFloat16 * data,Vectorized<float> & out)37 inline void load_fp32_from_bf16(const c10::BFloat16* data, Vectorized<float>& out) {
38   __at_align__ float values[Vectorized<float>::size()];
39   for (const auto k : c10::irange(Vectorized<float>::size())) {
40     values[k] = data[k];
41   }
42   out = Vectorized<float>::loadu(values);
43 }
44 
load_fp32_from_bf16(const c10::BFloat16 * data,Vectorized<float> & out1,Vectorized<float> & out2)45 inline void load_fp32_from_bf16(
46     const c10::BFloat16* data,
47     Vectorized<float>& out1,
48     Vectorized<float>& out2) {
49   load_fp32_from_bf16(data, out1);
50   data += Vectorized<float>::size();
51   load_fp32_from_bf16(data, out2);
52 }
53 
load_fp32_from_fp16(const c10::Half * data,Vectorized<float> & out)54 inline void load_fp32_from_fp16(const c10::Half* data, Vectorized<float>& out) {
55   __at_align__ float values[Vectorized<float>::size()];
56   for (const auto k : c10::irange(Vectorized<float>::size())) {
57     values[k] = data[k];
58   }
59   out = Vectorized<float>::loadu(values);
60 }
61 
load_fp32_from_fp16(const c10::Half * data,Vectorized<float> & out1,Vectorized<float> & out2)62 inline void load_fp32_from_fp16(
63     const c10::Half* data,
64     Vectorized<float>& out1,
65     Vectorized<float>& out2) {
66   load_fp32_from_fp16(data, out1);
67   data += Vectorized<float>::size();
68   load_fp32_from_fp16(data, out2);
69 }
70 
71 } // namespace
72 } // namespace vec
73 } // namespace at
74