xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 
7 // Note: header order is important here
8 #include <ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h>
9 #include <ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h>
10 #include <ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h>
11 #include <ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h>
12 #include <ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h>
13 #include <ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h>
14 #include <ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h>
15 #include <ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h>
16 
17 #include <ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h>
18 #include <ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h>
19 
20 #include <ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h>
21 
22 namespace at {
23 namespace vec {
24 
25 inline namespace CPU_CAPABILITY {
26 
27 DEFINE_CLAMP_FUNCS(c10::quint8)
DEFINE_CLAMP_FUNCS(c10::qint8)28 DEFINE_CLAMP_FUNCS(c10::qint8)
29 DEFINE_CLAMP_FUNCS(c10::qint32)
30 DEFINE_CLAMP_FUNCS(int16_t)
31 DEFINE_CLAMP_FUNCS(int32_t)
32 DEFINE_CLAMP_FUNCS(int64_t)
33 DEFINE_CLAMP_FUNCS(float)
34 DEFINE_CLAMP_FUNCS(double)
35 
36 template <>
37 Vectorized<double> C10_ALWAYS_INLINE fmadd(
38     const Vectorized<double>& a,
39     const Vectorized<double>& b,
40     const Vectorized<double>& c) {
41   return Vectorized<double>{
42       vec_madd(a.vec0(), b.vec0(), c.vec0()),
43       vec_madd(a.vec1(), b.vec1(), c.vec1())};
44 }
45 
46 template <>
fmadd(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const Vectorized<int64_t> & c)47 Vectorized<int64_t> C10_ALWAYS_INLINE fmadd(
48     const Vectorized<int64_t>& a,
49     const Vectorized<int64_t>& b,
50     const Vectorized<int64_t>& c) {
51   return Vectorized<int64_t>{
52       a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
53 }
54 template <>
fmadd(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b,const Vectorized<int32_t> & c)55 Vectorized<int32_t> C10_ALWAYS_INLINE fmadd(
56     const Vectorized<int32_t>& a,
57     const Vectorized<int32_t>& b,
58     const Vectorized<int32_t>& c) {
59   return Vectorized<int32_t>{
60       a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
61 }
62 template <>
fmadd(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b,const Vectorized<int16_t> & c)63 Vectorized<int16_t> C10_ALWAYS_INLINE fmadd(
64     const Vectorized<int16_t>& a,
65     const Vectorized<int16_t>& b,
66     const Vectorized<int16_t>& c) {
67   return Vectorized<int16_t>{
68       a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
69 }
70 
71 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)72 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)
73 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t)
74 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t)
75 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t)
76 
77 template <>
78 Vectorized<int64_t> C10_ALWAYS_INLINE
79 convert_to_int_of_same_size<double>(const Vectorized<double>& src) {
80   return Vectorized<int64_t>{vec_signed(src.vec0()), vec_signed(src.vec1())};
81 }
82 
83 template <>
84 Vectorized<int32_t> C10_ALWAYS_INLINE
85 convert_to_int_of_same_size<float>(
86     const Vectorized<float>& src) {
87   return Vectorized<int32_t>{vec_signed(src.vec0()), vec_signed(src.vec1())};
88 }
89 
90 template <>
convert(const int32_t * src,float * dst,int64_t n)91 inline void convert(const int32_t* src, float* dst, int64_t n) {
92   // int32_t and float have same size
93   int64_t i;
94   for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
95     const int32_t* src_a = src + i;
96     float* dst_a = dst + i;
97     vint32 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast<const vint32*>(src_a));
98     vint32 input_vec1 =
99         vec_vsx_ld(offset16, reinterpret_cast<const vint32*>(src_a));
100     vfloat32 c0 = vec_float(input_vec0);
101     vfloat32 c1 = vec_float(input_vec1);
102     vec_vsx_st(c0, offset0, dst_a);
103     vec_vsx_st(c1, offset16, dst_a);
104   }
105 
106   for (; i < n; i++) {
107     dst[i] = static_cast<float>(src[i]);
108   }
109 }
110 
111 template <>
convert(const int64_t * src,double * dst,int64_t n)112 inline void convert(const int64_t* src, double* dst, int64_t n) {
113   int64_t i;
114   for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
115     const int64_t* src_a = src + i;
116     double* dst_a = dst + i;
117     vint64 input_vec0 =
118         vec_vsx_ld(offset0, reinterpret_cast<const vint64*>(src_a));
119     vint64 input_vec1 =
120         vec_vsx_ld(offset16, reinterpret_cast<const vint64*>(src_a));
121     vfloat64 c0 = vec_double(input_vec0);
122     vfloat64 c1 = vec_double(input_vec1);
123     vec_vsx_st(c0, offset0, reinterpret_cast<double*>(dst_a));
124     vec_vsx_st(c1, offset16, reinterpret_cast<double*>(dst_a));
125   }
126   for (; i < n; i++) {
127     dst[i] = static_cast<double>(src[i]);
128   }
129 }
130 //Generic implementation to fix compiler error
131 //TO-DO : Add optimized version for ppc64
convert_half_float(const Vectorized<Half> & a)132 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_half_float(
133     const Vectorized<Half>& a) {
134   constexpr int64_t K = Vectorized<Half>::size();
135   __at_align__ float arr[K];
136   __at_align__ Half arr2[K];
137   a.store(arr2);
138   convert(arr2, arr, K);
139   return std::make_tuple(
140        Vectorized<float>::loadu(arr),
141        Vectorized<float>::loadu(arr + Vectorized<float>::size()));
142 }
143 
convert_float_half(const Vectorized<float> & a,const Vectorized<float> & b)144 inline Vectorized<Half> convert_float_half(
145     const Vectorized<float>& a, const Vectorized<float>& b) {
146   constexpr int64_t K = Vectorized<Half>::size();
147   __at_align__ float arr[K];
148   __at_align__ Half arr2[K];
149   a.store(arr);
150   b.store(arr + Vectorized<float>::size());
151   convert(arr, arr2, K);
152   return Vectorized<Half>::loadu(arr2);
153 };
154 
155 template <>
156 std::pair<Vectorized<double>, Vectorized<double>> inline interleave2<double>(
157     const Vectorized<double>& a,
158     const Vectorized<double>& b) {
159   // inputs:
160   //   a      = {a0, a1, a2, a3}
161   //   b      = {b0, b1, b2, b3}
162 
163   vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0);
164   vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3);
165   vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0);
166   vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3);
167   //   return {a0, b0, a1, b1}
168   //          {a2, b2, a3, b3}
169   return std::make_pair(
170       Vectorized<double>{ab00, ab11}, Vectorized<double>{ab2_00, ab2_11});
171 }
172 
173 template <>
174 std::pair<Vectorized<double>, Vectorized<double>> inline deinterleave2<double>(
175     const Vectorized<double>& a,
176     const Vectorized<double>& b) {
177   // inputs:
178   //   a = {a0, b0, a1, b1}
179   //   b = {a2, b2, a3, b3}
180   vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0);
181   vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0);
182 
183   vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3);
184   vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3);
185 
186   // swap lanes:
187   //   return {a0, a1, a2, a3}
188   //          {b0, b1, b2, b3}
189   return std::make_pair(
190       Vectorized<double>{aa01, aa23}, Vectorized<double>{bb_01, bb_23});
191 }
192 
193 template <>
194 std::pair<Vectorized<float>, Vectorized<float>> inline interleave2<float>(
195     const Vectorized<float>& a,
196     const Vectorized<float>& b) {
197   // inputs:
198   //   a = {a0, a1, a2, a3,, a4, a5, a6, a7}
199   //   b = {b0, b1, b2, b3,, b4, b5, b6, b7}
200 
201   vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0());
202   vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0());
203 
204   vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1());
205   vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1());
206   // group cols crossing lanes:
207   //   return {a0, b0, a1, b1,, a2, b2, a3, b3}
208   //          {a4, b4, a5, b5,, a6, b6, a7, b7}
209 
210   return std::make_pair(
211       Vectorized<float>{ab0011, ab2233}, Vectorized<float>{ab2_0011, ab2_2233});
212 }
213 
214 template <>
215 std::pair<Vectorized<float>, Vectorized<float>> inline deinterleave2<float>(
216     const Vectorized<float>& a,
217     const Vectorized<float>& b) {
218   // inputs:
219   //   a = {a0, b0, a1, b1,, a2, b2, a3, b3}
220   //   b = {a4, b4, a5, b5,, a6, b6, a7, b7}
221 
222   // {a0,a2,b0,b2} {a1,a3,b1,b3}
223   vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1());
224   vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1());
225 
226   vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3);
227   vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3);
228 
229   vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1());
230   vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1());
231 
232   vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2);
233   vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2);
234 
235   // it could be done with vec_perm ,too
236   // swap lanes:
237   //   return {a0, a1, a2, a3,, a4, a5, a6, a7}
238   //          {b0, b1, b2, b3,, b4, b5, b6, b7}
239 
240   return std::make_pair(
241       Vectorized<float>{aa0123, aa0123_2}, Vectorized<float>{bb0123, bb0123_2});
242 }
243 
244 } // namespace
245 } // namespace vec
246 } // namespace at
247