1 #pragma once
2
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6
7 // Note: header order is important here
8 #include <ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h>
9 #include <ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h>
10 #include <ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h>
11 #include <ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h>
12 #include <ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h>
13 #include <ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h>
14 #include <ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h>
15 #include <ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h>
16
17 #include <ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h>
18 #include <ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h>
19
20 #include <ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h>
21
22 namespace at {
23 namespace vec {
24
25 inline namespace CPU_CAPABILITY {
26
27 DEFINE_CLAMP_FUNCS(c10::quint8)
DEFINE_CLAMP_FUNCS(c10::qint8)28 DEFINE_CLAMP_FUNCS(c10::qint8)
29 DEFINE_CLAMP_FUNCS(c10::qint32)
30 DEFINE_CLAMP_FUNCS(int16_t)
31 DEFINE_CLAMP_FUNCS(int32_t)
32 DEFINE_CLAMP_FUNCS(int64_t)
33 DEFINE_CLAMP_FUNCS(float)
34 DEFINE_CLAMP_FUNCS(double)
35
36 template <>
37 Vectorized<double> C10_ALWAYS_INLINE fmadd(
38 const Vectorized<double>& a,
39 const Vectorized<double>& b,
40 const Vectorized<double>& c) {
41 return Vectorized<double>{
42 vec_madd(a.vec0(), b.vec0(), c.vec0()),
43 vec_madd(a.vec1(), b.vec1(), c.vec1())};
44 }
45
46 template <>
fmadd(const Vectorized<int64_t> & a,const Vectorized<int64_t> & b,const Vectorized<int64_t> & c)47 Vectorized<int64_t> C10_ALWAYS_INLINE fmadd(
48 const Vectorized<int64_t>& a,
49 const Vectorized<int64_t>& b,
50 const Vectorized<int64_t>& c) {
51 return Vectorized<int64_t>{
52 a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
53 }
54 template <>
fmadd(const Vectorized<int32_t> & a,const Vectorized<int32_t> & b,const Vectorized<int32_t> & c)55 Vectorized<int32_t> C10_ALWAYS_INLINE fmadd(
56 const Vectorized<int32_t>& a,
57 const Vectorized<int32_t>& b,
58 const Vectorized<int32_t>& c) {
59 return Vectorized<int32_t>{
60 a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
61 }
62 template <>
fmadd(const Vectorized<int16_t> & a,const Vectorized<int16_t> & b,const Vectorized<int16_t> & c)63 Vectorized<int16_t> C10_ALWAYS_INLINE fmadd(
64 const Vectorized<int16_t>& a,
65 const Vectorized<int16_t>& b,
66 const Vectorized<int16_t>& c) {
67 return Vectorized<int16_t>{
68 a.vec0() * b.vec0() + c.vec0(), a.vec1() * b.vec1() + c.vec1()};
69 }
70
71 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(float)
DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)72 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(double)
73 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int64_t)
74 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int32_t)
75 DEFINE_REINTERPRET_CAST_TO_ALL_FUNCS(int16_t)
76
77 template <>
78 Vectorized<int64_t> C10_ALWAYS_INLINE
79 convert_to_int_of_same_size<double>(const Vectorized<double>& src) {
80 return Vectorized<int64_t>{vec_signed(src.vec0()), vec_signed(src.vec1())};
81 }
82
83 template <>
84 Vectorized<int32_t> C10_ALWAYS_INLINE
85 convert_to_int_of_same_size<float>(
86 const Vectorized<float>& src) {
87 return Vectorized<int32_t>{vec_signed(src.vec0()), vec_signed(src.vec1())};
88 }
89
90 template <>
convert(const int32_t * src,float * dst,int64_t n)91 inline void convert(const int32_t* src, float* dst, int64_t n) {
92 // int32_t and float have same size
93 int64_t i;
94 for (i = 0; i <= (n - Vectorized<float>::size()); i += Vectorized<float>::size()) {
95 const int32_t* src_a = src + i;
96 float* dst_a = dst + i;
97 vint32 input_vec0 = vec_vsx_ld(offset0, reinterpret_cast<const vint32*>(src_a));
98 vint32 input_vec1 =
99 vec_vsx_ld(offset16, reinterpret_cast<const vint32*>(src_a));
100 vfloat32 c0 = vec_float(input_vec0);
101 vfloat32 c1 = vec_float(input_vec1);
102 vec_vsx_st(c0, offset0, dst_a);
103 vec_vsx_st(c1, offset16, dst_a);
104 }
105
106 for (; i < n; i++) {
107 dst[i] = static_cast<float>(src[i]);
108 }
109 }
110
111 template <>
convert(const int64_t * src,double * dst,int64_t n)112 inline void convert(const int64_t* src, double* dst, int64_t n) {
113 int64_t i;
114 for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
115 const int64_t* src_a = src + i;
116 double* dst_a = dst + i;
117 vint64 input_vec0 =
118 vec_vsx_ld(offset0, reinterpret_cast<const vint64*>(src_a));
119 vint64 input_vec1 =
120 vec_vsx_ld(offset16, reinterpret_cast<const vint64*>(src_a));
121 vfloat64 c0 = vec_double(input_vec0);
122 vfloat64 c1 = vec_double(input_vec1);
123 vec_vsx_st(c0, offset0, reinterpret_cast<double*>(dst_a));
124 vec_vsx_st(c1, offset16, reinterpret_cast<double*>(dst_a));
125 }
126 for (; i < n; i++) {
127 dst[i] = static_cast<double>(src[i]);
128 }
129 }
130 //Generic implementation to fix compiler error
131 //TO-DO : Add optimized version for ppc64
convert_half_float(const Vectorized<Half> & a)132 inline std::tuple<Vectorized<float>, Vectorized<float>> convert_half_float(
133 const Vectorized<Half>& a) {
134 constexpr int64_t K = Vectorized<Half>::size();
135 __at_align__ float arr[K];
136 __at_align__ Half arr2[K];
137 a.store(arr2);
138 convert(arr2, arr, K);
139 return std::make_tuple(
140 Vectorized<float>::loadu(arr),
141 Vectorized<float>::loadu(arr + Vectorized<float>::size()));
142 }
143
convert_float_half(const Vectorized<float> & a,const Vectorized<float> & b)144 inline Vectorized<Half> convert_float_half(
145 const Vectorized<float>& a, const Vectorized<float>& b) {
146 constexpr int64_t K = Vectorized<Half>::size();
147 __at_align__ float arr[K];
148 __at_align__ Half arr2[K];
149 a.store(arr);
150 b.store(arr + Vectorized<float>::size());
151 convert(arr, arr2, K);
152 return Vectorized<Half>::loadu(arr2);
153 };
154
155 template <>
156 std::pair<Vectorized<double>, Vectorized<double>> inline interleave2<double>(
157 const Vectorized<double>& a,
158 const Vectorized<double>& b) {
159 // inputs:
160 // a = {a0, a1, a2, a3}
161 // b = {b0, b1, b2, b3}
162
163 vfloat64 ab00 = vec_xxpermdi(a.vec0(), b.vec0(), 0);
164 vfloat64 ab11 = vec_xxpermdi(a.vec0(), b.vec0(), 3);
165 vfloat64 ab2_00 = vec_xxpermdi(a.vec1(), b.vec1(), 0);
166 vfloat64 ab2_11 = vec_xxpermdi(a.vec1(), b.vec1(), 3);
167 // return {a0, b0, a1, b1}
168 // {a2, b2, a3, b3}
169 return std::make_pair(
170 Vectorized<double>{ab00, ab11}, Vectorized<double>{ab2_00, ab2_11});
171 }
172
173 template <>
174 std::pair<Vectorized<double>, Vectorized<double>> inline deinterleave2<double>(
175 const Vectorized<double>& a,
176 const Vectorized<double>& b) {
177 // inputs:
178 // a = {a0, b0, a1, b1}
179 // b = {a2, b2, a3, b3}
180 vfloat64 aa01 = vec_xxpermdi(a.vec0(), a.vec1(), 0);
181 vfloat64 aa23 = vec_xxpermdi(b.vec0(), b.vec1(), 0);
182
183 vfloat64 bb_01 = vec_xxpermdi(a.vec0(), a.vec1(), 3);
184 vfloat64 bb_23 = vec_xxpermdi(b.vec0(), b.vec1(), 3);
185
186 // swap lanes:
187 // return {a0, a1, a2, a3}
188 // {b0, b1, b2, b3}
189 return std::make_pair(
190 Vectorized<double>{aa01, aa23}, Vectorized<double>{bb_01, bb_23});
191 }
192
193 template <>
194 std::pair<Vectorized<float>, Vectorized<float>> inline interleave2<float>(
195 const Vectorized<float>& a,
196 const Vectorized<float>& b) {
197 // inputs:
198 // a = {a0, a1, a2, a3,, a4, a5, a6, a7}
199 // b = {b0, b1, b2, b3,, b4, b5, b6, b7}
200
201 vfloat32 ab0011 = vec_mergeh(a.vec0(), b.vec0());
202 vfloat32 ab2233 = vec_mergel(a.vec0(), b.vec0());
203
204 vfloat32 ab2_0011 = vec_mergeh(a.vec1(), b.vec1());
205 vfloat32 ab2_2233 = vec_mergel(a.vec1(), b.vec1());
206 // group cols crossing lanes:
207 // return {a0, b0, a1, b1,, a2, b2, a3, b3}
208 // {a4, b4, a5, b5,, a6, b6, a7, b7}
209
210 return std::make_pair(
211 Vectorized<float>{ab0011, ab2233}, Vectorized<float>{ab2_0011, ab2_2233});
212 }
213
214 template <>
215 std::pair<Vectorized<float>, Vectorized<float>> inline deinterleave2<float>(
216 const Vectorized<float>& a,
217 const Vectorized<float>& b) {
218 // inputs:
219 // a = {a0, b0, a1, b1,, a2, b2, a3, b3}
220 // b = {a4, b4, a5, b5,, a6, b6, a7, b7}
221
222 // {a0,a2,b0,b2} {a1,a3,b1,b3}
223 vfloat32 a0a2b0b2 = vec_mergeh(a.vec0(), a.vec1());
224 vfloat32 a1a3b1b3 = vec_mergel(a.vec0(), a.vec1());
225
226 vfloat32 aa0123 = vec_mergeh(a0a2b0b2, a1a3b1b3);
227 vfloat32 bb0123 = vec_mergel(a0a2b0b2, a1a3b1b3);
228
229 vfloat32 a0a2b0b2_2 = vec_mergeh(b.vec0(), b.vec1());
230 vfloat32 a1a3b1b3_2 = vec_mergel(b.vec0(), b.vec1());
231
232 vfloat32 aa0123_2 = vec_mergeh(a0a2b0b2_2, a1a3b1b3_2);
233 vfloat32 bb0123_2 = vec_mergel(a0a2b0b2_2, a1a3b1b3_2);
234
235 // it could be done with vec_perm ,too
236 // swap lanes:
237 // return {a0, a1, a2, a3,, a4, a5, a6, a7}
238 // {b0, b1, b2, b3,, b4, b5, b6, b7}
239
240 return std::make_pair(
241 Vectorized<float>{aa0123, aa0123_2}, Vectorized<float>{bb0123, bb0123_2});
242 }
243
244 } // namespace
245 } // namespace vec
246 } // namespace at
247