xref: /aosp_15_r20/external/mesa3d/src/asahi/compiler/agx_pressure_schedule.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Alyssa Rosenzweig
3  * Copyright 2022 Collabora Ltd.
4  * SPDX-License-Identifier: MIT
5  */
6 
7 /* Bottom-up local scheduler to reduce register pressure */
8 
9 #include "util/dag.h"
10 #include "agx_compiler.h"
11 #include "agx_opcodes.h"
12 
13 struct sched_ctx {
14    /* Dependency graph */
15    struct dag *dag;
16 
17    /* Live set */
18    BITSET_WORD *live;
19 };
20 
21 struct sched_node {
22    struct dag_node dag;
23 
24    /* Instruction this node represents */
25    agx_instr *instr;
26 };
27 
28 static void
add_dep(struct sched_node * a,struct sched_node * b)29 add_dep(struct sched_node *a, struct sched_node *b)
30 {
31    assert(a != b && "no self-dependencies");
32 
33    if (a && b)
34       dag_add_edge(&a->dag, &b->dag, 0);
35 }
36 
37 static void
serialize(struct sched_node * a,struct sched_node ** b)38 serialize(struct sched_node *a, struct sched_node **b)
39 {
40    add_dep(a, *b);
41    *b = a;
42 }
43 
44 static struct dag *
create_dag(agx_context * ctx,agx_block * block,void * memctx)45 create_dag(agx_context *ctx, agx_block *block, void *memctx)
46 {
47    struct dag *dag = dag_create(ctx);
48 
49    struct sched_node **last_write =
50       calloc(ctx->alloc, sizeof(struct sched_node *));
51    struct sched_node *coverage = NULL;
52    struct sched_node *preload = NULL;
53 
54    /* Last memory load, to serialize stores against */
55    struct sched_node *memory_load = NULL;
56 
57    /* Last memory store, to serialize loads and stores against */
58    struct sched_node *memory_store = NULL;
59 
60    agx_foreach_instr_in_block(block, I) {
61       /* Don't touch control flow */
62       if (instr_after_logical_end(I))
63          break;
64 
65       struct sched_node *node = rzalloc(memctx, struct sched_node);
66       node->instr = I;
67       dag_init_node(dag, &node->dag);
68 
69       /* Reads depend on writes, no other hazards in SSA */
70       agx_foreach_ssa_src(I, s) {
71          add_dep(node, last_write[I->src[s].value]);
72       }
73 
74       agx_foreach_ssa_dest(I, d) {
75          assert(I->dest[d].value < ctx->alloc);
76          last_write[I->dest[d].value] = node;
77       }
78 
79       /* Classify the instruction and add dependencies according to the class */
80       enum agx_schedule_class dep = agx_opcodes_info[I->op].schedule_class;
81       assert(dep != AGX_SCHEDULE_CLASS_INVALID && "invalid instruction seen");
82 
83       bool barrier = dep == AGX_SCHEDULE_CLASS_BARRIER;
84       bool discards =
85          I->op == AGX_OPCODE_SAMPLE_MASK || I->op == AGX_OPCODE_ZS_EMIT;
86 
87       if (dep == AGX_SCHEDULE_CLASS_STORE)
88          add_dep(node, memory_load);
89       else if (dep == AGX_SCHEDULE_CLASS_ATOMIC || barrier)
90          serialize(node, &memory_load);
91 
92       if (dep == AGX_SCHEDULE_CLASS_LOAD || dep == AGX_SCHEDULE_CLASS_STORE ||
93           dep == AGX_SCHEDULE_CLASS_ATOMIC || barrier)
94          serialize(node, &memory_store);
95 
96       if (dep == AGX_SCHEDULE_CLASS_COVERAGE || barrier)
97          serialize(node, &coverage);
98 
99       /* Make sure side effects happen before a discard */
100       if (discards)
101          add_dep(node, memory_store);
102 
103       if (dep == AGX_SCHEDULE_CLASS_PRELOAD)
104          serialize(node, &preload);
105       else
106          add_dep(node, preload);
107    }
108 
109    free(last_write);
110 
111    return dag;
112 }
113 
114 /*
115  * Calculate the change in register pressure from scheduling a given
116  * instruction. Equivalently, calculate the difference in the number of live
117  * registers before and after the instruction, given the live set after the
118  * instruction. This calculation follows immediately from the dataflow
119  * definition of liveness:
120  *
121  *      live_in = (live_out - KILL) + GEN
122  */
123 static signed
calculate_pressure_delta(agx_instr * I,BITSET_WORD * live)124 calculate_pressure_delta(agx_instr *I, BITSET_WORD *live)
125 {
126    signed delta = 0;
127 
128    /* Destinations must be unique */
129    agx_foreach_ssa_dest(I, d) {
130       if (BITSET_TEST(live, I->dest[d].value))
131          delta -= agx_index_size_16(I->dest[d]);
132    }
133 
134    agx_foreach_ssa_src(I, src) {
135       /* Filter duplicates */
136       bool dupe = false;
137 
138       for (unsigned i = 0; i < src; ++i) {
139          if (agx_is_equiv(I->src[i], I->src[src])) {
140             dupe = true;
141             break;
142          }
143       }
144 
145       if (!dupe && !BITSET_TEST(live, I->src[src].value))
146          delta += agx_index_size_16(I->src[src]);
147    }
148 
149    return delta;
150 }
151 
152 /*
153  * Choose the next instruction, bottom-up. For now we use a simple greedy
154  * heuristic: choose the instruction that has the best effect on liveness, while
155  * hoisting sample_mask.
156  */
157 static struct sched_node *
choose_instr(struct sched_ctx * s)158 choose_instr(struct sched_ctx *s)
159 {
160    int32_t min_delta = INT32_MAX;
161    struct sched_node *best = NULL;
162 
163    list_for_each_entry(struct sched_node, n, &s->dag->heads, dag.link) {
164       /* Heuristic: hoist sample_mask/zs_emit. This allows depth/stencil tests
165        * to run earlier, and potentially to discard the entire quad invocation
166        * earlier, reducing how much redundant fragment shader we run.
167        *
168        * Since we schedule backwards, we make that happen by only choosing
169        * sample_mask when all other instructions have been exhausted.
170        */
171       if (n->instr->op == AGX_OPCODE_SAMPLE_MASK ||
172           n->instr->op == AGX_OPCODE_ZS_EMIT) {
173 
174          if (!best) {
175             best = n;
176             assert(min_delta == INT32_MAX);
177          }
178 
179          continue;
180       }
181 
182       /* Heuristic: sink wait_pix to increase parallelism. Since wait_pix does
183        * not read or write registers, this has no effect on pressure.
184        */
185       if (n->instr->op == AGX_OPCODE_WAIT_PIX)
186          return n;
187 
188       int32_t delta = calculate_pressure_delta(n->instr, s->live);
189 
190       if (delta < min_delta) {
191          best = n;
192          min_delta = delta;
193       }
194    }
195 
196    return best;
197 }
198 
199 static void
pressure_schedule_block(agx_context * ctx,agx_block * block,struct sched_ctx * s)200 pressure_schedule_block(agx_context *ctx, agx_block *block, struct sched_ctx *s)
201 {
202    /* off by a constant, that's ok */
203    signed pressure = 0;
204    signed orig_max_pressure = 0;
205    unsigned nr_ins = 0;
206 
207    memcpy(s->live, block->live_out,
208           BITSET_WORDS(ctx->alloc) * sizeof(BITSET_WORD));
209 
210    agx_foreach_instr_in_block_rev(block, I) {
211       pressure += calculate_pressure_delta(I, s->live);
212       orig_max_pressure = MAX2(pressure, orig_max_pressure);
213       agx_liveness_ins_update(s->live, I);
214       nr_ins++;
215    }
216 
217    memcpy(s->live, block->live_out,
218           BITSET_WORDS(ctx->alloc) * sizeof(BITSET_WORD));
219 
220    /* off by a constant, that's ok */
221    signed max_pressure = 0;
222    pressure = 0;
223 
224    struct sched_node **schedule = calloc(nr_ins, sizeof(struct sched_node *));
225    nr_ins = 0;
226 
227    while (!list_is_empty(&s->dag->heads)) {
228       struct sched_node *node = choose_instr(s);
229       pressure += calculate_pressure_delta(node->instr, s->live);
230       max_pressure = MAX2(pressure, max_pressure);
231       dag_prune_head(s->dag, &node->dag);
232 
233       schedule[nr_ins++] = node;
234       agx_liveness_ins_update(s->live, node->instr);
235    }
236 
237    /* Bail if it looks like it's worse */
238    if (max_pressure >= orig_max_pressure) {
239       free(schedule);
240       return;
241    }
242 
243    /* Apply the schedule */
244    for (unsigned i = 0; i < nr_ins; ++i) {
245       agx_remove_instruction(schedule[i]->instr);
246       list_add(&schedule[i]->instr->link, &block->instructions);
247    }
248 
249    free(schedule);
250 }
251 
252 void
agx_pressure_schedule(agx_context * ctx)253 agx_pressure_schedule(agx_context *ctx)
254 {
255    agx_compute_liveness(ctx);
256    void *memctx = ralloc_context(ctx);
257    BITSET_WORD *live =
258       ralloc_array(memctx, BITSET_WORD, BITSET_WORDS(ctx->alloc));
259 
260    agx_foreach_block(ctx, block) {
261       struct sched_ctx sctx = {
262          .dag = create_dag(ctx, block, memctx),
263          .live = live,
264       };
265 
266       pressure_schedule_block(ctx, block, &sctx);
267    }
268 
269    /* Clean up after liveness analysis */
270    agx_foreach_instr_global(ctx, I) {
271       agx_foreach_ssa_src(I, s)
272          I->src[s].kill = false;
273    }
274 
275    ralloc_free(memctx);
276 }
277