xref: /aosp_15_r20/external/executorch/kernels/optimized/vec/vec256/vec256.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 // DO NOT DEFINE STATIC DATA IN THIS HEADER!
12 // See Note [Do not compile initializers with AVX]
13 
14 #include <executorch/kernels/optimized/vec/intrinsics.h>
15 
16 #include <executorch/kernels/optimized/vec/vec_base.h>
17 #if !(defined(__VSX__)  || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR))
18 #include <executorch/kernels/optimized/vec/vec256/vec256_float.h>
19 #include <executorch/kernels/optimized/vec/vec256/vec256_float_neon.h>
20 #include <executorch/kernels/optimized/vec/vec256/vec256_double.h>
21 #include <executorch/kernels/optimized/vec/vec256/vec256_int.h>
22 #endif
23 
24 #include <algorithm>
25 #include <cstddef>
26 #include <cstdint>
27 #include <cstring>
28 #include <ostream>
29 
30 namespace executorch {
31 namespace vec {
32 
33 // Note [CPU_CAPABILITY namespace]
34 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
35 // This header, and all of its subheaders, will be compiled with
36 // different architecture flags for each supported set of vector
37 // intrinsics. So we need to make sure they aren't inadvertently
38 // linked together. We do this by declaring objects in an `inline
39 // namespace` which changes the name mangling, but can still be
40 // accessed as `at::vec`.
41 inline namespace CPU_CAPABILITY {
42 
43 template <typename T>
44 std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
45   T buf[Vectorized<T>::size()];
46   vec.store(buf);
47   stream << "vec[";
48   for (size_t i = 0; i != Vectorized<T>::size(); i++) {
49     if (i != 0) {
50       stream << ", ";
51     }
52     stream << buf[i];
53   }
54   stream << "]";
55   return stream;
56 }
57 
58 
59 #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
60 
61 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
62 
63 template<>
64 inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
65   return _mm256_castpd_ps(src);
66 }
67 
68 template<>
69 inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
70   return _mm256_castps_pd(src);
71 }
72 
73 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
74 
75 template<int64_t scale = 1>
76 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
gather(const double * base_addr,const Vectorized<int64_t> & vindex)77 inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
78   return _mm256_i64gather_pd(base_addr, vindex, scale);
79 }
80 
81 template<int64_t scale = 1>
82 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
gather(const float * base_addr,const Vectorized<int32_t> & vindex)83 inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
84   return _mm256_i32gather_ps(base_addr, vindex, scale);
85 }
86 
87 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
88 
89 template<int64_t scale = 1>
90 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
mask_gather(const Vectorized<double> & src,const double * base_addr,const Vectorized<int64_t> & vindex,const Vectorized<double> & mask)91 inline mask_gather(const Vectorized<double>& src, const double* base_addr,
92                    const Vectorized<int64_t>& vindex, const Vectorized<double>& mask) {
93   return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale);
94 }
95 
96 template<int64_t scale = 1>
97 std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
mask_gather(const Vectorized<float> & src,const float * base_addr,const Vectorized<int32_t> & vindex,const Vectorized<float> & mask)98 inline mask_gather(const Vectorized<float>& src, const float* base_addr,
99                    const Vectorized<int32_t>& vindex, const Vectorized<float>& mask) {
100   return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
101 }
102 
103 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104 
105 // Only works for inputs in the range: [-2^51, 2^51]
106 // From: https://stackoverflow.com/a/41148578
107 template<>
108 Vectorized<int64_t>
109 inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
110   auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000));
111   return _mm256_sub_epi64(
112       _mm256_castpd_si256(x),
113       _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))
114   );
115 }
116 
117 template<>
118 Vectorized<int32_t>
119 inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
120   return _mm256_cvttps_epi32(src);
121 }
122 
123 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
124 
125 template <>
126 std::pair<Vectorized<double>, Vectorized<double>>
127 inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
128   // inputs:
129   //   a = {a0, a1, a3, a3}
130   //   b = {b0, b1, b2, b3}
131 
132   // swap lanes:
133   //   a_swapped = {a0, a1, b0, b1}
134   //   b_swapped = {a2, a3, b2, b3}
135   auto a_swapped = _mm256_permute2f128_pd(a, b, 0b0100000);  // 0, 2.   4 bits apart
136   auto b_swapped = _mm256_permute2f128_pd(a, b, 0b0110001);  // 1, 3.   4 bits apart
137 
138   // group cols crossing lanes:
139   //   return {a0, b0, a1, b1}
140   //          {a2, b2, a3, b3}
141   return std::make_pair(_mm256_permute4x64_pd(a_swapped, 0b11011000),  // 0, 2, 1, 3
142                         _mm256_permute4x64_pd(b_swapped, 0b11011000)); // 0, 2, 1, 3
143 }
144 
145 template <>
146 std::pair<Vectorized<float>, Vectorized<float>>
147 inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
148   // inputs:
149   //   a = {a0, a1, a2, a3, a4, a5, a6, a7}
150   //   b = {b0, b1, b2, b3, b4, b5, b6, b7}
151 
152   // swap lanes:
153   //   a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3}
154   //   b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7}
155   // TODO: can we support caching this?
156   auto a_swapped = _mm256_permute2f128_ps(a, b, 0b0100000);  // 0, 2.   4 bits apart
157   auto b_swapped = _mm256_permute2f128_ps(a, b, 0b0110001);  // 1, 3.   4 bits apart
158 
159   // group cols crossing lanes:
160   //   return {a0, b0, a1, b1, a2, b2, a3, b3}
161   //          {a4, b4, a5, b5, a6, b6, a7, b7}
162   const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
163   return std::make_pair(_mm256_permutevar8x32_ps(a_swapped, group_ctrl),
164                         _mm256_permutevar8x32_ps(b_swapped, group_ctrl));
165 }
166 
167 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
168 
169 template <>
170 std::pair<Vectorized<double>, Vectorized<double>>
171 inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
172   // inputs:
173   //   a = {a0, b0, a1, b1}
174   //   b = {a2, b2, a3, b3}
175 
176   // group cols crossing lanes:
177   //   a_grouped = {a0, a1, b0, b1}
178   //   b_grouped = {a2, a3, b2, b3}
179   auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000);  // 0, 2, 1, 3
180   auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000);  // 0, 2, 1, 3
181 
182   // swap lanes:
183   //   return {a0, a1, a2, a3}
184   //          {b0, b1, b2, b3}
185   return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, 0b0100000),  // 0, 2.   4 bits apart
186                         _mm256_permute2f128_pd(a_grouped, b_grouped, 0b0110001)); // 1, 3.   4 bits apart
187 }
188 
189 template <>
190 std::pair<Vectorized<float>, Vectorized<float>>
191 inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
192   // inputs:
193   //   a = {a0, b0, a1, b1, a2, b2, a3, b3}
194   //   b = {a4, b4, a5, b5, a6, b6, a7, b7}
195 
196   // group cols crossing lanes:
197   //   a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3}
198   //   b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7}
199   // TODO: can we support caching this?
200   const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
201   auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl);
202   auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl);
203 
204   // swap lanes:
205   //   return {a0, a1, a2, a3, a4, a5, a6, a7}
206   //          {b0, b1, b2, b3, b4, b5, b6, b7}
207   return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0100000),  // 0, 2.   4 bits apart
208                         _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3.   4 bits apart
209 }
210 
211 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
212 
213 template<>
flip(const Vectorized<float> & v)214 inline Vectorized<float> flip(const Vectorized<float> & v) {
215   const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
216   return _mm256_permutevar8x32_ps(v, mask_float);
217 }
218 
219 template<>
flip(const Vectorized<double> & v)220 inline Vectorized<double> flip(const Vectorized<double> & v) {
221   return _mm256_permute4x64_pd(v, 27);  // 27 == _MM_SHUFFLE(0, 1, 2, 3)
222 }
223 
224 template<>
flip(const Vectorized<int64_t> & v)225 inline Vectorized<int64_t> flip(const Vectorized<int64_t> & v) {
226   return _mm256_permute4x64_epi64(v, 27);  // 27 == _MM_SHUFFLE(0, 1, 2, 3)
227 }
228 
229 template<>
flip(const Vectorized<int32_t> & v)230 inline Vectorized<int32_t> flip(const Vectorized<int32_t> & v) {
231   const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
232   return _mm256_permutevar8x32_epi32(v, mask_int32);
233 }
234 
235 template<>
flip(const Vectorized<int16_t> & v)236 inline Vectorized<int16_t> flip(const Vectorized<int16_t> & v) {
237   const __m256i mask = _mm256_set_epi8(
238     1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14,
239     1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14
240   );
241   auto reversed = _mm256_shuffle_epi8(v, mask);
242   return _mm256_permute2x128_si256(reversed, reversed, 1);
243 }
244 
flip8(const __m256i & v)245 inline __m256i flip8(const __m256i & v) {
246   const __m256i mask_int8 = _mm256_set_epi8(
247     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
248     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
249   );
250   auto reversed = _mm256_shuffle_epi8(v, mask_int8);
251   return _mm256_permute2x128_si256(reversed, reversed, 1);
252 }
253 
254 template<>
flip(const Vectorized<int8_t> & v)255 inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
256   return flip8(v);
257 }
258 
259 template<>
flip(const Vectorized<uint8_t> & v)260 inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
261   return flip8(v);
262 }
263 
264 #endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
265 
266 }}}
267