1 /*
2 * Copyright © 2023 Collabora, Ltd.
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "nak_private.h"
7 #include "nir_builder.h"
8
9 static nir_def *
cluster_mask(nir_builder * b,unsigned cluster_size)10 cluster_mask(nir_builder *b, unsigned cluster_size)
11 {
12 nir_def *mask = nir_ballot(b, 1, 32, nir_imm_true(b));
13
14 if (cluster_size < 32) {
15 nir_def *idx = nir_load_subgroup_invocation(b);
16 nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1));
17
18 nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size));
19 cluster_mask = nir_ishl(b, cluster_mask, cluster);
20
21 mask = nir_iand(b, mask, cluster_mask);
22 }
23
24 return mask;
25 }
26
27 static nir_def *
build_scan_bool(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,unsigned cluster_size)28 build_scan_bool(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
29 nir_def *data, unsigned cluster_size)
30 {
31 /* Handle a couple of special cases first */
32 if (op == nir_intrinsic_reduce && cluster_size == 32) {
33 switch (red_op) {
34 case nir_op_iand:
35 return nir_vote_all(b, 1, data);
36 case nir_op_ior:
37 return nir_vote_any(b, 1, data);
38 case nir_op_ixor:
39 /* The generic path is fine */
40 break;
41 default:
42 unreachable("Unsupported boolean reduction op");
43 }
44 }
45
46 nir_def *mask = cluster_mask(b, cluster_size);
47 switch (op) {
48 case nir_intrinsic_exclusive_scan:
49 mask = nir_iand(b, mask, nir_load_subgroup_lt_mask(b, 1, 32));
50 break;
51 case nir_intrinsic_inclusive_scan:
52 mask = nir_iand(b, mask, nir_load_subgroup_le_mask(b, 1, 32));
53 break;
54 case nir_intrinsic_reduce:
55 break;
56 default:
57 unreachable("Unsupported scan/reduce op");
58 }
59
60 data = nir_ballot(b, 1, 32, data);
61
62 switch (red_op) {
63 case nir_op_iand:
64 return nir_ieq_imm(b, nir_iand(b, nir_inot(b, data), mask), 0);
65 case nir_op_ior:
66 return nir_ine_imm(b, nir_iand(b, data, mask), 0);
67 case nir_op_ixor: {
68 nir_def *count = nir_bit_count(b, nir_iand(b, data, mask));
69 return nir_ine_imm(b, nir_iand_imm(b, count, 1), 0);
70 }
71 default:
72 unreachable("Unsupported boolean reduction op");
73 }
74 }
75
76 static nir_def *
build_identity(nir_builder * b,unsigned bit_size,nir_op op)77 build_identity(nir_builder *b, unsigned bit_size, nir_op op)
78 {
79 nir_const_value ident_const = nir_alu_binop_identity(op, bit_size);
80 return nir_build_imm(b, 1, bit_size, &ident_const);
81 }
82
83 /* Implementation of scan/reduce that assumes a full subgroup */
84 static nir_def *
build_scan_full(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,unsigned cluster_size)85 build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
86 nir_def *data, unsigned cluster_size)
87 {
88 switch (op) {
89 case nir_intrinsic_exclusive_scan:
90 case nir_intrinsic_inclusive_scan: {
91 for (unsigned i = 1; i < cluster_size; i *= 2) {
92 nir_def *idx = nir_load_subgroup_invocation(b);
93 nir_def *has_buddy = nir_ige_imm(b, idx, i);
94
95 nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i));
96 nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
97 data = nir_bcsel(b, has_buddy, accum, data);
98 }
99
100 if (op == nir_intrinsic_exclusive_scan) {
101 /* For exclusive scans, we need to shift one more time and fill in the
102 * bottom channel with identity.
103 */
104 assert(cluster_size == 32);
105 nir_def *idx = nir_load_subgroup_invocation(b);
106 nir_def *has_buddy = nir_ige_imm(b, idx, 1);
107
108 nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, 1));
109 nir_def *identity = build_identity(b, data->bit_size, red_op);
110 data = nir_bcsel(b, has_buddy, buddy_data, identity);
111 }
112
113 return data;
114 }
115
116 case nir_intrinsic_reduce: {
117 for (unsigned i = 1; i < cluster_size; i *= 2) {
118 nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i));
119 data = nir_build_alu2(b, red_op, data, buddy_data);
120 }
121 return data;
122 }
123
124 default:
125 unreachable("Unsupported scan/reduce op");
126 }
127 }
128
129 /* Fully generic implementation of scan/reduce that takes a mask */
130 static nir_def *
build_scan_reduce(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,nir_def * mask,unsigned max_mask_bits)131 build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
132 nir_def *data, nir_def *mask, unsigned max_mask_bits)
133 {
134 nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, 32);
135
136 /* Mask of all channels whose values we need to accumulate. Our own value
137 * is already in accum, if inclusive, thanks to the initialization above.
138 * We only need to consider lower indexed invocations.
139 */
140 nir_def *remaining = nir_iand(b, mask, lt_mask);
141
142 for (unsigned i = 1; i < max_mask_bits; i *= 2) {
143 /* At each step, our buddy channel is the first channel we have yet to
144 * take into account in the accumulator.
145 */
146 nir_def *has_buddy = nir_ine_imm(b, remaining, 0);
147 nir_def *buddy = nir_ufind_msb(b, remaining);
148
149 /* Accumulate with our buddy channel, if any */
150 nir_def *buddy_data = nir_shuffle(b, data, buddy);
151 nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
152 data = nir_bcsel(b, has_buddy, accum, data);
153
154 /* We just took into account everything in our buddy's accumulator from
155 * the previous step. The only things remaining are whatever channels
156 * were remaining for our buddy.
157 */
158 nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy);
159 remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0));
160 }
161
162 switch (op) {
163 case nir_intrinsic_exclusive_scan: {
164 /* For exclusive scans, we need to shift one more time and fill in the
165 * bottom channel with identity.
166 *
167 * Some of this will get CSE'd with the first step but that's okay. The
168 * code is cleaner this way.
169 */
170 nir_def *lower = nir_iand(b, mask, lt_mask);
171 nir_def *has_buddy = nir_ine_imm(b, lower, 0);
172 nir_def *buddy = nir_ufind_msb(b, lower);
173
174 nir_def *buddy_data = nir_shuffle(b, data, buddy);
175 nir_def *identity = build_identity(b, data->bit_size, red_op);
176 return nir_bcsel(b, has_buddy, buddy_data, identity);
177 }
178
179 case nir_intrinsic_inclusive_scan:
180 return data;
181
182 case nir_intrinsic_reduce: {
183 /* For reductions, we need to take the top value of the scan */
184 nir_def *idx = nir_ufind_msb(b, mask);
185 return nir_shuffle(b, data, idx);
186 }
187
188 default:
189 unreachable("Unsupported scan/reduce op");
190 }
191 }
192
193 static bool
nak_nir_lower_scan_reduce_intrin(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED void * _data)194 nak_nir_lower_scan_reduce_intrin(nir_builder *b,
195 nir_intrinsic_instr *intrin,
196 UNUSED void *_data)
197 {
198 switch (intrin->intrinsic) {
199 case nir_intrinsic_exclusive_scan:
200 case nir_intrinsic_inclusive_scan:
201 case nir_intrinsic_reduce:
202 break;
203 default:
204 return false;
205 }
206
207 const nir_op red_op = nir_intrinsic_reduction_op(intrin);
208
209 /* Grab the cluster size, defaulting to 32 */
210 unsigned cluster_size = 32;
211 if (nir_intrinsic_has_cluster_size(intrin)) {
212 cluster_size = nir_intrinsic_cluster_size(intrin);
213 if (cluster_size == 0 || cluster_size > 32)
214 cluster_size = 32;
215 }
216
217 b->cursor = nir_before_instr(&intrin->instr);
218
219 nir_def *data;
220 if (cluster_size == 1) {
221 /* Simple case where we're not actually doing any reducing at all. */
222 assert(intrin->intrinsic == nir_intrinsic_reduce);
223 data = intrin->src[0].ssa;
224 } else if (intrin->src[0].ssa->bit_size == 1) {
225 data = build_scan_bool(b, intrin->intrinsic, red_op,
226 intrin->src[0].ssa, cluster_size);
227 } else {
228 /* First, we need a mask of all invocations to be included in the
229 * reduction or scan. For trivial cluster sizes, that's just the mask
230 * of enabled channels.
231 */
232 nir_def *mask = cluster_mask(b, cluster_size);
233
234 nir_def *full, *partial;
235 nir_push_if(b, nir_ieq_imm(b, mask, -1));
236 {
237 full = build_scan_full(b, intrin->intrinsic, red_op,
238 intrin->src[0].ssa, cluster_size);
239 }
240 nir_push_else(b, NULL);
241 {
242 partial = build_scan_reduce(b, intrin->intrinsic, red_op,
243 intrin->src[0].ssa, mask, cluster_size);
244 }
245 nir_pop_if(b, NULL);
246 data = nir_if_phi(b, full, partial);
247 }
248
249 nir_def_replace(&intrin->def, data);
250
251 return true;
252 }
253
254 bool
nak_nir_lower_scan_reduce(nir_shader * nir)255 nak_nir_lower_scan_reduce(nir_shader *nir)
256 {
257 return nir_shader_intrinsics_pass(nir, nak_nir_lower_scan_reduce_intrin,
258 nir_metadata_none, NULL);
259 }
260