1 /*
2 * Copyright 2024 Alyssa Rosenzweig
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "agx_builder.h"
7 #include "agx_compiler.h"
8 #include "agx_opcodes.h"
9
10 static bool
is_shuffle(enum agx_opcode op)11 is_shuffle(enum agx_opcode op)
12 {
13 switch (op) {
14 case AGX_OPCODE_SHUFFLE:
15 case AGX_OPCODE_SHUFFLE_UP:
16 case AGX_OPCODE_SHUFFLE_DOWN:
17 case AGX_OPCODE_SHUFFLE_XOR:
18 case AGX_OPCODE_QUAD_SHUFFLE:
19 case AGX_OPCODE_QUAD_SHUFFLE_UP:
20 case AGX_OPCODE_QUAD_SHUFFLE_DOWN:
21 case AGX_OPCODE_QUAD_SHUFFLE_XOR:
22 return true;
23 default:
24 return false;
25 }
26 }
27
28 /*
29 * AGX shuffle instructions read indices to shuffle with from the entire quad
30 * and accumulate them. That means that an inactive thread anywhere in the quad
31 * can make the whole shuffle undefined! To workaround, we reserve a scratch
32 * register (r0h) which we keep zero throughout the program... except for when
33 * actually shuffling, when we copy the shuffle index into r0h for the
34 * operation. This ensures that inactive threads read 0 for their index and
35 * hence do not contribute to the accumulated index.
36 */
37 void
agx_lower_divergent_shuffle(agx_context * ctx)38 agx_lower_divergent_shuffle(agx_context *ctx)
39 {
40 agx_builder b = agx_init_builder(ctx, agx_before_function(ctx));
41 agx_index scratch = agx_register(1, AGX_SIZE_16);
42
43 assert(ctx->any_quad_divergent_shuffle);
44 agx_mov_imm_to(&b, scratch, 0);
45
46 agx_foreach_block(ctx, block) {
47 bool needs_zero = false;
48
49 agx_foreach_instr_in_block_safe(block, I) {
50 if (is_shuffle(I->op) && I->src[1].type == AGX_INDEX_REGISTER) {
51 assert(I->dest[0].value != scratch.value);
52 assert(I->src[0].value != scratch.value);
53 assert(I->src[1].value != scratch.value);
54
55 /* Use scratch register for our input, then zero it at the end of
56 * the block so all inactive threads read zero.
57 */
58 b.cursor = agx_before_instr(I);
59 agx_mov_to(&b, scratch, I->src[1]);
60 needs_zero = true;
61
62 I->src[1] = scratch;
63 }
64 }
65
66 if (needs_zero) {
67 b.cursor = agx_after_block_logical(block);
68 agx_mov_imm_to(&b, scratch, 0);
69 }
70 }
71 }
72