xref: /aosp_15_r20/external/mesa3d/src/asahi/compiler/agx_lower_divergent_shuffle.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2024 Alyssa Rosenzweig
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "agx_builder.h"
7 #include "agx_compiler.h"
8 #include "agx_opcodes.h"
9 
10 static bool
is_shuffle(enum agx_opcode op)11 is_shuffle(enum agx_opcode op)
12 {
13    switch (op) {
14    case AGX_OPCODE_SHUFFLE:
15    case AGX_OPCODE_SHUFFLE_UP:
16    case AGX_OPCODE_SHUFFLE_DOWN:
17    case AGX_OPCODE_SHUFFLE_XOR:
18    case AGX_OPCODE_QUAD_SHUFFLE:
19    case AGX_OPCODE_QUAD_SHUFFLE_UP:
20    case AGX_OPCODE_QUAD_SHUFFLE_DOWN:
21    case AGX_OPCODE_QUAD_SHUFFLE_XOR:
22       return true;
23    default:
24       return false;
25    }
26 }
27 
28 /*
29  * AGX shuffle instructions read indices to shuffle with from the entire quad
30  * and accumulate them. That means that an inactive thread anywhere in the quad
31  * can make the whole shuffle undefined! To workaround, we reserve a scratch
32  * register (r0h) which we keep zero throughout the program... except for when
33  * actually shuffling, when we copy the shuffle index into r0h for the
34  * operation. This ensures that inactive threads read 0 for their index and
35  * hence do not contribute to the accumulated index.
36  */
37 void
agx_lower_divergent_shuffle(agx_context * ctx)38 agx_lower_divergent_shuffle(agx_context *ctx)
39 {
40    agx_builder b = agx_init_builder(ctx, agx_before_function(ctx));
41    agx_index scratch = agx_register(1, AGX_SIZE_16);
42 
43    assert(ctx->any_quad_divergent_shuffle);
44    agx_mov_imm_to(&b, scratch, 0);
45 
46    agx_foreach_block(ctx, block) {
47       bool needs_zero = false;
48 
49       agx_foreach_instr_in_block_safe(block, I) {
50          if (is_shuffle(I->op) && I->src[1].type == AGX_INDEX_REGISTER) {
51             assert(I->dest[0].value != scratch.value);
52             assert(I->src[0].value != scratch.value);
53             assert(I->src[1].value != scratch.value);
54 
55             /* Use scratch register for our input, then zero it at the end of
56              * the block so all inactive threads read zero.
57              */
58             b.cursor = agx_before_instr(I);
59             agx_mov_to(&b, scratch, I->src[1]);
60             needs_zero = true;
61 
62             I->src[1] = scratch;
63          }
64       }
65 
66       if (needs_zero) {
67          b.cursor = agx_after_block_logical(block);
68          agx_mov_imm_to(&b, scratch, 0);
69       }
70    }
71 }
72