1 #pragma once
2
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
5 #include <ATen/cpu/vec/vec_base.h>
6 #include <c10/util/irange.h>
7
8 namespace at {
9 namespace vec {
10 // See Note [CPU_CAPABILITY namespace]
11 inline namespace CPU_CAPABILITY {
12
convert_bfloat16_float(const Vectorized<BFloat16> & a)13 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(
14 const Vectorized<BFloat16>& a) {
15 constexpr int64_t K = Vectorized<BFloat16>::size();
16 __at_align__ float arr[K];
17 __at_align__ BFloat16 arr2[K];
18 a.store(arr2);
19 convert(arr2, arr, K);
20 return std::make_tuple(
21 Vectorized<float>::loadu(arr),
22 Vectorized<float>::loadu(arr + Vectorized<float>::size()));
23 }
24
convert_float_bfloat16(const Vectorized<float> & a,const Vectorized<float> & b)25 inline Vectorized<BFloat16> convert_float_bfloat16(
26 const Vectorized<float>& a,
27 const Vectorized<float>& b) {
28 constexpr int64_t K = Vectorized<BFloat16>::size();
29 __at_align__ float arr[K];
30 __at_align__ BFloat16 arr2[K];
31 a.store(arr);
32 b.store(arr + Vectorized<float>::size());
33 convert(arr, arr2, K);
34 return Vectorized<BFloat16>::loadu(arr2);
35 }
36
load_fp32_from_bf16(const c10::BFloat16 * data,Vectorized<float> & out)37 inline void load_fp32_from_bf16(const c10::BFloat16* data, Vectorized<float>& out) {
38 __at_align__ float values[Vectorized<float>::size()];
39 for (const auto k : c10::irange(Vectorized<float>::size())) {
40 values[k] = data[k];
41 }
42 out = Vectorized<float>::loadu(values);
43 }
44
load_fp32_from_bf16(const c10::BFloat16 * data,Vectorized<float> & out1,Vectorized<float> & out2)45 inline void load_fp32_from_bf16(
46 const c10::BFloat16* data,
47 Vectorized<float>& out1,
48 Vectorized<float>& out2) {
49 load_fp32_from_bf16(data, out1);
50 data += Vectorized<float>::size();
51 load_fp32_from_bf16(data, out2);
52 }
53
load_fp32_from_fp16(const c10::Half * data,Vectorized<float> & out)54 inline void load_fp32_from_fp16(const c10::Half* data, Vectorized<float>& out) {
55 __at_align__ float values[Vectorized<float>::size()];
56 for (const auto k : c10::irange(Vectorized<float>::size())) {
57 values[k] = data[k];
58 }
59 out = Vectorized<float>::loadu(values);
60 }
61
load_fp32_from_fp16(const c10::Half * data,Vectorized<float> & out1,Vectorized<float> & out2)62 inline void load_fp32_from_fp16(
63 const c10::Half* data,
64 Vectorized<float>& out1,
65 Vectorized<float>& out2) {
66 load_fp32_from_fp16(data, out1);
67 data += Vectorized<float>::size();
68 load_fp32_from_fp16(data, out2);
69 }
70
71 } // namespace
72 } // namespace vec
73 } // namespace at
74