xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cpu/vec/intrinsics.h>
4 #include <ATen/cpu/vec/vec_base.h>
5 #include <ATen/cpu/vec/vec256/vsx/vsx_helpers.h>
6 
7 #include <c10/util/irange.h>
8 #include <c10/util/quint8.h>
9 #include <array>
10 
11 // This file defines Vectorized<> for the quantized types.
12 //
13 //
14 // Currently, we simply use these classes as efficient converters between
15 // the quantized types and Vectorized<float>, usually in bandwidth-bound cases
16 // where doing the arithmetic in full-precision is acceptable (e.g.
17 // elementwise operators).
18 //
19 //
20 // Conversions are as follows:
21 //  Vectorized<quint8> -> 4x Vectorized<float>
22 //
23 // The size of the returned float vector is specified by the special
24 // constexpr function float_num_vecs. The type of the value returned
25 // from dequantize (and expected as an argument to quantize) is
26 // specified by float_vec_return_type.
27 //
28 // When writing kernels with these vectors, it is expected that floating-
29 // point operations will be carried out in a loop over Vectorized<T>::float_num_vecs
30 // iterations.
31 
32 namespace at {
33 namespace vec {
34 inline namespace CPU_CAPABILITY {
35 
36 const vint16 mask_unsigned = vec_splats((short int)0xFF);
37 template <>
38 struct Vectorized<c10::quint8> {
39  private:
40   union {
41     struct {
42       vuint8 _vec0;
43       vuint8 _vec1;
44     };
45     struct {
46       vbool8 _vecb0;
47       vbool8 _vecb1;
48     };
49 
50   } __attribute__((__may_alias__));
51 
52  public:
53   Vectorized() {}
54   using size_type = int;
55   static constexpr size_type size() {
56     return 32;
57   }
58 
59   static constexpr size_t float_num_vecs() {
60     return 4;
61   }
62   static constexpr int int_num_vecs() {
63     return 4;
64   }
65   using float_vec_return_type = std::array<Vectorized<float>, 4>;
66   using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
67   using value_type = typename c10::quint8::underlying;
68   using vec_internal_type = vuint8;
69   using vec_internal_mask_type = vbool8;
70   // Broadcast constructor
71   C10_ALWAYS_INLINE Vectorized(const c10::quint8& val)
72       : _vec0(vec_splats(val.val_)), _vec1(vec_splats(val.val_)) {}
73 
74   C10_ALWAYS_INLINE Vectorized(const Vectorized<c10::quint8>& other)
75       : _vec0{other._vec0}, _vec1(other._vec1) {}
76 
77   C10_ALWAYS_INLINE Vectorized(vuint8 v) : _vec0{v}, _vec1{v} {}
78   C10_ALWAYS_INLINE Vectorized(vbool8 vmask) : _vecb0{vmask}, _vecb1{vmask} {}
79   C10_ALWAYS_INLINE Vectorized(vuint8 v1, vuint8 v2) : _vec0{v1}, _vec1{v2} {}
80   C10_ALWAYS_INLINE Vectorized(vbool8 v1, vbool8 v2) : _vecb0{v1}, _vecb1{v2} {}
81 
82   C10_ALWAYS_INLINE const vec_internal_type& vec0() const {
83     return _vec0;
84   }
85   C10_ALWAYS_INLINE const vec_internal_type& vec1() const {
86     return _vec1;
87   }
88 
89   static C10_ALWAYS_INLINE Vectorized<c10::quint8> loadu(
90       const void* ptr,
91       int count = size()) {
92     if (count == size()) {
93       return {
94           vec_vsx_ld(offset0, reinterpret_cast<const value_type*>(ptr)),
95           vec_vsx_ld(offset16, reinterpret_cast<const value_type*>(ptr))};
96     }
97     __at_align__ value_type tmp_values[size()] = {};
98     std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type));
99     return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)};
100   }
101   void C10_ALWAYS_INLINE store(void* ptr, int count = size()) const {
102     if (count == size()) {
103       vec_vsx_st(_vec0, offset0, reinterpret_cast<value_type*>(ptr));
104       vec_vsx_st(_vec1, offset16, reinterpret_cast<value_type*>(ptr));
105     } else if (count > 0) {
106       __at_align__ value_type tmp_values[size()];
107       vec_vsx_st(_vec0, offset0, tmp_values);
108       vec_vsx_st(_vec1, offset16, tmp_values);
109       std::memcpy(
110           ptr, tmp_values, std::min(count, size()) * sizeof(value_type));
111     }
112   }
113 
114  public:
115   float_vec_return_type C10_ALWAYS_INLINE dequantize(
116       Vectorized<float> scale,
117       Vectorized<float> zero_point,
118       Vectorized<float> scale_zp_premul) const {
119     // unpacking unsigned as signed
120     vint16 vecshi0 = vec_unpackh((vint8)_vec0);
121     vint16 vecshi1 = vec_unpackl((vint8)_vec0);
122 
123     vint16 vecshi2 = vec_unpackh((vint8)_vec1);
124     vint16 vecshi3 = vec_unpackl((vint8)_vec1);
125 
126     // signed ->  unsigned
127     vecshi0 = vec_and(vecshi0, mask_unsigned);
128     vecshi1 = vec_and(vecshi1, mask_unsigned);
129 
130     vecshi2 = vec_and(vecshi2, mask_unsigned);
131     vecshi3 = vec_and(vecshi3, mask_unsigned);
132 
133     vint32 veci0 = vec_unpackh(vecshi0);
134     vint32 veci1 = vec_unpackl(vecshi0);
135 
136     vint32 veci2 = vec_unpackh(vecshi1);
137     vint32 veci3 = vec_unpackl(vecshi1);
138 
139     vint32 veci4 = vec_unpackh(vecshi2);
140     vint32 veci5 = vec_unpackl(vecshi2);
141 
142     vint32 veci6 = vec_unpackh(vecshi3);
143     vint32 veci7 = vec_unpackl(vecshi3);
144 
145     vfloat32 vecf0_0 = vec_float(veci0);
146     vfloat32 vecf1_0 = vec_float(veci1);
147 
148     vfloat32 vecf0_1 = vec_float(veci2);
149     vfloat32 vecf1_1 = vec_float(veci3);
150 
151     vfloat32 vecf0_2 = vec_float(veci4);
152     vfloat32 vecf1_2 = vec_float(veci5);
153 
154     vfloat32 vecf0_3 = vec_float(veci6);
155     vfloat32 vecf1_3 = vec_float(veci7);
156     vfloat32 scale_vec0 = scale.vec0();
157     vfloat32 scale_vec1 = scale.vec1();
158     vfloat32 scale_zp_premul0 = scale_zp_premul.vec0();
159     vfloat32 scale_zp_premul1 = scale_zp_premul.vec1();
160     return {
161         Vectorized<float>{
162             vec_madd(scale_vec0, vecf0_0, scale_zp_premul0),
163             vec_madd(scale_vec1, vecf1_0, scale_zp_premul1)},
164         Vectorized<float>{
165             vec_madd(scale_vec0, vecf0_1, scale_zp_premul0),
166             vec_madd(scale_vec1, vecf1_1, scale_zp_premul1)},
167         Vectorized<float>{
168             vec_madd(scale_vec0, vecf0_2, scale_zp_premul0),
169             vec_madd(scale_vec1, vecf1_2, scale_zp_premul1)},
170         Vectorized<float>{
171             vec_madd(scale_vec0, vecf0_3, scale_zp_premul0),
172             vec_madd(scale_vec1, vecf1_3, scale_zp_premul1)}};
173   }
174 
175   float_vec_return_type C10_ALWAYS_INLINE dequantize(
176       Vectorized<float> scale,
177       Vectorized<float> zero_point) const {
178     // unpacking unsigned as signed
179     vint16 vecshi0 = vec_unpackh((vint8)_vec0);
180     vint16 vecshi1 = vec_unpackl((vint8)_vec0);
181 
182     vint16 vecshi2 = vec_unpackh((vint8)_vec1);
183     vint16 vecshi3 = vec_unpackl((vint8)_vec1);
184 
185     // signed ->  unsigned
186     vecshi0 = vec_and(vecshi0, mask_unsigned);
187     vecshi1 = vec_and(vecshi1, mask_unsigned);
188 
189     vecshi2 = vec_and(vecshi2, mask_unsigned);
190     vecshi3 = vec_and(vecshi3, mask_unsigned);
191 
192     vint32 veci0 = vec_unpackh(vecshi0);
193     vint32 veci1 = vec_unpackl(vecshi0);
194 
195     vint32 veci2 = vec_unpackh(vecshi1);
196     vint32 veci3 = vec_unpackl(vecshi1);
197 
198     vint32 veci4 = vec_unpackh(vecshi2);
199     vint32 veci5 = vec_unpackl(vecshi2);
200 
201     vint32 veci6 = vec_unpackh(vecshi3);
202     vint32 veci7 = vec_unpackl(vecshi3);
203 
204     vfloat32 vecf0_0 = vec_float(veci0);
205     vfloat32 vecf1_0 = vec_float(veci1);
206 
207     vfloat32 vecf0_1 = vec_float(veci2);
208     vfloat32 vecf1_1 = vec_float(veci3);
209 
210     vfloat32 vecf0_2 = vec_float(veci4);
211     vfloat32 vecf1_2 = vec_float(veci5);
212 
213     vfloat32 vecf0_3 = vec_float(veci6);
214     vfloat32 vecf1_3 = vec_float(veci7);
215     vfloat32 scale_vec0 = scale.vec0();
216     vfloat32 scale_vec1 = scale.vec1();
217     vfloat32 zero_point0 = zero_point.vec0();
218     vfloat32 zero_point1 = zero_point.vec1();
219     return {
220         Vectorized<float>{
221             (vecf0_0 - zero_point0) * scale_vec0,
222             (vecf1_0 - zero_point1) * scale_vec1},
223         Vectorized<float>{
224             (vecf0_1 - zero_point0) * scale_vec0,
225             (vecf1_1 - zero_point1) * scale_vec1},
226         Vectorized<float>{
227             (vecf0_2 - zero_point0) * scale_vec0,
228             (vecf1_2 - zero_point1) * scale_vec1},
229         Vectorized<float>{
230             (vecf0_3 - zero_point0) * scale_vec0,
231             (vecf1_3 - zero_point1) * scale_vec1}};
232   }
233 
234   static Vectorized<c10::quint8> quantize(
235       const float_vec_return_type& rhs,
236       float scale,
237       int32_t zero_point,
238       float inverse_scale) {
239     // constexpr int32_t min_val = std::numeric_limits<value_type>::min();
240     // constexpr int32_t max_val = std::numeric_limits<value_type>::max();
241 
242     vfloat32 vec_inverse = vec_splats(inverse_scale);
243     vfloat32 vec_zero_point = vec_splats((float)zero_point);
244     // vuint32 vmin = vec_splats(min_val);
245     // vuint32 vmax = vec_splats(max_val);
246     Vectorized<float> vf0 = rhs[0];
247     Vectorized<float> vf1 = rhs[1];
248     Vectorized<float> vf2 = rhs[2];
249     Vectorized<float> vf3 = rhs[3];
250     vfloat32 vecf0 = vf0.vec0();
251     vfloat32 vecf1 = vf0.vec1();
252     vfloat32 vecf2 = vf1.vec0();
253     vfloat32 vecf3 = vf1.vec1();
254 
255     vfloat32 vecf4 = vf2.vec0();
256     vfloat32 vecf5 = vf2.vec1();
257     vfloat32 vecf6 = vf3.vec0();
258     vfloat32 vecf7 = vf3.vec1();
259 
260     vecf0 = vec_mul(vecf0, vec_inverse);
261     vecf1 = vec_mul(vecf1, vec_inverse);
262     vecf2 = vec_mul(vecf2, vec_inverse);
263     vecf3 = vec_mul(vecf3, vec_inverse);
264 
265     vecf4 = vec_mul(vecf4, vec_inverse);
266     vecf5 = vec_mul(vecf5, vec_inverse);
267     vecf6 = vec_mul(vecf6, vec_inverse);
268     vecf7 = vec_mul(vecf7, vec_inverse);
269 
270     vecf0 = vec_add(vec_rint(vecf0), vec_zero_point);
271     vecf1 = vec_add(vec_rint(vecf1), vec_zero_point);
272     vecf2 = vec_add(vec_rint(vecf2), vec_zero_point);
273     vecf3 = vec_add(vec_rint(vecf3), vec_zero_point);
274 
275     vecf4 = vec_add(vec_rint(vecf4), vec_zero_point);
276     vecf5 = vec_add(vec_rint(vecf5), vec_zero_point);
277     vecf6 = vec_add(vec_rint(vecf6), vec_zero_point);
278     vecf7 = vec_add(vec_rint(vecf7), vec_zero_point);
279 
280     vint32 veci0 = vec_signed(vecf0);
281     vint32 veci1 = vec_signed(vecf1);
282     vint32 veci2 = vec_signed(vecf2);
283     vint32 veci3 = vec_signed(vecf3);
284 
285     vint32 veci4 = vec_signed(vecf4);
286     vint32 veci5 = vec_signed(vecf5);
287     vint32 veci6 = vec_signed(vecf6);
288     vint32 veci7 = vec_signed(vecf7);
289 
290     vint16 vecshi0 = vec_packs(veci0, veci1);
291     vint16 vecshi1 = vec_packs(veci2, veci3);
292     vint16 vecshi2 = vec_packs(veci4, veci5);
293     vint16 vecshi3 = vec_packs(veci6, veci7);
294 
295     vuint8 vec0 = vec_packsu(vecshi0, vecshi1);
296     vuint8 vec1 = vec_packsu(vecshi2, vecshi3);
297 
298     return {vec0, vec1};
299   }
300 
301   Vectorized<c10::quint8> C10_ALWAYS_INLINE relu(Vectorized<c10::quint8> zero_point) const {
302     return {vec_max(_vec0, zero_point._vec0), vec_max(_vec1, zero_point._vec1)};
303   }
304 
305   Vectorized<c10::quint8> C10_ALWAYS_INLINE
306   relu6(Vectorized<c10::quint8> zero_point, Vectorized<c10::quint8> q_six) const {
307     vuint8 max0 = vec_max(_vec0, zero_point._vec0);
308     vuint8 max1 = vec_max(_vec1, zero_point._vec1);
309     return {vec_min(max0, q_six._vec0), vec_min(max1, q_six._vec1)};
310   }
311 
312   int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
313     vint16 vecshi0 = vec_unpackh((vint8)_vec0);
314     vint16 vecBshi0 = vec_unpackh((vint8)b._vec0);
315     vint16 vecshi1 = vec_unpackl((vint8)_vec0);
316     vint16 vecBshi1 = vec_unpackl((vint8)b._vec0);
317 
318     vint16 vecshi2 = vec_unpackh((vint8)_vec1);
319     vint16 vecBshi2 = vec_unpackh((vint8)b._vec1);
320     vint16 vecshi3 = vec_unpackl((vint8)_vec1);
321     vint16 vecBshi3 = vec_unpackl((vint8)b._vec1);
322 
323     vecshi0 = vec_and(vecshi0, mask_unsigned);
324     vecBshi0 = vec_and(vecBshi0, mask_unsigned);
325     vecshi1 = vec_and(vecshi1, mask_unsigned);
326     vecBshi1 = vec_and(vecBshi1, mask_unsigned);
327 
328     vecshi2 = vec_and(vecshi2, mask_unsigned);
329     vecBshi2 = vec_and(vecBshi2, mask_unsigned);
330     vecshi3 = vec_and(vecshi3, mask_unsigned);
331     vecBshi3 = vec_and(vecBshi3, mask_unsigned);
332 
333     vint32 veci0 = vec_unpackh(vecshi0);
334     vint32 vecBi0 = vec_unpackh(vecBshi0);
335     vint32 veci1 = vec_unpackl(vecshi0);
336     vint32 vecBi1 = vec_unpackl(vecBshi0);
337 
338     vint32 veci2 = vec_unpackh(vecshi1);
339     vint32 vecBi2 = vec_unpackh(vecBshi1);
340     vint32 veci3 = vec_unpackl(vecshi1);
341     vint32 vecBi3 = vec_unpackl(vecBshi1);
342 
343     vint32 veci4 = vec_unpackh(vecshi2);
344     vint32 vecBi4 = vec_unpackh(vecBshi2);
345     vint32 veci5 = vec_unpackl(vecshi2);
346     vint32 vecBi5 = vec_unpackl(vecBshi2);
347 
348     vint32 veci6 = vec_unpackh(vecshi3);
349     vint32 vecBi6 = vec_unpackh(vecBshi3);
350     vint32 veci7 = vec_unpackl(vecshi3);
351     vint32 vecBi7 = vec_unpackl(vecBshi3);
352 
353     return {
354         Vectorized<c10::qint32>(veci0 - vecBi0, veci1 - vecBi1),
355         Vectorized<c10::qint32>(veci2 - vecBi2, veci3 - vecBi3),
356         Vectorized<c10::qint32>(veci4 - vecBi4, veci5 - vecBi5),
357         Vectorized<c10::qint32>(veci6 - vecBi6, veci7 - vecBi7)};
358   }
359 
360   static Vectorized<c10::quint8> requantize_from_int(
361       const int_vec_return_type& inp,
362       float multiplier,
363       int32_t zero_point) {
364     vfloat32 vec_multiplier = vec_splats(multiplier);
365     vint32 vec_zero_point = vec_splats(zero_point);
366 
367     Vectorized<c10::qint32> vi0 = inp[0];
368     Vectorized<c10::qint32> vi1 = inp[1];
369     Vectorized<c10::qint32> vi2 = inp[2];
370     Vectorized<c10::qint32> vi3 = inp[3];
371 
372     vfloat32 vecf0 = vec_float(vi0.vec0());
373     vfloat32 vecf1 = vec_float(vi0.vec1());
374     vfloat32 vecf2 = vec_float(vi1.vec0());
375     vfloat32 vecf3 = vec_float(vi1.vec1());
376 
377     vfloat32 vecf4 = vec_float(vi2.vec0());
378     vfloat32 vecf5 = vec_float(vi2.vec1());
379     vfloat32 vecf6 = vec_float(vi3.vec0());
380     vfloat32 vecf7 = vec_float(vi3.vec1());
381 
382     vecf0 = vec_mul(vecf0, vec_multiplier);
383     vecf1 = vec_mul(vecf1, vec_multiplier);
384     vecf2 = vec_mul(vecf2, vec_multiplier);
385     vecf3 = vec_mul(vecf3, vec_multiplier);
386 
387     vecf4 = vec_mul(vecf4, vec_multiplier);
388     vecf5 = vec_mul(vecf5, vec_multiplier);
389     vecf6 = vec_mul(vecf6, vec_multiplier);
390     vecf7 = vec_mul(vecf7, vec_multiplier);
391 
392     vecf0 = vec_rint(vecf0);
393     vecf1 = vec_rint(vecf1);
394     vecf2 = vec_rint(vecf2);
395     vecf3 = vec_rint(vecf3);
396 
397     vecf4 = vec_rint(vecf4);
398     vecf5 = vec_rint(vecf5);
399     vecf6 = vec_rint(vecf6);
400     vecf7 = vec_rint(vecf7);
401 
402     vint32 veci0 = vec_signed(vecf0);
403     vint32 veci1 = vec_signed(vecf1);
404     vint32 veci2 = vec_signed(vecf2);
405     vint32 veci3 = vec_signed(vecf3);
406 
407     vint32 veci4 = vec_signed(vecf4);
408     vint32 veci5 = vec_signed(vecf5);
409     vint32 veci6 = vec_signed(vecf6);
410     vint32 veci7 = vec_signed(vecf7);
411 
412     veci0 = vec_add(veci0, vec_zero_point);
413     veci1 = vec_add(veci1, vec_zero_point);
414     veci2 = vec_add(veci2, vec_zero_point);
415     veci3 = vec_add(veci3, vec_zero_point);
416 
417     veci4 = vec_add(veci4, vec_zero_point);
418     veci5 = vec_add(veci5, vec_zero_point);
419     veci6 = vec_add(veci6, vec_zero_point);
420     veci7 = vec_add(veci7, vec_zero_point);
421 
422     vint16 vecshi0 = vec_packs(veci0, veci1);
423     vint16 vecshi1 = vec_packs(veci2, veci3);
424     vint16 vecshi2 = vec_packs(veci4, veci5);
425     vint16 vecshi3 = vec_packs(veci6, veci7);
426 
427     vuint8 vec0 = vec_packsu(vecshi0, vecshi1);
428     vuint8 vec1 = vec_packsu(vecshi2, vecshi3);
429 
430     return {vec0, vec1};
431   }
432 
433   DEFINE_MEMBER_OP(operator==, c10::quint8, vec_cmpeq)
434   DEFINE_MEMBER_OP(operator!=, c10::quint8, vec_cmpne)
435   DEFINE_MEMBER_OP(operator<, c10::quint8, vec_cmplt)
436   DEFINE_MEMBER_OP(operator<=, c10::quint8, vec_cmple)
437   DEFINE_MEMBER_OP(operator>, c10::quint8, vec_cmpgt)
438   DEFINE_MEMBER_OP(operator>=, c10::quint8, vec_cmpge)
439   DEFINE_MEMBER_OP(operator+, c10::quint8, vec_add)
440   DEFINE_MEMBER_OP(operator-, c10::quint8, vec_sub)
441   DEFINE_MEMBER_OP(operator*, c10::quint8, vec_mul)
442   DEFINE_MEMBER_EMULATE_BINARY_OP(operator/, c10::quint8, /)
443   DEFINE_MEMBER_OP(maximum, c10::quint8, vec_max)
444   DEFINE_MEMBER_OP(minimum, c10::quint8, vec_min)
445   DEFINE_MEMBER_OP(operator&, c10::quint8, vec_and)
446   DEFINE_MEMBER_OP(operator|, c10::quint8, vec_or)
447   DEFINE_MEMBER_OP(operator^, c10::quint8, vec_xor)
448 };
449 
450 template <>
451 Vectorized<c10::quint8> inline maximum(
452     const Vectorized<c10::quint8>& a,
453     const Vectorized<c10::quint8>& b) {
454   return a.maximum(b);
455 }
456 
457 template <>
458 Vectorized<c10::quint8> inline minimum(
459     const Vectorized<c10::quint8>& a,
460     const Vectorized<c10::quint8>& b) {
461   return a.minimum(b);
462 }
463 
464 template <>
465 Vectorized<c10::quint8> C10_ALWAYS_INLINE operator+(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
466   return Vectorized<c10::quint8>{vec_add(a.vec0(), b.vec0()), vec_add(a.vec1(), b.vec1())};
467 }
468 
469 template <>
470 Vectorized<c10::quint8> C10_ALWAYS_INLINE operator-(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
471   return Vectorized<c10::quint8>{vec_sub(a.vec0(), b.vec0()), vec_sub(a.vec1(), b.vec1())};
472 }
473 
474 template <>
475 Vectorized<c10::quint8> C10_ALWAYS_INLINE operator*(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
476   return Vectorized<c10::quint8>{vec_mul(a.vec0(), b.vec0()), vec_mul(a.vec1(), b.vec1())};
477 }
478 
479 template <>
480 Vectorized<c10::quint8> C10_ALWAYS_INLINE operator/(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
481   return Vectorized<c10::quint8>{a.vec0()/b.vec0(), a.vec1()/b.vec1()};
482 }
483 
484 template <>
485 Vectorized<c10::quint8> C10_ALWAYS_INLINE operator&(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
486   return Vectorized<c10::quint8>{vec_and(a.vec0(), b.vec0()), vec_and(a.vec1(), b.vec1())};
487 }
488 
489 template <>
490 Vectorized<c10::quint8> C10_ALWAYS_INLINE operator|(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
491   return Vectorized<c10::quint8>{vec_or(a.vec0(), b.vec0()), vec_or(a.vec1(), b.vec1())};
492 }
493 
494 template <>
495 Vectorized<c10::quint8> C10_ALWAYS_INLINE operator^(const Vectorized<c10::quint8>& a, const Vectorized<c10::quint8>& b) {
496   return Vectorized<c10::quint8>{vec_xor(a.vec0(), b.vec0()), vec_xor(a.vec1(), b.vec1())};
497 }
498 
499 } // namespace
500 } // namespace vec
501 } // namespace at
502