xref: /aosp_15_r20/external/mesa3d/src/freedreno/ir3/ir3_opt_predicates.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2024 Igalia S.L.
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "ir3.h"
7 #include "ir3_shader.h"
8 
9 /* This pass tries to optimize away cmps.s.ne instructions created by
10  * ir3_get_predicate in order to write predicates. It does two things:
11  *  - Look through chains of multiple cmps.s.ne instructions and remove all but
12  *    the first.
13  *  - If the source of the cmps.s.ne can write directly to predicates (true for
14  *    bitops on a6xx+), remove the cmps.s.ne.
15  *
16  * In both cases, no instructions are actually removed but clones are made and
17  * we rely on DCE to remove anything that became unused. Note that it's fine to
18  * always make a clone since even in the case that the original instruction is
19  * also used for non-predicate sources (so it won't be DCE'd), we replaced a
20  * cmps.ne.s with another instruction so this pass should never increase
21  * instruction count.
22  */
23 
24 struct opt_predicates_ctx {
25    struct ir3 *ir;
26 
27    /* Map from instructions to their clones with a predicate destination. Used
28     * to prevent instructions being cloned multiple times.
29     */
30    struct hash_table *predicate_clones;
31 };
32 
33 static struct ir3_instruction *
clone_with_predicate_dst(struct opt_predicates_ctx * ctx,struct ir3_instruction * instr)34 clone_with_predicate_dst(struct opt_predicates_ctx *ctx,
35                          struct ir3_instruction *instr)
36 {
37    struct hash_entry *entry =
38       _mesa_hash_table_search(ctx->predicate_clones, instr);
39    if (entry)
40       return entry->data;
41 
42    assert(instr->dsts_count == 1);
43 
44    struct ir3_instruction *clone = ir3_instr_clone(instr);
45    ir3_instr_move_after(clone, instr);
46    clone->dsts[0]->flags |= IR3_REG_PREDICATE;
47    clone->dsts[0]->flags &= ~(IR3_REG_HALF | IR3_REG_SHARED);
48    _mesa_hash_table_insert(ctx->predicate_clones, instr, clone);
49    return clone;
50 }
51 
52 static bool
is_shared_or_const(struct ir3_register * reg)53 is_shared_or_const(struct ir3_register *reg)
54 {
55    return reg->flags & (IR3_REG_CONST | IR3_REG_SHARED);
56 }
57 
58 static bool
cat2_needs_scalar_alu(struct ir3_instruction * instr)59 cat2_needs_scalar_alu(struct ir3_instruction *instr)
60 {
61    return is_shared_or_const(instr->srcs[0]) &&
62           (instr->srcs_count == 1 || is_shared_or_const(instr->srcs[1]));
63 }
64 
65 static bool
can_write_predicate(struct opt_predicates_ctx * ctx,struct ir3_instruction * instr)66 can_write_predicate(struct opt_predicates_ctx *ctx,
67                     struct ir3_instruction *instr)
68 {
69    switch (instr->opc) {
70    case OPC_CMPS_S:
71    case OPC_CMPS_U:
72    case OPC_CMPS_F:
73       return !cat2_needs_scalar_alu(instr);
74    case OPC_AND_B:
75    case OPC_OR_B:
76    case OPC_NOT_B:
77    case OPC_XOR_B:
78    case OPC_GETBIT_B:
79       return ctx->ir->compiler->bitops_can_write_predicates &&
80              !cat2_needs_scalar_alu(instr);
81    default:
82       return false;
83    }
84 }
85 
86 /* Detects the pattern used by ir3_get_predicate to write a predicate register:
87  * cmps.s.ne pssa_x, ssa_y, 0
88  */
89 static bool
is_gpr_to_predicate_mov(struct ir3_instruction * instr)90 is_gpr_to_predicate_mov(struct ir3_instruction *instr)
91 {
92    return (instr->opc == OPC_CMPS_S) &&
93           (instr->cat2.condition == IR3_COND_NE) &&
94           (instr->srcs[0]->flags & IR3_REG_SSA) &&
95           (instr->srcs[1]->flags & IR3_REG_IMMED) &&
96           (instr->srcs[1]->iim_val == 0);
97 }
98 
99 /* Look through a chain of cmps.s.ne 0 instructions to find the initial source.
100  * Return it if it can write to predicates. Otherwise, return the first
101  * cmps.s.ne in the chain.
102  */
103 static struct ir3_register *
resolve_predicate_def(struct opt_predicates_ctx * ctx,struct ir3_register * src)104 resolve_predicate_def(struct opt_predicates_ctx *ctx, struct ir3_register *src)
105 {
106    struct ir3_register *def = src->def;
107 
108    while (is_gpr_to_predicate_mov(def->instr)) {
109       struct ir3_register *next_def = def->instr->srcs[0]->def;
110 
111       if (!can_write_predicate(ctx, next_def->instr))
112          return def;
113 
114       def = next_def;
115    }
116 
117    return def;
118 }
119 
120 /* Find all predicate sources and try to replace their defs with instructions
121  * that can directly write to predicates.
122  */
123 static bool
opt_instr(struct opt_predicates_ctx * ctx,struct ir3_instruction * instr)124 opt_instr(struct opt_predicates_ctx *ctx, struct ir3_instruction *instr)
125 {
126    bool progress = false;
127 
128    foreach_src (src, instr) {
129       if (!(src->flags & IR3_REG_PREDICATE))
130          continue;
131 
132       struct ir3_register *def = resolve_predicate_def(ctx, src);
133 
134       if (src->def == def)
135          continue;
136 
137       assert(can_write_predicate(ctx, def->instr) &&
138              !(def->flags & IR3_REG_PREDICATE));
139 
140       struct ir3_instruction *predicate =
141          clone_with_predicate_dst(ctx, def->instr);
142       assert(predicate->dsts_count == 1);
143 
144       src->def = predicate->dsts[0];
145       progress = true;
146    }
147 
148    return progress;
149 }
150 
151 static bool
opt_blocks(struct opt_predicates_ctx * ctx)152 opt_blocks(struct opt_predicates_ctx *ctx)
153 {
154    bool progress = false;
155 
156    foreach_block (block, &ctx->ir->block_list) {
157       foreach_instr (instr, &block->instr_list) {
158          progress |= opt_instr(ctx, instr);
159       }
160    }
161 
162    return progress;
163 }
164 
165 bool
ir3_opt_predicates(struct ir3 * ir,struct ir3_shader_variant * v)166 ir3_opt_predicates(struct ir3 *ir, struct ir3_shader_variant *v)
167 {
168    struct opt_predicates_ctx *ctx = rzalloc(NULL, struct opt_predicates_ctx);
169    ctx->ir = ir;
170    ctx->predicate_clones = _mesa_pointer_hash_table_create(ctx);
171 
172    bool progress = opt_blocks(ctx);
173 
174    ralloc_free(ctx);
175    return progress;
176 }
177