xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/asahi/agx_nir_lower_sysvals.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2022 Alyssa Rosenzweig
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "compiler/nir/nir_builder.h"
7 #include "pipe/p_defines.h"
8 #include "util/bitset.h"
9 #include "util/u_dynarray.h"
10 #include "agx_nir_lower_gs.h"
11 #include "agx_state.h"
12 #include "nir.h"
13 #include "nir_builder_opcodes.h"
14 #include "nir_intrinsics.h"
15 #include "nir_intrinsics_indices.h"
16 #include "shader_enums.h"
17 
18 #define AGX_TEXTURE_DESC_STRIDE 24
19 
20 /*
21  * Lower all system values to uniform loads. This pass tries to compact ranges
22  * of contiguous uploaded uniforms to reduce the draw-time overhead of uploading
23  * many tiny ranges. To do so, it works in 4 steps:
24  *
25  * 1. Lower NIR sysvals to loads from the system value buffers.
26  * 2. Walk the NIR, recording loads from system value buffers.
27  * 2. Walk the ranges of uniforms needed, compacting into contiguous ranges.
28  * 3. Fill in the load_preamble instructions with the real uniforms.
29  */
30 
31 #define MAX_TABLE_SIZE sizeof(struct agx_stage_uniforms)
32 static_assert(sizeof(struct agx_draw_uniforms) <= MAX_TABLE_SIZE, "packed");
33 
34 struct table_state {
35    /* Bitset of 16-bit uniforms pushed */
36    BITSET_DECLARE(pushed, MAX_TABLE_SIZE / 2);
37 
38    /* Element size in 16-bit units, so we may split ranges of different sizes
39     * to guarantee natural alignment.
40     */
41    uint8_t element_size[MAX_TABLE_SIZE / 2];
42 };
43 
44 struct state {
45    gl_shader_stage stage, hw_stage;
46 
47    /* Array of nir_intrinsic_instr's to fix up at the end */
48    struct util_dynarray loads;
49 
50    struct table_state tables[AGX_NUM_SYSVAL_TABLES];
51 };
52 
53 static nir_def *
load_sysval(nir_builder * b,unsigned dim,unsigned bitsize,uint8_t table,uint16_t offset)54 load_sysval(nir_builder *b, unsigned dim, unsigned bitsize, uint8_t table,
55             uint16_t offset)
56 {
57    return nir_load_sysval_agx(b, dim, bitsize, .desc_set = table,
58                               .binding = offset);
59 }
60 
61 static nir_def *
load_sysval_root(nir_builder * b,unsigned dim,unsigned bitsize,void * ptr)62 load_sysval_root(nir_builder *b, unsigned dim, unsigned bitsize, void *ptr)
63 {
64    return load_sysval(b, dim, bitsize, AGX_SYSVAL_TABLE_ROOT, (uintptr_t)ptr);
65 }
66 
67 static nir_def *
load_sysval_indirect(nir_builder * b,unsigned dim,unsigned bitsize,uint8_t table,void * base,nir_def * offset_el)68 load_sysval_indirect(nir_builder *b, unsigned dim, unsigned bitsize,
69                      uint8_t table, void *base, nir_def *offset_el)
70 {
71    nir_scalar scalar = {offset_el, 0};
72    unsigned stride = (dim * bitsize) / 8;
73 
74    if (nir_scalar_is_const(scalar)) {
75       /* Load the sysval directly */
76       return load_sysval(
77          b, dim, bitsize, table,
78          (uintptr_t)base + (nir_scalar_as_uint(scalar) * stride));
79    } else {
80       /* Load the base address of the table */
81       struct agx_draw_uniforms *u = NULL;
82       nir_def *table_base = load_sysval_root(b, 1, 64, &u->tables[table]);
83 
84       /* Load address of the array in the table */
85       nir_def *array_base = nir_iadd_imm(b, table_base, (uintptr_t)base);
86 
87       /* Index into the table and load */
88       nir_def *address = nir_iadd(
89          b, array_base, nir_u2u64(b, nir_imul_imm(b, offset_el, stride)));
90       return nir_load_global_constant(b, address, bitsize / 8, dim, bitsize);
91    }
92 }
93 
94 static unsigned
stage_table(nir_builder * b)95 stage_table(nir_builder *b)
96 {
97    gl_shader_stage stage = b->shader->info.stage;
98    if (stage == MESA_SHADER_VERTEX && b->shader->info.vs.tes_agx)
99       stage = MESA_SHADER_TESS_EVAL;
100 
101    assert(stage < PIPE_SHADER_TYPES);
102    return AGX_SYSVAL_STAGE(stage);
103 }
104 
105 static nir_def *
load_ubo(nir_builder * b,nir_intrinsic_instr * intr,void * bases)106 load_ubo(nir_builder *b, nir_intrinsic_instr *intr, void *bases)
107 {
108    nir_def *base =
109       load_sysval_indirect(b, 1, 64, stage_table(b), bases, intr->src[0].ssa);
110 
111    nir_def *address = nir_iadd(b, base, nir_u2u64(b, intr->src[1].ssa));
112 
113    return nir_load_global_constant(b, address, nir_intrinsic_align(intr),
114                                    intr->num_components, intr->def.bit_size);
115 }
116 
117 static nir_def *
load_texture_handle(nir_builder * b,nir_intrinsic_instr * intr,void * base)118 load_texture_handle(nir_builder *b, nir_intrinsic_instr *intr, void *base)
119 {
120    nir_def *uniform =
121       nir_load_sysval_agx(b, 1, 64, .desc_set = stage_table(b),
122                           .binding = (uintptr_t)base, .flags = ~0);
123 
124    return nir_vec2(
125       b, nir_u2u32(b, uniform),
126       nir_imul_imm(b, nir_u2u32(b, intr->src[0].ssa), AGX_TEXTURE_DESC_STRIDE));
127 }
128 
129 static nir_def *
load_provoking_vtx(nir_builder * b)130 load_provoking_vtx(nir_builder *b)
131 {
132    struct agx_draw_uniforms *u = NULL;
133    return load_sysval_root(b, 1, 16, &u->provoking_vertex);
134 }
135 
136 static nir_def *
lower_intrinsic(nir_builder * b,nir_intrinsic_instr * intr,bool lower_draw_params)137 lower_intrinsic(nir_builder *b, nir_intrinsic_instr *intr,
138                 bool lower_draw_params)
139 {
140    struct agx_draw_uniforms *u = NULL;
141    struct agx_stage_uniforms *s = NULL;
142 
143    switch (intr->intrinsic) {
144    case nir_intrinsic_load_ubo:
145       return load_ubo(b, intr, s->ubo_base);
146    case nir_intrinsic_load_texture_handle_agx:
147       return load_texture_handle(b, intr, &s->texture_base);
148    case nir_intrinsic_load_sampler_handle_agx:
149       return load_sysval_indirect(b, 1, 16, stage_table(b), &s->sampler_handle,
150                                   intr->src[0].ssa);
151    case nir_intrinsic_load_vbo_base_agx:
152       return load_sysval_indirect(b, 1, 64, AGX_SYSVAL_TABLE_ROOT,
153                                   &u->attrib_base, intr->src[0].ssa);
154    case nir_intrinsic_load_attrib_clamp_agx:
155       return load_sysval_indirect(b, 1, 32, AGX_SYSVAL_TABLE_ROOT,
156                                   &u->attrib_clamp, intr->src[0].ssa);
157    case nir_intrinsic_load_blend_const_color_r_float:
158       return load_sysval_root(b, 1, 32, &u->blend_constant[0]);
159    case nir_intrinsic_load_blend_const_color_g_float:
160       return load_sysval_root(b, 1, 32, &u->blend_constant[1]);
161    case nir_intrinsic_load_blend_const_color_b_float:
162       return load_sysval_root(b, 1, 32, &u->blend_constant[2]);
163    case nir_intrinsic_load_blend_const_color_a_float:
164       return load_sysval_root(b, 1, 32, &u->blend_constant[3]);
165    case nir_intrinsic_load_api_sample_mask_agx:
166       return load_sysval_root(b, 1, 16, &u->sample_mask);
167    case nir_intrinsic_load_sample_positions_agx:
168       return load_sysval_root(b, 1, 32, &u->ppp_multisamplectl);
169    case nir_intrinsic_load_stat_query_address_agx:
170       return load_sysval_root(
171          b, 1, 64, &u->pipeline_statistics[nir_intrinsic_base(intr)]);
172    case nir_intrinsic_load_ssbo_address:
173       assert(nir_src_as_uint(intr->src[1]) == 0);
174       return load_sysval_indirect(b, 1, 64, stage_table(b), &s->ssbo_base,
175                                   intr->src[0].ssa);
176    case nir_intrinsic_get_ubo_size:
177       return load_sysval_indirect(b, 1, 32, stage_table(b), &s->ubo_size,
178                                   intr->src[0].ssa);
179    case nir_intrinsic_get_ssbo_size:
180       return load_sysval_indirect(b, 1, 32, stage_table(b), &s->ssbo_size,
181                                   intr->src[0].ssa);
182    case nir_intrinsic_load_input_assembly_buffer_agx:
183       return load_sysval_root(b, 1, 64, &u->input_assembly);
184    case nir_intrinsic_load_geometry_param_buffer_agx:
185       return load_sysval_root(b, 1, 64, &u->geometry_params);
186    case nir_intrinsic_load_vs_output_buffer_agx:
187       return nir_load_global_constant(
188          b, load_sysval_root(b, 1, 64, &u->vertex_output_buffer_ptr), 8, 1, 64);
189    case nir_intrinsic_load_vs_outputs_agx:
190       return load_sysval_root(b, 1, 64, &u->vertex_outputs);
191    case nir_intrinsic_load_tess_param_buffer_agx:
192       return load_sysval_root(b, 1, 64, &u->tess_params);
193    case nir_intrinsic_load_fixed_point_size_agx:
194       return load_sysval_root(b, 1, 32, &u->fixed_point_size);
195    case nir_intrinsic_load_tex_sprite_mask_agx:
196       return load_sysval_root(b, 1, 16, &u->sprite_mask);
197    case nir_intrinsic_load_shader_part_tests_zs_agx:
198       return load_sysval_root(b, 1, 16, &u->no_epilog_discard);
199    case nir_intrinsic_load_clip_z_coeff_agx:
200       return nir_f2f32(b, load_sysval_root(b, 1, 16, &u->clip_z_coeff));
201    case nir_intrinsic_load_depth_never_agx:
202       /* TODO: Do we need this workaround for anything in GL? */
203       return nir_imm_intN_t(b, 0, 16);
204    case nir_intrinsic_load_uvs_index_agx:
205       return load_sysval_root(
206          b, 1, 16, &u->uvs_index[nir_intrinsic_io_semantics(intr).location]);
207    case nir_intrinsic_load_is_first_fan_agx:
208       return nir_ieq_imm(b, load_provoking_vtx(b), 1);
209    case nir_intrinsic_load_provoking_last:
210       return nir_b2b32(b, nir_ieq_imm(b, load_provoking_vtx(b), 2));
211    default:
212       break;
213    }
214 
215    if (!lower_draw_params)
216       return NULL;
217 
218    switch (intr->intrinsic) {
219    case nir_intrinsic_load_num_workgroups:
220       return load_sysval(b, 3, 32, AGX_SYSVAL_TABLE_GRID, 0);
221    case nir_intrinsic_load_first_vertex:
222       return load_sysval(b, 1, 32, AGX_SYSVAL_TABLE_PARAMS, 0);
223    case nir_intrinsic_load_base_instance:
224       return load_sysval(b, 1, 32, AGX_SYSVAL_TABLE_PARAMS, 4);
225    case nir_intrinsic_load_base_vertex:
226       /* first vertex if indexed, 0 otherwise. More efficient for our hw than
227        * the lowering in NIR.
228        */
229       return nir_bcsel(
230          b, nir_i2b(b, load_sysval_root(b, 1, 16, &u->is_indexed_draw)),
231          load_sysval(b, 1, 32, AGX_SYSVAL_TABLE_PARAMS, 0), nir_imm_int(b, 0));
232    case nir_intrinsic_load_draw_id:
233       return load_sysval_root(b, 1, 32, &u->draw_id);
234    default:
235       return NULL;
236    }
237 }
238 
239 /* Step 1. Lower NIR sysvals */
240 static bool
lower_sysvals(nir_builder * b,nir_instr * instr,void * data)241 lower_sysvals(nir_builder *b, nir_instr *instr, void *data)
242 {
243    bool *lower_draw_params = data;
244    b->cursor = nir_before_instr(instr);
245    nir_def *old;
246    nir_def *replacement = NULL;
247 
248    if (instr->type == nir_instr_type_intrinsic) {
249       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
250       old = &intr->def;
251       replacement = lower_intrinsic(b, intr, *lower_draw_params);
252    } else if (instr->type == nir_instr_type_tex) {
253       nir_tex_instr *tex = nir_instr_as_tex(instr);
254       old = &tex->def;
255 
256       if (tex->op != nir_texop_lod_bias_agx)
257          return false;
258 
259       struct agx_stage_uniforms *s = NULL;
260 
261       int src_idx = nir_tex_instr_src_index(tex, nir_tex_src_texture_offset);
262       if (src_idx >= 0) {
263          replacement = load_sysval_indirect(
264             b, 1, 16, stage_table(b), s->lod_bias, tex->src[src_idx].src.ssa);
265       } else {
266          replacement = load_sysval(b, 1, 16, stage_table(b),
267                                    (uintptr_t)&s->lod_bias[tex->sampler_index]);
268       }
269    }
270 
271    if (replacement != NULL) {
272       nir_def_rewrite_uses(old, replacement);
273       return true;
274    } else {
275       return false;
276    }
277 }
278 
279 /* Step 2: Record system value loads */
280 static bool
record_loads(nir_builder * b,nir_intrinsic_instr * intr,void * data)281 record_loads(nir_builder *b, nir_intrinsic_instr *intr, void *data)
282 {
283    if (intr->intrinsic != nir_intrinsic_load_sysval_agx)
284       return false;
285 
286    assert(intr->def.bit_size >= 16 && "no 8-bit sysvals");
287    unsigned dim = intr->def.num_components;
288    unsigned element_size = intr->def.bit_size / 16;
289    unsigned length = dim * element_size;
290 
291    struct state *state = data;
292    struct table_state *table = &state->tables[nir_intrinsic_desc_set(intr)];
293    unsigned offset = nir_intrinsic_binding(intr);
294    assert((offset % 2) == 0 && "all entries are aligned by ABI");
295 
296    BITSET_SET_RANGE(table->pushed, (offset / 2), (offset / 2) + length - 1);
297 
298    for (unsigned i = 0; i < length; ++i) {
299       if (table->element_size[(offset / 2) + i])
300          assert((table->element_size[(offset / 2) + i]) == element_size);
301       else
302          table->element_size[(offset / 2) + i] = element_size;
303    }
304 
305    util_dynarray_append(&state->loads, nir_intrinsic_instr *, intr);
306    return false;
307 }
308 
309 /* Step 3: Decide where to push the system values */
310 static struct agx_push_range *
find_push_range_containing(struct agx_compiled_shader * shader,uint8_t table,uint16_t offset)311 find_push_range_containing(struct agx_compiled_shader *shader, uint8_t table,
312                            uint16_t offset)
313 {
314    for (unsigned i = 0; i < shader->push_range_count; ++i) {
315       struct agx_push_range *range = &shader->push[i];
316 
317       if (range->table != table)
318          continue;
319 
320       /* range->length is 16-bit words, need to convert. offset is bytes. */
321       uint16_t length_B = range->length * 2;
322 
323       if (range->offset <= offset && offset < (range->offset + length_B))
324          return range;
325    }
326 
327    unreachable("no containing range");
328 }
329 
330 static unsigned
lay_out_table(struct agx_compiled_shader * shader,struct table_state * state,unsigned table_index,unsigned uniform)331 lay_out_table(struct agx_compiled_shader *shader, struct table_state *state,
332               unsigned table_index, unsigned uniform)
333 {
334    unsigned start, end;
335    BITSET_FOREACH_RANGE(start, end, state->pushed, sizeof(state->pushed) * 8) {
336       unsigned range_start = start;
337 
338       do {
339          uint8_t size = state->element_size[range_start];
340 
341          /* Find a range of constant element size. [range_start, range_end).
342           * Ranges may be at most 64 halfs.
343           */
344          unsigned range_end;
345          for (range_end = range_start + 1;
346               range_end < end && state->element_size[range_end] == size &&
347               range_end < range_start + 64;
348               ++range_end)
349             ;
350 
351          /* Now make the range with the given size (naturally aligned) */
352          uniform = ALIGN_POT(uniform, size);
353 
354          assert((shader->push_range_count < ARRAY_SIZE(shader->push)) &&
355                 "AGX_MAX_PUSH_RANGES must be an upper bound");
356 
357          /* Offsets must be aligned to 4 bytes, this may require pushing a
358           * little more than intended (otherwise we would need extra copies)
359           */
360          range_start = ROUND_DOWN_TO(range_start, 4 / 2);
361 
362          shader->push[shader->push_range_count++] = (struct agx_push_range){
363             .uniform = uniform,
364             .table = table_index,
365             .offset = range_start * 2 /* bytes, not elements */,
366             .length = (range_end - range_start),
367          };
368 
369          uniform += (range_end - range_start);
370          range_start = range_end;
371       } while (range_start < end);
372    }
373 
374    return uniform;
375 }
376 
377 static unsigned
lay_out_uniforms(struct agx_compiled_shader * shader,struct state * state)378 lay_out_uniforms(struct agx_compiled_shader *shader, struct state *state)
379 {
380    unsigned uniform = 0;
381 
382    if (state->stage == PIPE_SHADER_VERTEX ||
383        state->stage == PIPE_SHADER_TESS_EVAL) {
384       unsigned count =
385          DIV_ROUND_UP(BITSET_LAST_BIT(shader->attrib_components_read), 4);
386 
387       struct agx_draw_uniforms *u = NULL;
388 
389       if (count) {
390          shader->push[shader->push_range_count++] = (struct agx_push_range){
391             .uniform = 0,
392             .table = AGX_SYSVAL_TABLE_ROOT,
393             .offset = (uintptr_t)&u->attrib_base,
394             .length = 4 * count,
395          };
396 
397          shader->push[shader->push_range_count++] = (struct agx_push_range){
398             .uniform = 4 * count,
399             .table = AGX_SYSVAL_TABLE_ROOT,
400             .offset = (uintptr_t)&u->attrib_clamp,
401             .length = 2 * count,
402          };
403       }
404 
405       shader->push[shader->push_range_count++] = (struct agx_push_range){
406          .uniform = 6 * count,
407          .table = AGX_SYSVAL_TABLE_PARAMS,
408          .offset = 0,
409          .length = 4,
410       };
411 
412       uniform = (6 * count) + 4;
413 
414       if (state->hw_stage == PIPE_SHADER_COMPUTE) {
415          shader->push[shader->push_range_count++] = (struct agx_push_range){
416             .uniform = (6 * count) + 8,
417             .table = AGX_SYSVAL_TABLE_ROOT,
418             .offset = (uintptr_t)&u->input_assembly,
419             .length = 4,
420          };
421 
422          uniform = (6 * count) + 12;
423       }
424    } else if (state->stage == PIPE_SHADER_FRAGMENT) {
425       struct agx_draw_uniforms *u = NULL;
426       struct agx_stage_uniforms *s = NULL;
427       shader->push[shader->push_range_count++] = (struct agx_push_range){
428          .uniform = 0,
429          .table = AGX_SYSVAL_TABLE_FS,
430          .offset = (uintptr_t)&s->texture_base,
431          .length = 4,
432       };
433 
434       shader->push[shader->push_range_count++] = (struct agx_push_range){
435          .uniform = 4,
436          .table = AGX_SYSVAL_TABLE_ROOT,
437          .offset = (uintptr_t)&u->blend_constant,
438          .length = 8,
439       };
440 
441       shader->push[shader->push_range_count++] = (struct agx_push_range){
442          .uniform = 12,
443          .table = AGX_SYSVAL_TABLE_ROOT,
444          .offset = (uintptr_t)&u->tables[AGX_SYSVAL_TABLE_ROOT],
445          .length = 4,
446       };
447 
448       uniform = 16;
449    }
450 
451    /* Lay out each system value table. We do this backwards to ensure the first
452     * uniform goes to the bindless texture base.
453     */
454    for (int t = AGX_NUM_SYSVAL_TABLES - 1; t >= 0; --t)
455       uniform = lay_out_table(shader, &state->tables[t], t, uniform);
456 
457    /* Step 4: Fill in the loads */
458    util_dynarray_foreach(&state->loads, nir_intrinsic_instr *, intr_) {
459       nir_intrinsic_instr *intr = *intr_;
460       uint8_t table = nir_intrinsic_desc_set(intr);
461       uint16_t offset = nir_intrinsic_binding(intr);
462       bool load_uniform_location = nir_intrinsic_flags(intr);
463 
464       struct agx_push_range *range =
465          find_push_range_containing(shader, table, offset);
466       unsigned base = range->uniform + ((offset - range->offset) / 2);
467 
468       nir_builder b = nir_builder_at(nir_instr_remove(&(intr->instr)));
469       nir_def *repl;
470 
471       if (load_uniform_location) {
472          repl = nir_imm_int(&b, base);
473       } else {
474          repl = nir_load_preamble(&b, intr->def.num_components,
475                                   intr->def.bit_size, .base = base);
476       }
477 
478       nir_def_rewrite_uses(&intr->def, repl);
479    }
480 
481    return uniform;
482 }
483 
484 bool
agx_nir_lower_sysvals(nir_shader * shader,enum pipe_shader_type desc_stage,bool lower_draw_params)485 agx_nir_lower_sysvals(nir_shader *shader, enum pipe_shader_type desc_stage,
486                       bool lower_draw_params)
487 {
488    /* override stage for the duration on the pass. XXX: should refactor, but
489     * it's annoying!
490     */
491    enum pipe_shader_type phys_stage = shader->info.stage;
492    shader->info.stage = desc_stage;
493 
494    bool progress = nir_shader_instructions_pass(
495       shader, lower_sysvals, nir_metadata_control_flow, &lower_draw_params);
496 
497    shader->info.stage = phys_stage;
498    return progress;
499 }
500 
501 bool
agx_nir_layout_uniforms(nir_shader * shader,struct agx_compiled_shader * compiled,unsigned * push_size)502 agx_nir_layout_uniforms(nir_shader *shader,
503                         struct agx_compiled_shader *compiled,
504                         unsigned *push_size)
505 {
506    struct state state = {
507       .stage = compiled->stage,
508       .hw_stage = shader->info.stage,
509    };
510 
511    nir_shader_intrinsics_pass(shader, record_loads, nir_metadata_control_flow,
512                               &state);
513 
514    *push_size = lay_out_uniforms(compiled, &state);
515 
516    util_dynarray_fini(&state.loads);
517 
518    /* Make sure texture handles have constants associated */
519    nir_opt_constant_folding(shader);
520 
521    return true;
522 }
523