1 /*
2 * Copyright 2022 Alyssa Rosenzweig
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "compiler/nir/nir_builder.h"
7 #include "pipe/p_defines.h"
8 #include "util/bitset.h"
9 #include "util/u_dynarray.h"
10 #include "agx_nir_lower_gs.h"
11 #include "agx_state.h"
12 #include "nir.h"
13 #include "nir_builder_opcodes.h"
14 #include "nir_intrinsics.h"
15 #include "nir_intrinsics_indices.h"
16 #include "shader_enums.h"
17
18 #define AGX_TEXTURE_DESC_STRIDE 24
19
20 /*
21 * Lower all system values to uniform loads. This pass tries to compact ranges
22 * of contiguous uploaded uniforms to reduce the draw-time overhead of uploading
23 * many tiny ranges. To do so, it works in 4 steps:
24 *
25 * 1. Lower NIR sysvals to loads from the system value buffers.
26 * 2. Walk the NIR, recording loads from system value buffers.
27 * 2. Walk the ranges of uniforms needed, compacting into contiguous ranges.
28 * 3. Fill in the load_preamble instructions with the real uniforms.
29 */
30
31 #define MAX_TABLE_SIZE sizeof(struct agx_stage_uniforms)
32 static_assert(sizeof(struct agx_draw_uniforms) <= MAX_TABLE_SIZE, "packed");
33
34 struct table_state {
35 /* Bitset of 16-bit uniforms pushed */
36 BITSET_DECLARE(pushed, MAX_TABLE_SIZE / 2);
37
38 /* Element size in 16-bit units, so we may split ranges of different sizes
39 * to guarantee natural alignment.
40 */
41 uint8_t element_size[MAX_TABLE_SIZE / 2];
42 };
43
44 struct state {
45 gl_shader_stage stage, hw_stage;
46
47 /* Array of nir_intrinsic_instr's to fix up at the end */
48 struct util_dynarray loads;
49
50 struct table_state tables[AGX_NUM_SYSVAL_TABLES];
51 };
52
53 static nir_def *
load_sysval(nir_builder * b,unsigned dim,unsigned bitsize,uint8_t table,uint16_t offset)54 load_sysval(nir_builder *b, unsigned dim, unsigned bitsize, uint8_t table,
55 uint16_t offset)
56 {
57 return nir_load_sysval_agx(b, dim, bitsize, .desc_set = table,
58 .binding = offset);
59 }
60
61 static nir_def *
load_sysval_root(nir_builder * b,unsigned dim,unsigned bitsize,void * ptr)62 load_sysval_root(nir_builder *b, unsigned dim, unsigned bitsize, void *ptr)
63 {
64 return load_sysval(b, dim, bitsize, AGX_SYSVAL_TABLE_ROOT, (uintptr_t)ptr);
65 }
66
67 static nir_def *
load_sysval_indirect(nir_builder * b,unsigned dim,unsigned bitsize,uint8_t table,void * base,nir_def * offset_el)68 load_sysval_indirect(nir_builder *b, unsigned dim, unsigned bitsize,
69 uint8_t table, void *base, nir_def *offset_el)
70 {
71 nir_scalar scalar = {offset_el, 0};
72 unsigned stride = (dim * bitsize) / 8;
73
74 if (nir_scalar_is_const(scalar)) {
75 /* Load the sysval directly */
76 return load_sysval(
77 b, dim, bitsize, table,
78 (uintptr_t)base + (nir_scalar_as_uint(scalar) * stride));
79 } else {
80 /* Load the base address of the table */
81 struct agx_draw_uniforms *u = NULL;
82 nir_def *table_base = load_sysval_root(b, 1, 64, &u->tables[table]);
83
84 /* Load address of the array in the table */
85 nir_def *array_base = nir_iadd_imm(b, table_base, (uintptr_t)base);
86
87 /* Index into the table and load */
88 nir_def *address = nir_iadd(
89 b, array_base, nir_u2u64(b, nir_imul_imm(b, offset_el, stride)));
90 return nir_load_global_constant(b, address, bitsize / 8, dim, bitsize);
91 }
92 }
93
94 static unsigned
stage_table(nir_builder * b)95 stage_table(nir_builder *b)
96 {
97 gl_shader_stage stage = b->shader->info.stage;
98 if (stage == MESA_SHADER_VERTEX && b->shader->info.vs.tes_agx)
99 stage = MESA_SHADER_TESS_EVAL;
100
101 assert(stage < PIPE_SHADER_TYPES);
102 return AGX_SYSVAL_STAGE(stage);
103 }
104
105 static nir_def *
load_ubo(nir_builder * b,nir_intrinsic_instr * intr,void * bases)106 load_ubo(nir_builder *b, nir_intrinsic_instr *intr, void *bases)
107 {
108 nir_def *base =
109 load_sysval_indirect(b, 1, 64, stage_table(b), bases, intr->src[0].ssa);
110
111 nir_def *address = nir_iadd(b, base, nir_u2u64(b, intr->src[1].ssa));
112
113 return nir_load_global_constant(b, address, nir_intrinsic_align(intr),
114 intr->num_components, intr->def.bit_size);
115 }
116
117 static nir_def *
load_texture_handle(nir_builder * b,nir_intrinsic_instr * intr,void * base)118 load_texture_handle(nir_builder *b, nir_intrinsic_instr *intr, void *base)
119 {
120 nir_def *uniform =
121 nir_load_sysval_agx(b, 1, 64, .desc_set = stage_table(b),
122 .binding = (uintptr_t)base, .flags = ~0);
123
124 return nir_vec2(
125 b, nir_u2u32(b, uniform),
126 nir_imul_imm(b, nir_u2u32(b, intr->src[0].ssa), AGX_TEXTURE_DESC_STRIDE));
127 }
128
129 static nir_def *
load_provoking_vtx(nir_builder * b)130 load_provoking_vtx(nir_builder *b)
131 {
132 struct agx_draw_uniforms *u = NULL;
133 return load_sysval_root(b, 1, 16, &u->provoking_vertex);
134 }
135
136 static nir_def *
lower_intrinsic(nir_builder * b,nir_intrinsic_instr * intr,bool lower_draw_params)137 lower_intrinsic(nir_builder *b, nir_intrinsic_instr *intr,
138 bool lower_draw_params)
139 {
140 struct agx_draw_uniforms *u = NULL;
141 struct agx_stage_uniforms *s = NULL;
142
143 switch (intr->intrinsic) {
144 case nir_intrinsic_load_ubo:
145 return load_ubo(b, intr, s->ubo_base);
146 case nir_intrinsic_load_texture_handle_agx:
147 return load_texture_handle(b, intr, &s->texture_base);
148 case nir_intrinsic_load_sampler_handle_agx:
149 return load_sysval_indirect(b, 1, 16, stage_table(b), &s->sampler_handle,
150 intr->src[0].ssa);
151 case nir_intrinsic_load_vbo_base_agx:
152 return load_sysval_indirect(b, 1, 64, AGX_SYSVAL_TABLE_ROOT,
153 &u->attrib_base, intr->src[0].ssa);
154 case nir_intrinsic_load_attrib_clamp_agx:
155 return load_sysval_indirect(b, 1, 32, AGX_SYSVAL_TABLE_ROOT,
156 &u->attrib_clamp, intr->src[0].ssa);
157 case nir_intrinsic_load_blend_const_color_r_float:
158 return load_sysval_root(b, 1, 32, &u->blend_constant[0]);
159 case nir_intrinsic_load_blend_const_color_g_float:
160 return load_sysval_root(b, 1, 32, &u->blend_constant[1]);
161 case nir_intrinsic_load_blend_const_color_b_float:
162 return load_sysval_root(b, 1, 32, &u->blend_constant[2]);
163 case nir_intrinsic_load_blend_const_color_a_float:
164 return load_sysval_root(b, 1, 32, &u->blend_constant[3]);
165 case nir_intrinsic_load_api_sample_mask_agx:
166 return load_sysval_root(b, 1, 16, &u->sample_mask);
167 case nir_intrinsic_load_sample_positions_agx:
168 return load_sysval_root(b, 1, 32, &u->ppp_multisamplectl);
169 case nir_intrinsic_load_stat_query_address_agx:
170 return load_sysval_root(
171 b, 1, 64, &u->pipeline_statistics[nir_intrinsic_base(intr)]);
172 case nir_intrinsic_load_ssbo_address:
173 assert(nir_src_as_uint(intr->src[1]) == 0);
174 return load_sysval_indirect(b, 1, 64, stage_table(b), &s->ssbo_base,
175 intr->src[0].ssa);
176 case nir_intrinsic_get_ubo_size:
177 return load_sysval_indirect(b, 1, 32, stage_table(b), &s->ubo_size,
178 intr->src[0].ssa);
179 case nir_intrinsic_get_ssbo_size:
180 return load_sysval_indirect(b, 1, 32, stage_table(b), &s->ssbo_size,
181 intr->src[0].ssa);
182 case nir_intrinsic_load_input_assembly_buffer_agx:
183 return load_sysval_root(b, 1, 64, &u->input_assembly);
184 case nir_intrinsic_load_geometry_param_buffer_agx:
185 return load_sysval_root(b, 1, 64, &u->geometry_params);
186 case nir_intrinsic_load_vs_output_buffer_agx:
187 return nir_load_global_constant(
188 b, load_sysval_root(b, 1, 64, &u->vertex_output_buffer_ptr), 8, 1, 64);
189 case nir_intrinsic_load_vs_outputs_agx:
190 return load_sysval_root(b, 1, 64, &u->vertex_outputs);
191 case nir_intrinsic_load_tess_param_buffer_agx:
192 return load_sysval_root(b, 1, 64, &u->tess_params);
193 case nir_intrinsic_load_fixed_point_size_agx:
194 return load_sysval_root(b, 1, 32, &u->fixed_point_size);
195 case nir_intrinsic_load_tex_sprite_mask_agx:
196 return load_sysval_root(b, 1, 16, &u->sprite_mask);
197 case nir_intrinsic_load_shader_part_tests_zs_agx:
198 return load_sysval_root(b, 1, 16, &u->no_epilog_discard);
199 case nir_intrinsic_load_clip_z_coeff_agx:
200 return nir_f2f32(b, load_sysval_root(b, 1, 16, &u->clip_z_coeff));
201 case nir_intrinsic_load_depth_never_agx:
202 /* TODO: Do we need this workaround for anything in GL? */
203 return nir_imm_intN_t(b, 0, 16);
204 case nir_intrinsic_load_uvs_index_agx:
205 return load_sysval_root(
206 b, 1, 16, &u->uvs_index[nir_intrinsic_io_semantics(intr).location]);
207 case nir_intrinsic_load_is_first_fan_agx:
208 return nir_ieq_imm(b, load_provoking_vtx(b), 1);
209 case nir_intrinsic_load_provoking_last:
210 return nir_b2b32(b, nir_ieq_imm(b, load_provoking_vtx(b), 2));
211 default:
212 break;
213 }
214
215 if (!lower_draw_params)
216 return NULL;
217
218 switch (intr->intrinsic) {
219 case nir_intrinsic_load_num_workgroups:
220 return load_sysval(b, 3, 32, AGX_SYSVAL_TABLE_GRID, 0);
221 case nir_intrinsic_load_first_vertex:
222 return load_sysval(b, 1, 32, AGX_SYSVAL_TABLE_PARAMS, 0);
223 case nir_intrinsic_load_base_instance:
224 return load_sysval(b, 1, 32, AGX_SYSVAL_TABLE_PARAMS, 4);
225 case nir_intrinsic_load_base_vertex:
226 /* first vertex if indexed, 0 otherwise. More efficient for our hw than
227 * the lowering in NIR.
228 */
229 return nir_bcsel(
230 b, nir_i2b(b, load_sysval_root(b, 1, 16, &u->is_indexed_draw)),
231 load_sysval(b, 1, 32, AGX_SYSVAL_TABLE_PARAMS, 0), nir_imm_int(b, 0));
232 case nir_intrinsic_load_draw_id:
233 return load_sysval_root(b, 1, 32, &u->draw_id);
234 default:
235 return NULL;
236 }
237 }
238
239 /* Step 1. Lower NIR sysvals */
240 static bool
lower_sysvals(nir_builder * b,nir_instr * instr,void * data)241 lower_sysvals(nir_builder *b, nir_instr *instr, void *data)
242 {
243 bool *lower_draw_params = data;
244 b->cursor = nir_before_instr(instr);
245 nir_def *old;
246 nir_def *replacement = NULL;
247
248 if (instr->type == nir_instr_type_intrinsic) {
249 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
250 old = &intr->def;
251 replacement = lower_intrinsic(b, intr, *lower_draw_params);
252 } else if (instr->type == nir_instr_type_tex) {
253 nir_tex_instr *tex = nir_instr_as_tex(instr);
254 old = &tex->def;
255
256 if (tex->op != nir_texop_lod_bias_agx)
257 return false;
258
259 struct agx_stage_uniforms *s = NULL;
260
261 int src_idx = nir_tex_instr_src_index(tex, nir_tex_src_texture_offset);
262 if (src_idx >= 0) {
263 replacement = load_sysval_indirect(
264 b, 1, 16, stage_table(b), s->lod_bias, tex->src[src_idx].src.ssa);
265 } else {
266 replacement = load_sysval(b, 1, 16, stage_table(b),
267 (uintptr_t)&s->lod_bias[tex->sampler_index]);
268 }
269 }
270
271 if (replacement != NULL) {
272 nir_def_rewrite_uses(old, replacement);
273 return true;
274 } else {
275 return false;
276 }
277 }
278
279 /* Step 2: Record system value loads */
280 static bool
record_loads(nir_builder * b,nir_intrinsic_instr * intr,void * data)281 record_loads(nir_builder *b, nir_intrinsic_instr *intr, void *data)
282 {
283 if (intr->intrinsic != nir_intrinsic_load_sysval_agx)
284 return false;
285
286 assert(intr->def.bit_size >= 16 && "no 8-bit sysvals");
287 unsigned dim = intr->def.num_components;
288 unsigned element_size = intr->def.bit_size / 16;
289 unsigned length = dim * element_size;
290
291 struct state *state = data;
292 struct table_state *table = &state->tables[nir_intrinsic_desc_set(intr)];
293 unsigned offset = nir_intrinsic_binding(intr);
294 assert((offset % 2) == 0 && "all entries are aligned by ABI");
295
296 BITSET_SET_RANGE(table->pushed, (offset / 2), (offset / 2) + length - 1);
297
298 for (unsigned i = 0; i < length; ++i) {
299 if (table->element_size[(offset / 2) + i])
300 assert((table->element_size[(offset / 2) + i]) == element_size);
301 else
302 table->element_size[(offset / 2) + i] = element_size;
303 }
304
305 util_dynarray_append(&state->loads, nir_intrinsic_instr *, intr);
306 return false;
307 }
308
309 /* Step 3: Decide where to push the system values */
310 static struct agx_push_range *
find_push_range_containing(struct agx_compiled_shader * shader,uint8_t table,uint16_t offset)311 find_push_range_containing(struct agx_compiled_shader *shader, uint8_t table,
312 uint16_t offset)
313 {
314 for (unsigned i = 0; i < shader->push_range_count; ++i) {
315 struct agx_push_range *range = &shader->push[i];
316
317 if (range->table != table)
318 continue;
319
320 /* range->length is 16-bit words, need to convert. offset is bytes. */
321 uint16_t length_B = range->length * 2;
322
323 if (range->offset <= offset && offset < (range->offset + length_B))
324 return range;
325 }
326
327 unreachable("no containing range");
328 }
329
330 static unsigned
lay_out_table(struct agx_compiled_shader * shader,struct table_state * state,unsigned table_index,unsigned uniform)331 lay_out_table(struct agx_compiled_shader *shader, struct table_state *state,
332 unsigned table_index, unsigned uniform)
333 {
334 unsigned start, end;
335 BITSET_FOREACH_RANGE(start, end, state->pushed, sizeof(state->pushed) * 8) {
336 unsigned range_start = start;
337
338 do {
339 uint8_t size = state->element_size[range_start];
340
341 /* Find a range of constant element size. [range_start, range_end).
342 * Ranges may be at most 64 halfs.
343 */
344 unsigned range_end;
345 for (range_end = range_start + 1;
346 range_end < end && state->element_size[range_end] == size &&
347 range_end < range_start + 64;
348 ++range_end)
349 ;
350
351 /* Now make the range with the given size (naturally aligned) */
352 uniform = ALIGN_POT(uniform, size);
353
354 assert((shader->push_range_count < ARRAY_SIZE(shader->push)) &&
355 "AGX_MAX_PUSH_RANGES must be an upper bound");
356
357 /* Offsets must be aligned to 4 bytes, this may require pushing a
358 * little more than intended (otherwise we would need extra copies)
359 */
360 range_start = ROUND_DOWN_TO(range_start, 4 / 2);
361
362 shader->push[shader->push_range_count++] = (struct agx_push_range){
363 .uniform = uniform,
364 .table = table_index,
365 .offset = range_start * 2 /* bytes, not elements */,
366 .length = (range_end - range_start),
367 };
368
369 uniform += (range_end - range_start);
370 range_start = range_end;
371 } while (range_start < end);
372 }
373
374 return uniform;
375 }
376
377 static unsigned
lay_out_uniforms(struct agx_compiled_shader * shader,struct state * state)378 lay_out_uniforms(struct agx_compiled_shader *shader, struct state *state)
379 {
380 unsigned uniform = 0;
381
382 if (state->stage == PIPE_SHADER_VERTEX ||
383 state->stage == PIPE_SHADER_TESS_EVAL) {
384 unsigned count =
385 DIV_ROUND_UP(BITSET_LAST_BIT(shader->attrib_components_read), 4);
386
387 struct agx_draw_uniforms *u = NULL;
388
389 if (count) {
390 shader->push[shader->push_range_count++] = (struct agx_push_range){
391 .uniform = 0,
392 .table = AGX_SYSVAL_TABLE_ROOT,
393 .offset = (uintptr_t)&u->attrib_base,
394 .length = 4 * count,
395 };
396
397 shader->push[shader->push_range_count++] = (struct agx_push_range){
398 .uniform = 4 * count,
399 .table = AGX_SYSVAL_TABLE_ROOT,
400 .offset = (uintptr_t)&u->attrib_clamp,
401 .length = 2 * count,
402 };
403 }
404
405 shader->push[shader->push_range_count++] = (struct agx_push_range){
406 .uniform = 6 * count,
407 .table = AGX_SYSVAL_TABLE_PARAMS,
408 .offset = 0,
409 .length = 4,
410 };
411
412 uniform = (6 * count) + 4;
413
414 if (state->hw_stage == PIPE_SHADER_COMPUTE) {
415 shader->push[shader->push_range_count++] = (struct agx_push_range){
416 .uniform = (6 * count) + 8,
417 .table = AGX_SYSVAL_TABLE_ROOT,
418 .offset = (uintptr_t)&u->input_assembly,
419 .length = 4,
420 };
421
422 uniform = (6 * count) + 12;
423 }
424 } else if (state->stage == PIPE_SHADER_FRAGMENT) {
425 struct agx_draw_uniforms *u = NULL;
426 struct agx_stage_uniforms *s = NULL;
427 shader->push[shader->push_range_count++] = (struct agx_push_range){
428 .uniform = 0,
429 .table = AGX_SYSVAL_TABLE_FS,
430 .offset = (uintptr_t)&s->texture_base,
431 .length = 4,
432 };
433
434 shader->push[shader->push_range_count++] = (struct agx_push_range){
435 .uniform = 4,
436 .table = AGX_SYSVAL_TABLE_ROOT,
437 .offset = (uintptr_t)&u->blend_constant,
438 .length = 8,
439 };
440
441 shader->push[shader->push_range_count++] = (struct agx_push_range){
442 .uniform = 12,
443 .table = AGX_SYSVAL_TABLE_ROOT,
444 .offset = (uintptr_t)&u->tables[AGX_SYSVAL_TABLE_ROOT],
445 .length = 4,
446 };
447
448 uniform = 16;
449 }
450
451 /* Lay out each system value table. We do this backwards to ensure the first
452 * uniform goes to the bindless texture base.
453 */
454 for (int t = AGX_NUM_SYSVAL_TABLES - 1; t >= 0; --t)
455 uniform = lay_out_table(shader, &state->tables[t], t, uniform);
456
457 /* Step 4: Fill in the loads */
458 util_dynarray_foreach(&state->loads, nir_intrinsic_instr *, intr_) {
459 nir_intrinsic_instr *intr = *intr_;
460 uint8_t table = nir_intrinsic_desc_set(intr);
461 uint16_t offset = nir_intrinsic_binding(intr);
462 bool load_uniform_location = nir_intrinsic_flags(intr);
463
464 struct agx_push_range *range =
465 find_push_range_containing(shader, table, offset);
466 unsigned base = range->uniform + ((offset - range->offset) / 2);
467
468 nir_builder b = nir_builder_at(nir_instr_remove(&(intr->instr)));
469 nir_def *repl;
470
471 if (load_uniform_location) {
472 repl = nir_imm_int(&b, base);
473 } else {
474 repl = nir_load_preamble(&b, intr->def.num_components,
475 intr->def.bit_size, .base = base);
476 }
477
478 nir_def_rewrite_uses(&intr->def, repl);
479 }
480
481 return uniform;
482 }
483
484 bool
agx_nir_lower_sysvals(nir_shader * shader,enum pipe_shader_type desc_stage,bool lower_draw_params)485 agx_nir_lower_sysvals(nir_shader *shader, enum pipe_shader_type desc_stage,
486 bool lower_draw_params)
487 {
488 /* override stage for the duration on the pass. XXX: should refactor, but
489 * it's annoying!
490 */
491 enum pipe_shader_type phys_stage = shader->info.stage;
492 shader->info.stage = desc_stage;
493
494 bool progress = nir_shader_instructions_pass(
495 shader, lower_sysvals, nir_metadata_control_flow, &lower_draw_params);
496
497 shader->info.stage = phys_stage;
498 return progress;
499 }
500
501 bool
agx_nir_layout_uniforms(nir_shader * shader,struct agx_compiled_shader * compiled,unsigned * push_size)502 agx_nir_layout_uniforms(nir_shader *shader,
503 struct agx_compiled_shader *compiled,
504 unsigned *push_size)
505 {
506 struct state state = {
507 .stage = compiled->stage,
508 .hw_stage = shader->info.stage,
509 };
510
511 nir_shader_intrinsics_pass(shader, record_loads, nir_metadata_control_flow,
512 &state);
513
514 *push_size = lay_out_uniforms(compiled, &state);
515
516 util_dynarray_fini(&state.loads);
517
518 /* Make sure texture handles have constants associated */
519 nir_opt_constant_folding(shader);
520
521 return true;
522 }
523