xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/nir/radv_nir_opt_tid_function.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "nir_builder.h"
7 #include "nir_constant_expressions.h"
8 #include "radv_nir.h"
9 
10 /* This pass optimizes shuffles and boolean alu where the source can be
11  * expressed as a function of tid (only subgroup_id,
12  * invocation_id or constant as inputs).
13  * Shuffles are replaced by specialized intrinsics, boolean alu by inverse_ballot.
14  * The pass first computes the function of tid (fotid) mask, and then uses constant
15  * folding to compute the source for each invocation.
16  *
17  * This pass assumes that local_invocation_index = subgroup_id * subgroup_size + subgroup_invocation_id.
18  * That is not guaranteed by the VK spec, but it's how amd hardware works, if the GFX12 INTERLEAVE_BITS_X/Y
19  * fields are not used. This is also the main reason why this pass is currently radv specific.
20  */
21 
22 #define NIR_MAX_SUBGROUP_SIZE     128
23 #define FOTID_MAX_RECURSION_DEPTH 16 /* totally arbitrary */
24 
25 static inline unsigned
src_get_fotid_mask(nir_src src)26 src_get_fotid_mask(nir_src src)
27 {
28    return src.ssa->parent_instr->pass_flags;
29 }
30 
31 static inline unsigned
alu_src_get_fotid_mask(nir_alu_instr * instr,unsigned idx)32 alu_src_get_fotid_mask(nir_alu_instr *instr, unsigned idx)
33 {
34    unsigned unswizzled = src_get_fotid_mask(instr->src[idx].src);
35    unsigned result = 0;
36    for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(instr, idx); i++) {
37       bool is_fotid = unswizzled & (1u << instr->src[idx].swizzle[i]);
38       result |= is_fotid << i;
39    }
40    return result;
41 }
42 
43 static void
update_fotid_alu(nir_builder * b,nir_alu_instr * instr,const radv_nir_opt_tid_function_options * options)44 update_fotid_alu(nir_builder *b, nir_alu_instr *instr, const radv_nir_opt_tid_function_options *options)
45 {
46    /* For legacy reasons these are ALU instructions
47     * when they should be intrinsics.
48     */
49    if (nir_op_is_derivative(instr->op))
50       return;
51 
52    const nir_op_info *info = &nir_op_infos[instr->op];
53 
54    unsigned res = BITFIELD_MASK(instr->def.num_components);
55    for (unsigned i = 0; res != 0 && i < info->num_inputs; i++) {
56       unsigned src_mask = alu_src_get_fotid_mask(instr, i);
57       if (info->input_sizes[i] == 0)
58          res &= src_mask;
59       else if (src_mask != BITFIELD_MASK(info->input_sizes[i]))
60          res = 0;
61    }
62 
63    instr->instr.pass_flags = (uint8_t)res;
64 }
65 
66 static void
update_fotid_intrinsic(nir_builder * b,nir_intrinsic_instr * instr,const radv_nir_opt_tid_function_options * options)67 update_fotid_intrinsic(nir_builder *b, nir_intrinsic_instr *instr, const radv_nir_opt_tid_function_options *options)
68 {
69    switch (instr->intrinsic) {
70    case nir_intrinsic_load_subgroup_invocation: {
71       instr->instr.pass_flags = 1;
72       break;
73    }
74    case nir_intrinsic_load_local_invocation_id: {
75       if (b->shader->info.workgroup_size_variable)
76          break;
77       /* This assumes linear subgroup dispatch. */
78       unsigned partial_size = 1;
79       for (unsigned i = 0; i < 3; i++) {
80          partial_size *= b->shader->info.workgroup_size[i];
81          if (partial_size == options->hw_subgroup_size)
82             instr->instr.pass_flags = (uint8_t)BITFIELD_MASK(i + 1);
83       }
84       if (partial_size <= options->hw_subgroup_size)
85          instr->instr.pass_flags = 0x7;
86       break;
87    }
88    case nir_intrinsic_load_local_invocation_index: {
89       if (b->shader->info.workgroup_size_variable)
90          break;
91       unsigned workgroup_size =
92          b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2];
93       if (workgroup_size <= options->hw_subgroup_size)
94          instr->instr.pass_flags = 0x1;
95       break;
96    }
97    case nir_intrinsic_inverse_ballot: {
98       if (src_get_fotid_mask(instr->src[0]) == BITFIELD_MASK(instr->src[0].ssa->num_components)) {
99          instr->instr.pass_flags = 0x1;
100       }
101       break;
102    }
103    default: {
104       break;
105    }
106    }
107 }
108 
109 static void
update_fotid_load_const(nir_load_const_instr * instr)110 update_fotid_load_const(nir_load_const_instr *instr)
111 {
112    instr->instr.pass_flags = (uint8_t)BITFIELD_MASK(instr->def.num_components);
113 }
114 
115 static bool
update_fotid_instr(nir_builder * b,nir_instr * instr,const radv_nir_opt_tid_function_options * options)116 update_fotid_instr(nir_builder *b, nir_instr *instr, const radv_nir_opt_tid_function_options *options)
117 {
118    /* Gather a mask of components that are functions of tid. */
119    instr->pass_flags = 0;
120 
121    switch (instr->type) {
122    case nir_instr_type_alu:
123       update_fotid_alu(b, nir_instr_as_alu(instr), options);
124       break;
125    case nir_instr_type_intrinsic:
126       update_fotid_intrinsic(b, nir_instr_as_intrinsic(instr), options);
127       break;
128    case nir_instr_type_load_const:
129       update_fotid_load_const(nir_instr_as_load_const(instr));
130       break;
131    default:
132       break;
133    }
134 
135    return false;
136 }
137 
138 static bool
constant_fold_scalar(nir_scalar s,unsigned invocation_id,nir_shader * shader,nir_const_value * dest,unsigned depth)139 constant_fold_scalar(nir_scalar s, unsigned invocation_id, nir_shader *shader, nir_const_value *dest, unsigned depth)
140 {
141    if (depth > FOTID_MAX_RECURSION_DEPTH)
142       return false;
143 
144    memset(dest, 0, sizeof(*dest));
145 
146    if (nir_scalar_is_alu(s)) {
147       nir_alu_instr *alu = nir_instr_as_alu(s.def->parent_instr);
148       nir_const_value sources[NIR_ALU_MAX_INPUTS][NIR_MAX_VEC_COMPONENTS];
149       const nir_op_info *op_info = &nir_op_infos[alu->op];
150 
151       unsigned bit_size = 0;
152       if (!nir_alu_type_get_type_size(op_info->output_type))
153          bit_size = alu->def.bit_size;
154 
155       for (unsigned i = 0; i < op_info->num_inputs; i++) {
156          if (!bit_size && !nir_alu_type_get_type_size(op_info->input_types[i]))
157             bit_size = alu->src[i].src.ssa->bit_size;
158 
159          unsigned offset = 0;
160          unsigned num_comp = op_info->input_sizes[i];
161          if (num_comp == 0) {
162             num_comp = 1;
163             offset = s.comp;
164          }
165 
166          for (unsigned j = 0; j < num_comp; j++) {
167             nir_scalar src_scalar = nir_get_scalar(alu->src[i].src.ssa, alu->src[i].swizzle[offset + j]);
168             if (!constant_fold_scalar(src_scalar, invocation_id, shader, &sources[i][j], depth + 1))
169                return false;
170          }
171       }
172 
173       if (!bit_size)
174          bit_size = 32;
175 
176       unsigned exec_mode = shader->info.float_controls_execution_mode;
177 
178       nir_const_value *srcs[NIR_ALU_MAX_INPUTS];
179       for (unsigned i = 0; i < op_info->num_inputs; ++i)
180          srcs[i] = sources[i];
181       nir_const_value dests[NIR_MAX_VEC_COMPONENTS];
182       if (op_info->output_size == 0) {
183          nir_eval_const_opcode(alu->op, dests, 1, bit_size, srcs, exec_mode);
184          *dest = dests[0];
185       } else {
186          nir_eval_const_opcode(alu->op, dests, s.def->num_components, bit_size, srcs, exec_mode);
187          *dest = dests[s.comp];
188       }
189       return true;
190    } else if (nir_scalar_is_intrinsic(s)) {
191       switch (nir_scalar_intrinsic_op(s)) {
192       case nir_intrinsic_load_subgroup_invocation:
193       case nir_intrinsic_load_local_invocation_index: {
194          *dest = nir_const_value_for_uint(invocation_id, s.def->bit_size);
195          return true;
196       }
197       case nir_intrinsic_load_local_invocation_id: {
198          unsigned local_ids[3];
199          local_ids[2] = invocation_id / (shader->info.workgroup_size[0] * shader->info.workgroup_size[1]);
200          unsigned xy = invocation_id % (shader->info.workgroup_size[0] * shader->info.workgroup_size[1]);
201          local_ids[1] = xy / shader->info.workgroup_size[0];
202          local_ids[0] = xy % shader->info.workgroup_size[0];
203          *dest = nir_const_value_for_uint(local_ids[s.comp], s.def->bit_size);
204          return true;
205       }
206       case nir_intrinsic_inverse_ballot: {
207          nir_def *src = nir_instr_as_intrinsic(s.def->parent_instr)->src[0].ssa;
208          unsigned comp = invocation_id / src->bit_size;
209          unsigned bit = invocation_id % src->bit_size;
210          if (!constant_fold_scalar(nir_get_scalar(src, comp), invocation_id, shader, dest, depth + 1))
211             return false;
212          uint64_t ballot = nir_const_value_as_uint(*dest, src->bit_size);
213          *dest = nir_const_value_for_bool(ballot & (1ull << bit), 1);
214          return true;
215       }
216       default:
217          break;
218       }
219    } else if (nir_scalar_is_const(s)) {
220       *dest = nir_scalar_as_const_value(s);
221       return true;
222    }
223 
224    unreachable("unhandled scalar type");
225    return false;
226 }
227 
228 struct fotid_context {
229    const radv_nir_opt_tid_function_options *options;
230    uint8_t src_invoc[NIR_MAX_SUBGROUP_SIZE];
231    bool reads_zero[NIR_MAX_SUBGROUP_SIZE];
232    nir_shader *shader;
233 };
234 
235 static bool
gather_read_invocation_shuffle(nir_def * src,struct fotid_context * ctx)236 gather_read_invocation_shuffle(nir_def *src, struct fotid_context *ctx)
237 {
238    nir_scalar s = {src, 0};
239 
240    /* Recursive constant folding for each invocation */
241    for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
242       nir_const_value value;
243       if (!constant_fold_scalar(s, i, ctx->shader, &value, 0))
244          return false;
245       ctx->src_invoc[i] = MIN2(nir_const_value_as_uint(value, src->bit_size), UINT8_MAX);
246    }
247 
248    return true;
249 }
250 
251 static nir_alu_instr *
get_singluar_user_bcsel(nir_def * def,unsigned * src_idx)252 get_singluar_user_bcsel(nir_def *def, unsigned *src_idx)
253 {
254    if (def->num_components != 1 || !list_is_singular(&def->uses))
255       return NULL;
256 
257    nir_alu_instr *bcsel = NULL;
258    nir_foreach_use_including_if_safe (src, def) {
259       if (nir_src_is_if(src) || nir_src_parent_instr(src)->type != nir_instr_type_alu)
260          return NULL;
261       bcsel = nir_instr_as_alu(nir_src_parent_instr(src));
262       if (bcsel->op != nir_op_bcsel || bcsel->def.num_components != 1)
263          return NULL;
264       *src_idx = list_entry(src, nir_alu_src, src) - bcsel->src;
265       break;
266    }
267    assert(*src_idx < 3);
268 
269    if (*src_idx == 0)
270       return NULL;
271    return bcsel;
272 }
273 
274 static bool
gather_invocation_uses(nir_alu_instr * bcsel,unsigned shuffle_idx,struct fotid_context * ctx)275 gather_invocation_uses(nir_alu_instr *bcsel, unsigned shuffle_idx, struct fotid_context *ctx)
276 {
277    if (!alu_src_get_fotid_mask(bcsel, 0))
278       return false;
279 
280    nir_scalar s = {bcsel->src[0].src.ssa, bcsel->src[0].swizzle[0]};
281 
282    bool can_remove_bcsel =
283       nir_src_is_const(bcsel->src[3 - shuffle_idx].src) && nir_src_as_uint(bcsel->src[3 - shuffle_idx].src) == 0;
284 
285    /* Recursive constant folding for each invocation */
286    for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
287       nir_const_value value;
288       if (!constant_fold_scalar(s, i, ctx->shader, &value, 0)) {
289          can_remove_bcsel = false;
290          continue;
291       }
292 
293       /* If this invocation selects the other source,
294        * so we can read an undefined result. */
295       if (nir_const_value_as_bool(value, 1) == (shuffle_idx != 1)) {
296          ctx->src_invoc[i] = UINT8_MAX;
297          ctx->reads_zero[i] = can_remove_bcsel;
298       }
299    }
300 
301    if (can_remove_bcsel) {
302       return true;
303    } else {
304       memset(ctx->reads_zero, 0, sizeof(ctx->reads_zero));
305       return false;
306    }
307 }
308 
309 static nir_def *
try_opt_bitwise_mask(nir_builder * b,nir_def * def,struct fotid_context * ctx)310 try_opt_bitwise_mask(nir_builder *b, nir_def *def, struct fotid_context *ctx)
311 {
312    unsigned one = NIR_MAX_SUBGROUP_SIZE - 1;
313    unsigned zero = NIR_MAX_SUBGROUP_SIZE - 1;
314    unsigned copy = NIR_MAX_SUBGROUP_SIZE - 1;
315    unsigned invert = NIR_MAX_SUBGROUP_SIZE - 1;
316 
317    for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
318       unsigned read = ctx->src_invoc[i];
319       if (read >= ctx->options->hw_subgroup_size)
320          continue; /* undefined result */
321 
322       copy &= ~(read ^ i);
323       invert &= read ^ i;
324       one &= read;
325       zero &= ~read;
326    }
327 
328    /* We didn't find valid masks for at least one bit. */
329    if ((copy | zero | one | invert) != NIR_MAX_SUBGROUP_SIZE - 1)
330       return NULL;
331 
332    unsigned and_mask = copy | invert;
333    unsigned xor_mask = (one | invert) & ~copy;
334 
335 #if 0
336    fprintf(stderr, "and %x, xor %x \n", and_mask, xor_mask);
337 
338    assert(false);
339 #endif
340 
341    if ((and_mask & (ctx->options->hw_subgroup_size - 1)) == 0) {
342       return nir_read_invocation(b, def, nir_imm_int(b, xor_mask));
343    } else if (and_mask == 0x7f && xor_mask == 0) {
344       return def;
345    } else if (ctx->options->use_shuffle_xor && and_mask == 0x7f) {
346       return nir_shuffle_xor(b, def, nir_imm_int(b, xor_mask));
347    } else if (ctx->options->use_masked_swizzle_amd && (and_mask & 0x60) == 0x60 && xor_mask <= 0x1f) {
348       return nir_masked_swizzle_amd(b, def, (xor_mask << 10) | (and_mask & 0x1f), .fetch_inactive = true);
349    }
350 
351    return NULL;
352 }
353 
354 static nir_def *
try_opt_rotate(nir_builder * b,nir_def * def,struct fotid_context * ctx)355 try_opt_rotate(nir_builder *b, nir_def *def, struct fotid_context *ctx)
356 {
357    for (unsigned csize = 4; csize <= ctx->options->hw_subgroup_size; csize *= 2) {
358       unsigned cmask = csize - 1;
359 
360       unsigned delta = UINT_MAX;
361       for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
362          if (ctx->src_invoc[i] >= ctx->options->hw_subgroup_size)
363             continue;
364 
365          if (ctx->src_invoc[i] >= i)
366             delta = ctx->src_invoc[i] - i;
367          else
368             delta = csize - i + ctx->src_invoc[i];
369          break;
370       }
371 
372       if (delta >= csize || delta == 0)
373          continue;
374 
375       bool use_rotate = true;
376       for (unsigned i = 0; use_rotate && i < ctx->options->hw_subgroup_size; i++) {
377          if (ctx->src_invoc[i] >= ctx->options->hw_subgroup_size)
378             continue;
379          use_rotate &= (((i + delta) & cmask) + (i & ~cmask)) == ctx->src_invoc[i];
380       }
381 
382       if (use_rotate)
383          return nir_rotate(b, def, nir_imm_int(b, delta), .cluster_size = csize);
384    }
385 
386    return NULL;
387 }
388 
389 static nir_def *
try_opt_dpp16_shift(nir_builder * b,nir_def * def,struct fotid_context * ctx)390 try_opt_dpp16_shift(nir_builder *b, nir_def *def, struct fotid_context *ctx)
391 {
392    int delta = INT_MAX;
393    for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
394       if (ctx->src_invoc[i] >= ctx->options->hw_subgroup_size)
395          continue;
396       delta = ctx->src_invoc[i] - i;
397       break;
398    }
399 
400    if (delta < -15 || delta > 15 || delta == 0)
401       return NULL;
402 
403    for (unsigned i = 0; i < ctx->options->hw_subgroup_size; i++) {
404       int read = i + delta;
405       bool out_of_bounds = (read & ~0xf) != (i & ~0xf);
406       if (ctx->reads_zero[i] && !out_of_bounds)
407          return NULL;
408       if (ctx->src_invoc[i] >= ctx->options->hw_subgroup_size)
409          continue;
410       if (read != ctx->src_invoc[i] || out_of_bounds)
411          return NULL;
412    }
413 
414    return nir_dpp16_shift_amd(b, def, .base = delta);
415 }
416 
417 static bool
opt_fotid_shuffle(nir_builder * b,nir_intrinsic_instr * instr,const radv_nir_opt_tid_function_options * options,bool revist_bcsel)418 opt_fotid_shuffle(nir_builder *b, nir_intrinsic_instr *instr, const radv_nir_opt_tid_function_options *options,
419                   bool revist_bcsel)
420 {
421    if (instr->intrinsic != nir_intrinsic_shuffle)
422       return false;
423    if (!instr->src[1].ssa->parent_instr->pass_flags)
424       return false;
425 
426    unsigned src_idx = 0;
427    nir_alu_instr *bcsel = get_singluar_user_bcsel(&instr->def, &src_idx);
428    /* Skip this shuffle, it will be revisited later when
429     * the function of tid mask is set on the bcsel.
430     */
431    if (bcsel && !revist_bcsel)
432       return false;
433 
434    /* We already tried (and failed) to optimize this shuffle. */
435    if (!bcsel && revist_bcsel)
436       return false;
437 
438    struct fotid_context ctx = {
439       .options = options,
440       .reads_zero = {0},
441       .shader = b->shader,
442    };
443 
444    memset(ctx.src_invoc, 0xff, sizeof(ctx.src_invoc));
445 
446    if (!gather_read_invocation_shuffle(instr->src[1].ssa, &ctx))
447       return false;
448 
449    /* Generalize src_invoc by taking into account which invocations
450     * do not use the shuffle result because of bcsel.
451     */
452    bool can_remove_bcsel = false;
453    if (bcsel)
454       can_remove_bcsel = gather_invocation_uses(bcsel, src_idx, &ctx);
455 
456 #if 0
457    for (int i = 0; i < options->hw_subgroup_size; i++) {
458       fprintf(stderr, "invocation %d reads %d\n", i, ctx.src_invoc[i]);
459    }
460 
461    for (int i = 0; i < options->hw_subgroup_size; i++) {
462       fprintf(stderr, "invocation %d zero %d\n", i, ctx.reads_zero[i]);
463    }
464 #endif
465 
466    b->cursor = nir_after_instr(&instr->instr);
467 
468    nir_def *res = NULL;
469 
470    if (can_remove_bcsel && options->use_dpp16_shift_amd) {
471       res = try_opt_dpp16_shift(b, instr->src[0].ssa, &ctx);
472       if (res) {
473          nir_def_rewrite_uses(&bcsel->def, res);
474          return true;
475       }
476    }
477 
478    if (!res)
479       res = try_opt_bitwise_mask(b, instr->src[0].ssa, &ctx);
480    if (!res && options->use_clustered_rotate)
481       res = try_opt_rotate(b, instr->src[0].ssa, &ctx);
482 
483    if (res) {
484       nir_def_replace(&instr->def, res);
485       return true;
486    } else {
487       return false;
488    }
489 }
490 
491 static bool
opt_fotid_bool(nir_builder * b,nir_alu_instr * instr,const radv_nir_opt_tid_function_options * options)492 opt_fotid_bool(nir_builder *b, nir_alu_instr *instr, const radv_nir_opt_tid_function_options *options)
493 {
494    nir_scalar s = {&instr->def, 0};
495 
496    b->cursor = nir_after_instr(&instr->instr);
497 
498    nir_def *ballot_comp[NIR_MAX_VEC_COMPONENTS];
499 
500    for (unsigned comp = 0; comp < options->hw_ballot_num_comp; comp++) {
501       uint64_t cballot = 0;
502       for (unsigned i = 0; i < options->hw_ballot_bit_size; i++) {
503          unsigned invocation_id = comp * options->hw_ballot_bit_size + i;
504          if (invocation_id >= options->hw_subgroup_size)
505             break;
506          nir_const_value value;
507          if (!constant_fold_scalar(s, invocation_id, b->shader, &value, 0))
508             return false;
509          cballot |= nir_const_value_as_uint(value, 1) << i;
510       }
511       ballot_comp[comp] = nir_imm_intN_t(b, cballot, options->hw_ballot_bit_size);
512    }
513 
514    nir_def *ballot = nir_vec(b, ballot_comp, options->hw_ballot_num_comp);
515    nir_def *res = nir_inverse_ballot(b, 1, ballot);
516    res->parent_instr->pass_flags = 1;
517 
518    nir_def_replace(&instr->def, res);
519    return true;
520 }
521 
522 static bool
visit_instr(nir_builder * b,nir_instr * instr,void * params)523 visit_instr(nir_builder *b, nir_instr *instr, void *params)
524 {
525    const radv_nir_opt_tid_function_options *options = params;
526    update_fotid_instr(b, instr, options);
527 
528    switch (instr->type) {
529    case nir_instr_type_alu: {
530       nir_alu_instr *alu = nir_instr_as_alu(instr);
531 
532       if (alu->op == nir_op_bcsel && alu->def.bit_size != 1) {
533          /* revist shuffles that we skipped previously */
534          bool progress = false;
535          for (unsigned i = 1; i < 3; i++) {
536             nir_instr *src_instr = alu->src[i].src.ssa->parent_instr;
537             if (src_instr->type == nir_instr_type_intrinsic) {
538                nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(src_instr);
539                progress |= opt_fotid_shuffle(b, intrin, options, true);
540                if (list_is_empty(&alu->def.uses))
541                   break;
542             }
543          }
544          return progress;
545       }
546 
547       if (!options->hw_ballot_bit_size || !options->hw_ballot_num_comp)
548          return false;
549       if (alu->def.bit_size != 1 || alu->def.num_components > 1 || !instr->pass_flags)
550          return false;
551       return opt_fotid_bool(b, alu, options);
552    }
553    case nir_instr_type_intrinsic: {
554       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
555       return opt_fotid_shuffle(b, intrin, options, false);
556    }
557    default:
558       return false;
559    }
560 }
561 
562 bool
radv_nir_opt_tid_function(nir_shader * shader,const radv_nir_opt_tid_function_options * options)563 radv_nir_opt_tid_function(nir_shader *shader, const radv_nir_opt_tid_function_options *options)
564 {
565    return nir_shader_instructions_pass(shader, visit_instr, nir_metadata_control_flow, (void *)options);
566 }
567