1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #if defined(ARM_COMPUTE_ENABLE_SME)
26 
27 #include <cstdint>
28 
29 namespace arm_conv {
30 namespace pooling {
31 
32 
sme_fp32_nhwc_avg_generic_depthfirst_impl(const uint64_t window_cells,const uint64_t n_valid_cells,uint64_t n_channels,const float * const * const inptrs,float * outptr)33 void sme_fp32_nhwc_avg_generic_depthfirst_impl(
34   const uint64_t window_cells,
35   const uint64_t n_valid_cells,
36   uint64_t n_channels,
37   const float *const *const inptrs,
38   float *outptr
39 )
40 {
41   const auto rescale_value = static_cast<float>(1.0f / static_cast<float>(window_cells));
42 
43   __asm__ __volatile__(
44     ".inst 0xd503477f  // SMSTART ZA\n"
45     "mov x28, #0x0\n"
46     "cntw x27\n"
47     "cntw x26, ALL, MUL #2\n"
48     "cntw x25, ALL, MUL #3\n"
49     "ptrue p0.b\n"
50     "whilelt p3.s, x28, %x[n_channels]\n"
51     "ld1rw { z6.s }, p0/Z, [%x[rescale_ptr]]\n"
52     "whilelt p2.s, x27, %x[n_channels]\n"
53     "whilelt p1.s, x26, %x[n_channels]\n"
54     "whilelt p0.s, x25, %x[n_channels]\n"
55     "b.none 7f\n"
56     "1:"  // 4-vectors of channels
57     "lsr x24, %x[n_valid_cells], #0x2\n"
58     "mov z5.b, #0x0\n"
59     "mov z4.b, #0x0\n"
60     "mov x19, %x[inptrs]\n"
61     "mov z3.b, #0x0\n"
62     "mov z2.b, #0x0\n"
63     "cbz x24, 4f\n"
64     "ldp x23, x22, [x19, #0x0]\n"
65     "subs x24, x24, #0x1\n"
66     "ld1w { z1.s }, p3/Z, [x23, x28, LSL #2]\n"
67     "ldp x21, x20, [x19, #0x10]\n"
68     "add x19, x19, #0x20\n"
69     "ld1w { z0.s }, p3/Z, [x22, x28, LSL #2]\n"
70     "ld1w { z31.s }, p3/Z, [x21, x28, LSL #2]\n"
71     "ld1w { z30.s }, p3/Z, [x20, x28, LSL #2]\n"
72     "ld1w { z29.s }, p2/Z, [x23, x27, LSL #2]\n"
73     "ld1w { z22.s }, p2/Z, [x22, x27, LSL #2]\n"
74     "ld1w { z28.s }, p2/Z, [x21, x27, LSL #2]\n"
75     "ld1w { z18.s }, p2/Z, [x20, x27, LSL #2]\n"
76     "ld1w { z27.s }, p1/Z, [x23, x26, LSL #2]\n"
77     "ld1w { z21.s }, p1/Z, [x22, x26, LSL #2]\n"
78     "ld1w { z26.s }, p1/Z, [x21, x26, LSL #2]\n"
79     "ld1w { z17.s }, p1/Z, [x20, x26, LSL #2]\n"
80     "ld1w { z25.s }, p0/Z, [x23, x25, LSL #2]\n"
81     "ld1w { z20.s }, p0/Z, [x22, x25, LSL #2]\n"
82     "ld1w { z24.s }, p0/Z, [x21, x25, LSL #2]\n"
83     "ld1w { z16.s }, p0/Z, [x20, x25, LSL #2]\n"
84     "beq 3f\n"
85     "2:"  // 4-vectors of channels: 4 inputs loop
86     "fadd z23.s, z1.s, z0.s\n"
87     "fadd z19.s, z31.s, z30.s\n"
88     "ldp x23, x22, [x19, #0x0]\n"
89     "subs x24, x24, #0x1\n"
90     "fadd z22.s, z29.s, z22.s\n"
91     "fadd z18.s, z28.s, z18.s\n"
92     "ldp x21, x20, [x19, #0x10]\n"
93     "add x19, x19, #0x20\n"
94     "fadd z21.s, z27.s, z21.s\n"
95     "fadd z17.s, z26.s, z17.s\n"
96     "ld1w { z1.s }, p3/Z, [x23, x28, LSL #2]\n"
97     "fadd z20.s, z25.s, z20.s\n"
98     "fadd z16.s, z24.s, z16.s\n"
99     "ld1w { z0.s }, p3/Z, [x22, x28, LSL #2]\n"
100     "fadd z19.s, z23.s, z19.s\n"
101     "fadd z18.s, z22.s, z18.s\n"
102     "ld1w { z31.s }, p3/Z, [x21, x28, LSL #2]\n"
103     "fadd z17.s, z21.s, z17.s\n"
104     "fadd z16.s, z20.s, z16.s\n"
105     "ld1w { z30.s }, p3/Z, [x20, x28, LSL #2]\n"
106     "fadd z5.s, z5.s, z19.s\n"
107     "fadd z4.s, z4.s, z18.s\n"
108     "ld1w { z29.s }, p2/Z, [x23, x27, LSL #2]\n"
109     "fadd z3.s, z3.s, z17.s\n"
110     "fadd z2.s, z2.s, z16.s\n"
111     "ld1w { z22.s }, p2/Z, [x22, x27, LSL #2]\n"
112     "ld1w { z28.s }, p2/Z, [x21, x27, LSL #2]\n"
113     "ld1w { z18.s }, p2/Z, [x20, x27, LSL #2]\n"
114     "ld1w { z27.s }, p1/Z, [x23, x26, LSL #2]\n"
115     "ld1w { z21.s }, p1/Z, [x22, x26, LSL #2]\n"
116     "ld1w { z26.s }, p1/Z, [x21, x26, LSL #2]\n"
117     "ld1w { z17.s }, p1/Z, [x20, x26, LSL #2]\n"
118     "ld1w { z25.s }, p0/Z, [x23, x25, LSL #2]\n"
119     "ld1w { z20.s }, p0/Z, [x22, x25, LSL #2]\n"
120     "ld1w { z24.s }, p0/Z, [x21, x25, LSL #2]\n"
121     "ld1w { z16.s }, p0/Z, [x20, x25, LSL #2]\n"
122     "bgt 2b\n"
123     "3:"  // 4-vectors of channels: 4 inputs tail
124     "fadd z23.s, z1.s, z0.s\n"
125     "fadd z19.s, z31.s, z30.s\n"
126     "fadd z22.s, z29.s, z22.s\n"
127     "fadd z18.s, z28.s, z18.s\n"
128     "fadd z21.s, z27.s, z21.s\n"
129     "fadd z17.s, z26.s, z17.s\n"
130     "fadd z20.s, z25.s, z20.s\n"
131     "fadd z16.s, z24.s, z16.s\n"
132     "fadd z19.s, z23.s, z19.s\n"
133     "fadd z18.s, z22.s, z18.s\n"
134     "fadd z17.s, z21.s, z17.s\n"
135     "fadd z16.s, z20.s, z16.s\n"
136     "fadd z5.s, z5.s, z19.s\n"
137     "fadd z4.s, z4.s, z18.s\n"
138     "fadd z3.s, z3.s, z17.s\n"
139     "fadd z2.s, z2.s, z16.s\n"
140     "4:"  // 4-vectors of channels: After loop
141     "ands x20, %x[n_valid_cells], #0x3\n"
142     "beq 6f\n"
143     "5:"  // 4-vectors of channels: Single input loop
144     "ldr x23, [x19], #0x8\n"
145     "ld1w { z1.s }, p3/Z, [x23, x28, LSL #2]\n"
146     "subs x20, x20, #0x1\n"
147     "fadd z5.s, z5.s, z1.s\n"
148     "ld1w { z29.s }, p2/Z, [x23, x27, LSL #2]\n"
149     "fadd z4.s, z4.s, z29.s\n"
150     "ld1w { z27.s }, p1/Z, [x23, x26, LSL #2]\n"
151     "fadd z3.s, z3.s, z27.s\n"
152     "ld1w { z25.s }, p0/Z, [x23, x25, LSL #2]\n"
153     "fadd z2.s, z2.s, z25.s\n"
154     "bgt 5b\n"
155     "6:"  // 4-vectors of channels: Single input loop: End
156     "fmul z5.s, z5.s, z6.s\n"
157     "fmul z4.s, z4.s, z6.s\n"
158     "st1w { z5.s }, p3, [%x[outptr], x28, LSL #2]\n"
159     "incw x28, ALL, MUL #4\n"
160     "fmul z3.s, z3.s, z6.s\n"
161     "fmul z2.s, z2.s, z6.s\n"
162     "st1w { z4.s }, p2, [%x[outptr], x27, LSL #2]\n"
163     "incw x27, ALL, MUL #4\n"
164     "st1w { z3.s }, p1, [%x[outptr], x26, LSL #2]\n"
165     "incw x26, ALL, MUL #4\n"
166     "st1w { z2.s }, p0, [%x[outptr], x25, LSL #2]\n"
167     "incw x25, ALL, MUL #4\n"
168     "whilelt p0.s, x25, %x[n_channels]\n"
169     "b.any 1b\n"
170     "7:"  // Single vector of channels
171     "whilelt p3.s, x28, %x[n_channels]\n"
172     "b.none 14f\n"
173     "8:"  // Single vector of channels: Loop
174     "lsr x24, %x[n_valid_cells], #0x2\n"
175     "mov z5.b, #0x0\n"
176     "mov x19, %x[inptrs]\n"
177     "cbz x24, 11f\n"
178     "ldp x23, x22, [x19, #0x0]\n"
179     "subs x24, x24, #0x1\n"
180     "ld1w { z1.s }, p3/Z, [x23, x28, LSL #2]\n"
181     "ldp x21, x20, [x19, #0x10]\n"
182     "add x19, x19, #0x20\n"
183     "ld1w { z0.s }, p3/Z, [x22, x28, LSL #2]\n"
184     "ld1w { z31.s }, p3/Z, [x21, x28, LSL #2]\n"
185     "ld1w { z30.s }, p3/Z, [x20, x28, LSL #2]\n"
186     "beq 10f\n"
187     "9:"  // Single vector of channels: Loop: 4 inputs loop
188     "fadd z23.s, z1.s, z0.s\n"
189     "fadd z19.s, z31.s, z30.s\n"
190     "ldp x23, x22, [x19, #0x0]\n"
191     "subs x24, x24, #0x1\n"
192     "fadd z19.s, z23.s, z19.s\n"
193     "ldp x21, x20, [x19, #0x10]\n"
194     "fadd z5.s, z5.s, z19.s\n"
195     "add x19, x19, #0x20\n"
196     "ld1w { z1.s }, p3/Z, [x23, x28, LSL #2]\n"
197     "ld1w { z0.s }, p3/Z, [x22, x28, LSL #2]\n"
198     "ld1w { z31.s }, p3/Z, [x21, x28, LSL #2]\n"
199     "ld1w { z30.s }, p3/Z, [x20, x28, LSL #2]\n"
200     "bgt 9b\n"
201     "10:"  // Single vector of channels: Loop: 4 inputs tail
202     "fadd z23.s, z1.s, z0.s\n"
203     "fadd z19.s, z31.s, z30.s\n"
204     "fadd z19.s, z23.s, z19.s\n"
205     "fadd z5.s, z5.s, z19.s\n"
206     "11:"  // Single vector of channels: Loop: After loop
207     "ands x20, %x[n_valid_cells], #0x3\n"
208     "beq 13f\n"
209     "12:"  // Single vector of channels: Loop: Single input loop
210     "ldr x23, [x19], #0x8\n"
211     "ld1w { z1.s }, p3/Z, [x23, x28, LSL #2]\n"
212     "subs x20, x20, #0x1\n"
213     "fadd z5.s, z5.s, z1.s\n"
214     "bgt 12b\n"
215     "13:"  // Single vector of channels: Loop: Single input loop: End
216     "fmul z5.s, z5.s, z6.s\n"
217     "st1w { z5.s }, p3, [%x[outptr], x28, LSL #2]\n"
218     "incw x28\n"
219     "whilelt p3.s, x28, %x[n_channels]\n"
220     "b.any 8b\n"
221     "14:"  // End
222     ".inst 0xd503467f  // SMSTOP\n"
223     :
224     : [inptrs] "r" (inptrs), [n_channels] "r" (n_channels), [n_valid_cells] "r" (n_valid_cells), [outptr] "r" (outptr), [rescale_ptr] "r" (&rescale_value)
225     : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"
226   );
227 }
228 
229 }  // namespace pooling
230 }  // namespace arm_conv
231 
232 #endif  // defined(ARM_COMPUTE_ENABLE_SME)
233