xref: /aosp_15_r20/external/XNNPACK/src/x8-zip/xm-neon.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright (c) Facebook, Inc. and its affiliates.
2*4bdc9457SAndroid Build Coastguard Worker // All rights reserved.
3*4bdc9457SAndroid Build Coastguard Worker //
4*4bdc9457SAndroid Build Coastguard Worker // Copyright 2019 Google LLC
5*4bdc9457SAndroid Build Coastguard Worker //
6*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
7*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
8*4bdc9457SAndroid Build Coastguard Worker 
9*4bdc9457SAndroid Build Coastguard Worker #include <arm_neon.h>
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/zip.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker 
xnn_x8_zip_xm_ukernel__neon(size_t n,size_t m,const uint8_t * input,uint8_t * output)14*4bdc9457SAndroid Build Coastguard Worker void xnn_x8_zip_xm_ukernel__neon(
15*4bdc9457SAndroid Build Coastguard Worker     size_t n,
16*4bdc9457SAndroid Build Coastguard Worker     size_t m,
17*4bdc9457SAndroid Build Coastguard Worker     const uint8_t* input,
18*4bdc9457SAndroid Build Coastguard Worker     uint8_t* output)
19*4bdc9457SAndroid Build Coastguard Worker {
20*4bdc9457SAndroid Build Coastguard Worker   const uint8_t* w = input;
21*4bdc9457SAndroid Build Coastguard Worker   const size_t input_increment = n * 3;
22*4bdc9457SAndroid Build Coastguard Worker   const size_t output_increment = 4 - m * n;
23*4bdc9457SAndroid Build Coastguard Worker   const uint8_t* last_input = w + n * (m - 1);
24*4bdc9457SAndroid Build Coastguard Worker   uint8_t* last_output = (uint8_t*) ((uintptr_t) output + (m - 4));
25*4bdc9457SAndroid Build Coastguard Worker 
26*4bdc9457SAndroid Build Coastguard Worker   if (n >= 8) {
27*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < m; i += 4) {
28*4bdc9457SAndroid Build Coastguard Worker       size_t k = n;
29*4bdc9457SAndroid Build Coastguard Worker       w = (const uint8_t*) ((uintptr_t) w + input_increment);
30*4bdc9457SAndroid Build Coastguard Worker       if (w >= last_input) {
31*4bdc9457SAndroid Build Coastguard Worker         w = last_input;
32*4bdc9457SAndroid Build Coastguard Worker       }
33*4bdc9457SAndroid Build Coastguard Worker       const uint8_t* z = (const uint8_t*) ((uintptr_t) w - n);
34*4bdc9457SAndroid Build Coastguard Worker       const uint8_t* y = (const uint8_t*) ((uintptr_t) z - n);
35*4bdc9457SAndroid Build Coastguard Worker       const uint8_t* x = (const uint8_t*) ((uintptr_t) y - n);
36*4bdc9457SAndroid Build Coastguard Worker       while (k >= 8) {
37*4bdc9457SAndroid Build Coastguard Worker         const uint8x8_t vx = vld1_u8(x); x += 8;
38*4bdc9457SAndroid Build Coastguard Worker         const uint8x8_t vy = vld1_u8(y); y += 8;
39*4bdc9457SAndroid Build Coastguard Worker         const uint8x8_t vz = vld1_u8(z); z += 8;
40*4bdc9457SAndroid Build Coastguard Worker         const uint8x8_t vw = vld1_u8(w); w += 8;
41*4bdc9457SAndroid Build Coastguard Worker 
42*4bdc9457SAndroid Build Coastguard Worker         const uint8x8x2_t vxy = vzip_u8(vx, vy);
43*4bdc9457SAndroid Build Coastguard Worker         const uint8x8x2_t vzw = vzip_u8(vz, vw);
44*4bdc9457SAndroid Build Coastguard Worker         const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0]));
45*4bdc9457SAndroid Build Coastguard Worker         const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1]));
46*4bdc9457SAndroid Build Coastguard Worker 
47*4bdc9457SAndroid Build Coastguard Worker         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[0]), 0);
48*4bdc9457SAndroid Build Coastguard Worker         output = (uint8_t*) ((uintptr_t) output + m);
49*4bdc9457SAndroid Build Coastguard Worker 
50*4bdc9457SAndroid Build Coastguard Worker         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[0]), 1);
51*4bdc9457SAndroid Build Coastguard Worker         output = (uint8_t*) ((uintptr_t) output + m);
52*4bdc9457SAndroid Build Coastguard Worker 
53*4bdc9457SAndroid Build Coastguard Worker         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[1]), 0);
54*4bdc9457SAndroid Build Coastguard Worker         output = (uint8_t*) ((uintptr_t) output + m);
55*4bdc9457SAndroid Build Coastguard Worker 
56*4bdc9457SAndroid Build Coastguard Worker         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_lo.val[1]), 1);
57*4bdc9457SAndroid Build Coastguard Worker         output = (uint8_t*) ((uintptr_t) output + m);
58*4bdc9457SAndroid Build Coastguard Worker 
59*4bdc9457SAndroid Build Coastguard Worker         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[0]), 0);
60*4bdc9457SAndroid Build Coastguard Worker         output = (uint8_t*) ((uintptr_t) output + m);
61*4bdc9457SAndroid Build Coastguard Worker 
62*4bdc9457SAndroid Build Coastguard Worker         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[0]), 1);
63*4bdc9457SAndroid Build Coastguard Worker         output = (uint8_t*) ((uintptr_t) output + m);
64*4bdc9457SAndroid Build Coastguard Worker 
65*4bdc9457SAndroid Build Coastguard Worker         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[1]), 0);
66*4bdc9457SAndroid Build Coastguard Worker         output = (uint8_t*) ((uintptr_t) output + m);
67*4bdc9457SAndroid Build Coastguard Worker 
68*4bdc9457SAndroid Build Coastguard Worker         vst1_lane_u32((void*) output, vreinterpret_u32_u16(vxyzw_hi.val[1]), 1);
69*4bdc9457SAndroid Build Coastguard Worker         output = (uint8_t*) ((uintptr_t) output + m);
70*4bdc9457SAndroid Build Coastguard Worker 
71*4bdc9457SAndroid Build Coastguard Worker         k -= 8;
72*4bdc9457SAndroid Build Coastguard Worker       }
73*4bdc9457SAndroid Build Coastguard Worker       if (k != 0) {
74*4bdc9457SAndroid Build Coastguard Worker         const size_t address_increment = k - 8;
75*4bdc9457SAndroid Build Coastguard Worker         x = (const uint8_t*) ((uintptr_t) x + address_increment);
76*4bdc9457SAndroid Build Coastguard Worker         y = (const uint8_t*) ((uintptr_t) y + address_increment);
77*4bdc9457SAndroid Build Coastguard Worker         z = (const uint8_t*) ((uintptr_t) z + address_increment);
78*4bdc9457SAndroid Build Coastguard Worker         w = (const uint8_t*) ((uintptr_t) w + address_increment);
79*4bdc9457SAndroid Build Coastguard Worker         const int64x1_t vshift = vmov_n_s64(8 * address_increment);
80*4bdc9457SAndroid Build Coastguard Worker 
81*4bdc9457SAndroid Build Coastguard Worker         const uint64x1_t vx = vshl_u64(vreinterpret_u64_u8(vld1_u8(x)), vshift);
82*4bdc9457SAndroid Build Coastguard Worker         const uint64x1_t vy = vshl_u64(vreinterpret_u64_u8(vld1_u8(y)), vshift);
83*4bdc9457SAndroid Build Coastguard Worker         const uint64x1_t vz = vshl_u64(vreinterpret_u64_u8(vld1_u8(z)), vshift);
84*4bdc9457SAndroid Build Coastguard Worker         const uint64x1_t vw = vshl_u64(vreinterpret_u64_u8(vld1_u8(w)), vshift); w += 8;
85*4bdc9457SAndroid Build Coastguard Worker         const uint8x8x2_t vxy = vzip_u8(vreinterpret_u8_u64(vx), vreinterpret_u8_u64(vy));
86*4bdc9457SAndroid Build Coastguard Worker         const uint8x8x2_t vzw = vzip_u8(vreinterpret_u8_u64(vz), vreinterpret_u8_u64(vw));
87*4bdc9457SAndroid Build Coastguard Worker         const uint16x4x2_t vxyzw_lo = vzip_u16(vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0]));
88*4bdc9457SAndroid Build Coastguard Worker         const uint16x4x2_t vxyzw_hi = vzip_u16(vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1]));
89*4bdc9457SAndroid Build Coastguard Worker 
90*4bdc9457SAndroid Build Coastguard Worker         uint32x2_t vxyzw0 = vreinterpret_u32_u16(vxyzw_lo.val[0]);
91*4bdc9457SAndroid Build Coastguard Worker         uint32x2_t vxyzw1 = vreinterpret_u32_u16(vxyzw_lo.val[1]);
92*4bdc9457SAndroid Build Coastguard Worker         uint32x2_t vxyzw2 = vreinterpret_u32_u16(vxyzw_hi.val[0]);
93*4bdc9457SAndroid Build Coastguard Worker         uint32x2_t vxyzw3 = vreinterpret_u32_u16(vxyzw_hi.val[1]);
94*4bdc9457SAndroid Build Coastguard Worker 
95*4bdc9457SAndroid Build Coastguard Worker         if (k & 4) {
96*4bdc9457SAndroid Build Coastguard Worker           vst1_lane_u32((void*) output, vxyzw0, 0);
97*4bdc9457SAndroid Build Coastguard Worker           output = (uint8_t*) ((uintptr_t) output + m);
98*4bdc9457SAndroid Build Coastguard Worker 
99*4bdc9457SAndroid Build Coastguard Worker           vst1_lane_u32((void*) output, vxyzw0, 1);
100*4bdc9457SAndroid Build Coastguard Worker           output = (uint8_t*) ((uintptr_t) output + m);
101*4bdc9457SAndroid Build Coastguard Worker 
102*4bdc9457SAndroid Build Coastguard Worker           vst1_lane_u32((void*) output, vxyzw1, 0);
103*4bdc9457SAndroid Build Coastguard Worker           output = (uint8_t*) ((uintptr_t) output + m);
104*4bdc9457SAndroid Build Coastguard Worker 
105*4bdc9457SAndroid Build Coastguard Worker           vst1_lane_u32((void*) output, vxyzw1, 1);
106*4bdc9457SAndroid Build Coastguard Worker           output = (uint8_t*) ((uintptr_t) output + m);
107*4bdc9457SAndroid Build Coastguard Worker 
108*4bdc9457SAndroid Build Coastguard Worker           vxyzw0 = vxyzw2;
109*4bdc9457SAndroid Build Coastguard Worker           vxyzw1 = vxyzw3;
110*4bdc9457SAndroid Build Coastguard Worker         }
111*4bdc9457SAndroid Build Coastguard Worker 
112*4bdc9457SAndroid Build Coastguard Worker         if (k & 2) {
113*4bdc9457SAndroid Build Coastguard Worker           vst1_lane_u32((void*) output, vxyzw0, 0);
114*4bdc9457SAndroid Build Coastguard Worker           output = (uint8_t*) ((uintptr_t) output + m);
115*4bdc9457SAndroid Build Coastguard Worker 
116*4bdc9457SAndroid Build Coastguard Worker           vst1_lane_u32((void*) output, vxyzw0, 1);
117*4bdc9457SAndroid Build Coastguard Worker           output = (uint8_t*) ((uintptr_t) output + m);
118*4bdc9457SAndroid Build Coastguard Worker 
119*4bdc9457SAndroid Build Coastguard Worker           vxyzw0 = vxyzw1;
120*4bdc9457SAndroid Build Coastguard Worker         }
121*4bdc9457SAndroid Build Coastguard Worker         if (k & 1) {
122*4bdc9457SAndroid Build Coastguard Worker           vst1_lane_u32((void*) output, vxyzw0, 0);
123*4bdc9457SAndroid Build Coastguard Worker           output = (uint8_t*) ((uintptr_t) output + m);
124*4bdc9457SAndroid Build Coastguard Worker         }
125*4bdc9457SAndroid Build Coastguard Worker       }
126*4bdc9457SAndroid Build Coastguard Worker       output = (uint8_t*) ((uintptr_t) output + output_increment);
127*4bdc9457SAndroid Build Coastguard Worker       if (output > last_output) {
128*4bdc9457SAndroid Build Coastguard Worker         output = last_output;
129*4bdc9457SAndroid Build Coastguard Worker       }
130*4bdc9457SAndroid Build Coastguard Worker     }
131*4bdc9457SAndroid Build Coastguard Worker   } else {
132*4bdc9457SAndroid Build Coastguard Worker     const uint8_t* i = input;
133*4bdc9457SAndroid Build Coastguard Worker     uint8_t* o = output;
134*4bdc9457SAndroid Build Coastguard Worker     size_t k = n;
135*4bdc9457SAndroid Build Coastguard Worker     do {
136*4bdc9457SAndroid Build Coastguard Worker       size_t l = m;
137*4bdc9457SAndroid Build Coastguard Worker       const uint8_t* ii = i++;
138*4bdc9457SAndroid Build Coastguard Worker       do {
139*4bdc9457SAndroid Build Coastguard Worker         *o++ = *ii;
140*4bdc9457SAndroid Build Coastguard Worker         ii += n;
141*4bdc9457SAndroid Build Coastguard Worker       } while (--l != 0);
142*4bdc9457SAndroid Build Coastguard Worker     } while (--k != 0);
143*4bdc9457SAndroid Build Coastguard Worker   }
144*4bdc9457SAndroid Build Coastguard Worker }
145