xref: /aosp_15_r20/external/mesa3d/src/nouveau/compiler/nak_nir_lower_scan_reduce.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2023 Collabora, Ltd.
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "nak_private.h"
7 #include "nir_builder.h"
8 
9 static nir_def *
cluster_mask(nir_builder * b,unsigned cluster_size)10 cluster_mask(nir_builder *b, unsigned cluster_size)
11 {
12    nir_def *mask = nir_ballot(b, 1, 32, nir_imm_true(b));
13 
14    if (cluster_size < 32) {
15       nir_def *idx = nir_load_subgroup_invocation(b);
16       nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 1));
17 
18       nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size));
19       cluster_mask = nir_ishl(b, cluster_mask, cluster);
20 
21       mask = nir_iand(b, mask, cluster_mask);
22    }
23 
24    return mask;
25 }
26 
27 static nir_def *
build_scan_bool(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,unsigned cluster_size)28 build_scan_bool(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
29                 nir_def *data, unsigned cluster_size)
30 {
31    /* Handle a couple of special cases first */
32    if (op == nir_intrinsic_reduce && cluster_size == 32) {
33       switch (red_op) {
34       case nir_op_iand:
35          return nir_vote_all(b, 1, data);
36       case nir_op_ior:
37          return nir_vote_any(b, 1, data);
38       case nir_op_ixor:
39          /* The generic path is fine */
40          break;
41       default:
42          unreachable("Unsupported boolean reduction op");
43       }
44    }
45 
46    nir_def *mask = cluster_mask(b, cluster_size);
47    switch (op) {
48    case nir_intrinsic_exclusive_scan:
49       mask = nir_iand(b, mask, nir_load_subgroup_lt_mask(b, 1, 32));
50       break;
51    case nir_intrinsic_inclusive_scan:
52       mask = nir_iand(b, mask, nir_load_subgroup_le_mask(b, 1, 32));
53       break;
54    case nir_intrinsic_reduce:
55       break;
56    default:
57       unreachable("Unsupported scan/reduce op");
58    }
59 
60    data = nir_ballot(b, 1, 32, data);
61 
62    switch (red_op) {
63    case nir_op_iand:
64       return nir_ieq_imm(b, nir_iand(b, nir_inot(b, data), mask), 0);
65    case nir_op_ior:
66       return nir_ine_imm(b, nir_iand(b, data, mask), 0);
67    case nir_op_ixor: {
68       nir_def *count = nir_bit_count(b, nir_iand(b, data, mask));
69       return nir_ine_imm(b, nir_iand_imm(b, count, 1), 0);
70    }
71    default:
72       unreachable("Unsupported boolean reduction op");
73    }
74 }
75 
76 static nir_def *
build_identity(nir_builder * b,unsigned bit_size,nir_op op)77 build_identity(nir_builder *b, unsigned bit_size, nir_op op)
78 {
79    nir_const_value ident_const = nir_alu_binop_identity(op, bit_size);
80    return nir_build_imm(b, 1, bit_size, &ident_const);
81 }
82 
83 /* Implementation of scan/reduce that assumes a full subgroup */
84 static nir_def *
build_scan_full(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,unsigned cluster_size)85 build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
86                 nir_def *data, unsigned cluster_size)
87 {
88    switch (op) {
89    case nir_intrinsic_exclusive_scan:
90    case nir_intrinsic_inclusive_scan: {
91       for (unsigned i = 1; i < cluster_size; i *= 2) {
92          nir_def *idx = nir_load_subgroup_invocation(b);
93          nir_def *has_buddy = nir_ige_imm(b, idx, i);
94 
95          nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i));
96          nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
97          data = nir_bcsel(b, has_buddy, accum, data);
98       }
99 
100       if (op == nir_intrinsic_exclusive_scan) {
101          /* For exclusive scans, we need to shift one more time and fill in the
102           * bottom channel with identity.
103           */
104          assert(cluster_size == 32);
105          nir_def *idx = nir_load_subgroup_invocation(b);
106          nir_def *has_buddy = nir_ige_imm(b, idx, 1);
107 
108          nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, 1));
109          nir_def *identity = build_identity(b, data->bit_size, red_op);
110          data = nir_bcsel(b, has_buddy, buddy_data, identity);
111       }
112 
113       return data;
114    }
115 
116    case nir_intrinsic_reduce: {
117       for (unsigned i = 1; i < cluster_size; i *= 2) {
118          nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i));
119          data = nir_build_alu2(b, red_op, data, buddy_data);
120       }
121       return data;
122    }
123 
124    default:
125       unreachable("Unsupported scan/reduce op");
126    }
127 }
128 
129 /* Fully generic implementation of scan/reduce that takes a mask */
130 static nir_def *
build_scan_reduce(nir_builder * b,nir_intrinsic_op op,nir_op red_op,nir_def * data,nir_def * mask,unsigned max_mask_bits)131 build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
132                   nir_def *data, nir_def *mask, unsigned max_mask_bits)
133 {
134    nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, 32);
135 
136    /* Mask of all channels whose values we need to accumulate.  Our own value
137     * is already in accum, if inclusive, thanks to the initialization above.
138     * We only need to consider lower indexed invocations.
139     */
140    nir_def *remaining = nir_iand(b, mask, lt_mask);
141 
142    for (unsigned i = 1; i < max_mask_bits; i *= 2) {
143       /* At each step, our buddy channel is the first channel we have yet to
144        * take into account in the accumulator.
145        */
146       nir_def *has_buddy = nir_ine_imm(b, remaining, 0);
147       nir_def *buddy = nir_ufind_msb(b, remaining);
148 
149       /* Accumulate with our buddy channel, if any */
150       nir_def *buddy_data = nir_shuffle(b, data, buddy);
151       nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
152       data = nir_bcsel(b, has_buddy, accum, data);
153 
154       /* We just took into account everything in our buddy's accumulator from
155        * the previous step.  The only things remaining are whatever channels
156        * were remaining for our buddy.
157        */
158       nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy);
159       remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0));
160    }
161 
162    switch (op) {
163    case nir_intrinsic_exclusive_scan: {
164       /* For exclusive scans, we need to shift one more time and fill in the
165        * bottom channel with identity.
166        *
167        * Some of this will get CSE'd with the first step but that's okay. The
168        * code is cleaner this way.
169        */
170       nir_def *lower = nir_iand(b, mask, lt_mask);
171       nir_def *has_buddy = nir_ine_imm(b, lower, 0);
172       nir_def *buddy = nir_ufind_msb(b, lower);
173 
174       nir_def *buddy_data = nir_shuffle(b, data, buddy);
175       nir_def *identity = build_identity(b, data->bit_size, red_op);
176       return nir_bcsel(b, has_buddy, buddy_data, identity);
177    }
178 
179    case nir_intrinsic_inclusive_scan:
180       return data;
181 
182    case nir_intrinsic_reduce: {
183       /* For reductions, we need to take the top value of the scan */
184       nir_def *idx = nir_ufind_msb(b, mask);
185       return nir_shuffle(b, data, idx);
186    }
187 
188    default:
189       unreachable("Unsupported scan/reduce op");
190    }
191 }
192 
193 static bool
nak_nir_lower_scan_reduce_intrin(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED void * _data)194 nak_nir_lower_scan_reduce_intrin(nir_builder *b,
195                                  nir_intrinsic_instr *intrin,
196                                  UNUSED void *_data)
197 {
198    switch (intrin->intrinsic) {
199    case nir_intrinsic_exclusive_scan:
200    case nir_intrinsic_inclusive_scan:
201    case nir_intrinsic_reduce:
202       break;
203    default:
204       return false;
205    }
206 
207    const nir_op red_op = nir_intrinsic_reduction_op(intrin);
208 
209    /* Grab the cluster size, defaulting to 32 */
210    unsigned cluster_size = 32;
211    if (nir_intrinsic_has_cluster_size(intrin)) {
212       cluster_size = nir_intrinsic_cluster_size(intrin);
213       if (cluster_size == 0 || cluster_size > 32)
214          cluster_size = 32;
215    }
216 
217    b->cursor = nir_before_instr(&intrin->instr);
218 
219    nir_def *data;
220    if (cluster_size == 1) {
221       /* Simple case where we're not actually doing any reducing at all. */
222       assert(intrin->intrinsic == nir_intrinsic_reduce);
223       data = intrin->src[0].ssa;
224    } else if (intrin->src[0].ssa->bit_size == 1) {
225       data = build_scan_bool(b, intrin->intrinsic, red_op,
226                              intrin->src[0].ssa, cluster_size);
227    } else {
228       /* First, we need a mask of all invocations to be included in the
229        * reduction or scan.  For trivial cluster sizes, that's just the mask
230        * of enabled channels.
231        */
232       nir_def *mask = cluster_mask(b, cluster_size);
233 
234       nir_def *full, *partial;
235       nir_push_if(b, nir_ieq_imm(b, mask, -1));
236       {
237          full = build_scan_full(b, intrin->intrinsic, red_op,
238                                 intrin->src[0].ssa, cluster_size);
239       }
240       nir_push_else(b, NULL);
241       {
242          partial = build_scan_reduce(b, intrin->intrinsic, red_op,
243                                      intrin->src[0].ssa, mask, cluster_size);
244       }
245       nir_pop_if(b, NULL);
246       data = nir_if_phi(b, full, partial);
247    }
248 
249    nir_def_replace(&intrin->def, data);
250 
251    return true;
252 }
253 
254 bool
nak_nir_lower_scan_reduce(nir_shader * nir)255 nak_nir_lower_scan_reduce(nir_shader *nir)
256 {
257    return nir_shader_intrinsics_pass(nir, nak_nir_lower_scan_reduce_intrin,
258                                      nir_metadata_none, NULL);
259 }
260