1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13
14 #include <executorch/kernels/optimized/vec/intrinsics.h>
15
16 #include <executorch/kernels/optimized/vec/vec_base.h>
17 #if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR))
18 #include <executorch/kernels/optimized/vec/vec256/vec256_float.h>
19 #include <executorch/kernels/optimized/vec/vec256/vec256_float_neon.h>
20 #include <executorch/kernels/optimized/vec/vec256/vec256_double.h>
21 #include <executorch/kernels/optimized/vec/vec256/vec256_int.h>
22 #endif
23
24 #include <algorithm>
25 #include <cstddef>
26 #include <cstdint>
27 #include <cstring>
28 #include <ostream>
29
30 namespace executorch {
31 namespace vec {
32
33 // Note [CPU_CAPABILITY namespace]
34 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
35 // This header, and all of its subheaders, will be compiled with
36 // different architecture flags for each supported set of vector
37 // intrinsics. So we need to make sure they aren't inadvertently
38 // linked together. We do this by declaring objects in an `inline
39 // namespace` which changes the name mangling, but can still be
40 // accessed as `at::vec`.
41 inline namespace CPU_CAPABILITY {
42
43 template <typename T>
44 std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
45 T buf[Vectorized<T>::size()];
46 vec.store(buf);
47 stream << "vec[";
48 for (size_t i = 0; i != Vectorized<T>::size(); i++) {
49 if (i != 0) {
50 stream << ", ";
51 }
52 stream << buf[i];
53 }
54 stream << "]";
55 return stream;
56 }
57
58
59 #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
60
61 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
62
63 template<>
64 inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
65 return _mm256_castpd_ps(src);
66 }
67
68 template<>
69 inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
70 return _mm256_castps_pd(src);
71 }
72
73 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
74
75 template<int64_t scale = 1>
76 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
gather(const double * base_addr,const Vectorized<int64_t> & vindex)77 inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
78 return _mm256_i64gather_pd(base_addr, vindex, scale);
79 }
80
81 template<int64_t scale = 1>
82 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
gather(const float * base_addr,const Vectorized<int32_t> & vindex)83 inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
84 return _mm256_i32gather_ps(base_addr, vindex, scale);
85 }
86
87 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
88
89 template<int64_t scale = 1>
90 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
mask_gather(const Vectorized<double> & src,const double * base_addr,const Vectorized<int64_t> & vindex,const Vectorized<double> & mask)91 inline mask_gather(const Vectorized<double>& src, const double* base_addr,
92 const Vectorized<int64_t>& vindex, const Vectorized<double>& mask) {
93 return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale);
94 }
95
96 template<int64_t scale = 1>
97 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
mask_gather(const Vectorized<float> & src,const float * base_addr,const Vectorized<int32_t> & vindex,const Vectorized<float> & mask)98 inline mask_gather(const Vectorized<float>& src, const float* base_addr,
99 const Vectorized<int32_t>& vindex, const Vectorized<float>& mask) {
100 return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
101 }
102
103 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104
105 // Only works for inputs in the range: [-2^51, 2^51]
106 // From: https://stackoverflow.com/a/41148578
107 template<>
108 Vectorized<int64_t>
109 inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
110 auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000));
111 return _mm256_sub_epi64(
112 _mm256_castpd_si256(x),
113 _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))
114 );
115 }
116
117 template<>
118 Vectorized<int32_t>
119 inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
120 return _mm256_cvttps_epi32(src);
121 }
122
123 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
124
125 template <>
126 std::pair<Vectorized<double>, Vectorized<double>>
127 inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
128 // inputs:
129 // a = {a0, a1, a3, a3}
130 // b = {b0, b1, b2, b3}
131
132 // swap lanes:
133 // a_swapped = {a0, a1, b0, b1}
134 // b_swapped = {a2, a3, b2, b3}
135 auto a_swapped = _mm256_permute2f128_pd(a, b, 0b0100000); // 0, 2. 4 bits apart
136 auto b_swapped = _mm256_permute2f128_pd(a, b, 0b0110001); // 1, 3. 4 bits apart
137
138 // group cols crossing lanes:
139 // return {a0, b0, a1, b1}
140 // {a2, b2, a3, b3}
141 return std::make_pair(_mm256_permute4x64_pd(a_swapped, 0b11011000), // 0, 2, 1, 3
142 _mm256_permute4x64_pd(b_swapped, 0b11011000)); // 0, 2, 1, 3
143 }
144
145 template <>
146 std::pair<Vectorized<float>, Vectorized<float>>
147 inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
148 // inputs:
149 // a = {a0, a1, a2, a3, a4, a5, a6, a7}
150 // b = {b0, b1, b2, b3, b4, b5, b6, b7}
151
152 // swap lanes:
153 // a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3}
154 // b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7}
155 // TODO: can we support caching this?
156 auto a_swapped = _mm256_permute2f128_ps(a, b, 0b0100000); // 0, 2. 4 bits apart
157 auto b_swapped = _mm256_permute2f128_ps(a, b, 0b0110001); // 1, 3. 4 bits apart
158
159 // group cols crossing lanes:
160 // return {a0, b0, a1, b1, a2, b2, a3, b3}
161 // {a4, b4, a5, b5, a6, b6, a7, b7}
162 const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
163 return std::make_pair(_mm256_permutevar8x32_ps(a_swapped, group_ctrl),
164 _mm256_permutevar8x32_ps(b_swapped, group_ctrl));
165 }
166
167 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
168
169 template <>
170 std::pair<Vectorized<double>, Vectorized<double>>
171 inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
172 // inputs:
173 // a = {a0, b0, a1, b1}
174 // b = {a2, b2, a3, b3}
175
176 // group cols crossing lanes:
177 // a_grouped = {a0, a1, b0, b1}
178 // b_grouped = {a2, a3, b2, b3}
179 auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000); // 0, 2, 1, 3
180 auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000); // 0, 2, 1, 3
181
182 // swap lanes:
183 // return {a0, a1, a2, a3}
184 // {b0, b1, b2, b3}
185 return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart
186 _mm256_permute2f128_pd(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
187 }
188
189 template <>
190 std::pair<Vectorized<float>, Vectorized<float>>
191 inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
192 // inputs:
193 // a = {a0, b0, a1, b1, a2, b2, a3, b3}
194 // b = {a4, b4, a5, b5, a6, b6, a7, b7}
195
196 // group cols crossing lanes:
197 // a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3}
198 // b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7}
199 // TODO: can we support caching this?
200 const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
201 auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl);
202 auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl);
203
204 // swap lanes:
205 // return {a0, a1, a2, a3, a4, a5, a6, a7}
206 // {b0, b1, b2, b3, b4, b5, b6, b7}
207 return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart
208 _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
209 }
210
211 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
212
213 template<>
flip(const Vectorized<float> & v)214 inline Vectorized<float> flip(const Vectorized<float> & v) {
215 const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
216 return _mm256_permutevar8x32_ps(v, mask_float);
217 }
218
219 template<>
flip(const Vectorized<double> & v)220 inline Vectorized<double> flip(const Vectorized<double> & v) {
221 return _mm256_permute4x64_pd(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3)
222 }
223
224 template<>
flip(const Vectorized<int64_t> & v)225 inline Vectorized<int64_t> flip(const Vectorized<int64_t> & v) {
226 return _mm256_permute4x64_epi64(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3)
227 }
228
229 template<>
flip(const Vectorized<int32_t> & v)230 inline Vectorized<int32_t> flip(const Vectorized<int32_t> & v) {
231 const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
232 return _mm256_permutevar8x32_epi32(v, mask_int32);
233 }
234
235 template<>
flip(const Vectorized<int16_t> & v)236 inline Vectorized<int16_t> flip(const Vectorized<int16_t> & v) {
237 const __m256i mask = _mm256_set_epi8(
238 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14,
239 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14
240 );
241 auto reversed = _mm256_shuffle_epi8(v, mask);
242 return _mm256_permute2x128_si256(reversed, reversed, 1);
243 }
244
flip8(const __m256i & v)245 inline __m256i flip8(const __m256i & v) {
246 const __m256i mask_int8 = _mm256_set_epi8(
247 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
248 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
249 );
250 auto reversed = _mm256_shuffle_epi8(v, mask_int8);
251 return _mm256_permute2x128_si256(reversed, reversed, 1);
252 }
253
254 template<>
flip(const Vectorized<int8_t> & v)255 inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
256 return flip8(v);
257 }
258
259 template<>
flip(const Vectorized<uint8_t> & v)260 inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
261 return flip8(v);
262 }
263
264 #endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
265
266 }}}
267