xref: /aosp_15_r20/external/mesa3d/src/amd/common/ac_nir_lower_esgs_io_to_mem.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 #include "nir_builder.h"
10 
11 /*
12  * Lower NIR cross-stage I/O intrinsics into the memory accesses that actually happen on the HW.
13  *
14  * These HW stages are used only when a Geometry Shader is used.
15  * Export Shader (ES) runs the SW stage before GS, can be either VS or TES.
16  *
17  * * GFX6-8:
18  *   ES and GS are separate HW stages.
19  *   I/O is passed between them through VRAM.
20  * * GFX9+:
21  *   ES and GS are merged into a single HW stage.
22  *   I/O is passed between them through LDS.
23  *
24  */
25 
26 typedef struct {
27    /* Which hardware generation we're dealing with */
28    enum amd_gfx_level gfx_level;
29 
30    /* I/O semantic -> real location used by lowering. */
31    ac_nir_map_io_driver_location map_io;
32 
33    /* Stride of an ES invocation outputs in esgs ring, in bytes. */
34    unsigned esgs_itemsize;
35 
36    /* Enable fix for triangle strip adjacency in geometry shader. */
37    bool gs_triangle_strip_adjacency_fix;
38 
39    /* Bit mask of inputs read by the GS,
40     * this is used for linking ES outputs to GS inputs.
41     */
42    uint64_t gs_inputs_read;
43 } lower_esgs_io_state;
44 
45 static nir_def *
emit_split_buffer_load(nir_builder * b,unsigned num_components,unsigned bit_size,unsigned component_stride,nir_def * desc,nir_def * v_off,nir_def * s_off)46 emit_split_buffer_load(nir_builder *b, unsigned num_components, unsigned bit_size,
47                        unsigned component_stride, nir_def *desc, nir_def *v_off, nir_def *s_off)
48 {
49    unsigned total_bytes = num_components * bit_size / 8u;
50    unsigned full_dwords = total_bytes / 4u;
51    unsigned remaining_bytes = total_bytes - full_dwords * 4u;
52 
53    /* Accommodate max number of split 64-bit loads */
54    nir_def *comps[NIR_MAX_VEC_COMPONENTS * 2u];
55 
56    /* Assume that 1x32-bit load is better than 1x16-bit + 1x8-bit */
57    if (remaining_bytes == 3) {
58       remaining_bytes = 0;
59       full_dwords++;
60    }
61 
62    nir_def *zero = nir_imm_int(b, 0);
63 
64    for (unsigned i = 0; i < full_dwords; ++i)
65       comps[i] = nir_load_buffer_amd(b, 1, 32, desc, v_off, s_off, zero,
66                                      .base = component_stride * i, .memory_modes = nir_var_shader_in,
67                                      .access = ACCESS_COHERENT);
68 
69    if (remaining_bytes)
70       comps[full_dwords] = nir_load_buffer_amd(b, 1, remaining_bytes * 8, desc, v_off, s_off, zero,
71                                                .base = component_stride * full_dwords,
72                                                .memory_modes = nir_var_shader_in,
73                                                .access = ACCESS_COHERENT);
74 
75    return nir_extract_bits(b, comps, full_dwords + !!remaining_bytes, 0, num_components, bit_size);
76 }
77 
78 static void
emit_split_buffer_store(nir_builder * b,nir_def * d,nir_def * desc,nir_def * v_off,nir_def * s_off,unsigned bit_size,unsigned const_offset,unsigned writemask,bool swizzled,bool slc)79 emit_split_buffer_store(nir_builder *b, nir_def *d, nir_def *desc, nir_def *v_off, nir_def *s_off,
80                         unsigned bit_size, unsigned const_offset, unsigned writemask, bool swizzled, bool slc)
81 {
82    nir_def *zero = nir_imm_int(b, 0);
83 
84    while (writemask) {
85       int start, count;
86       u_bit_scan_consecutive_range(&writemask, &start, &count);
87       assert(start >= 0 && count >= 0);
88 
89       unsigned bytes = count * bit_size / 8u;
90       unsigned start_byte = start * bit_size / 8u;
91 
92       while (bytes) {
93          unsigned store_bytes = MIN2(bytes, 4u);
94          if ((start_byte % 4) == 1 || (start_byte % 4) == 3)
95             store_bytes = MIN2(store_bytes, 1);
96          else if ((start_byte % 4) == 2)
97             store_bytes = MIN2(store_bytes, 2);
98 
99          nir_def *store_val = nir_extract_bits(b, &d, 1, start_byte * 8u, 1, store_bytes * 8u);
100          nir_store_buffer_amd(b, store_val, desc, v_off, s_off, zero,
101                               .base = start_byte + const_offset, .memory_modes = nir_var_shader_out,
102                               .access = ACCESS_COHERENT |
103                                         (slc ? ACCESS_NON_TEMPORAL : 0) |
104                                         (swizzled ? ACCESS_IS_SWIZZLED_AMD : 0));
105 
106          start_byte += store_bytes;
107          bytes -= store_bytes;
108       }
109    }
110 }
111 
112 static bool
lower_es_output_store(nir_builder * b,nir_intrinsic_instr * intrin,void * state)113 lower_es_output_store(nir_builder *b,
114                       nir_intrinsic_instr *intrin,
115                       void *state)
116 {
117    if (intrin->intrinsic != nir_intrinsic_store_output)
118       return false;
119 
120    /* The ARB_shader_viewport_layer_array spec contains the
121     * following issue:
122     *
123     *    2) What happens if gl_ViewportIndex or gl_Layer is
124     *    written in the vertex shader and a geometry shader is
125     *    present?
126     *
127     *    RESOLVED: The value written by the last vertex processing
128     *    stage is used. If the last vertex processing stage
129     *    (vertex, tessellation evaluation or geometry) does not
130     *    statically assign to gl_ViewportIndex or gl_Layer, index
131     *    or layer zero is assumed.
132     *
133     * Vulkan spec 15.7 Built-In Variables:
134     *
135     *   The last active pre-rasterization shader stage (in pipeline order)
136     *   controls the Layer that is used. Outputs in previous shader stages
137     *   are not used, even if the last stage fails to write the Layer.
138     *
139     *   The last active pre-rasterization shader stage (in pipeline order)
140     *   controls the ViewportIndex that is used. Outputs in previous shader
141     *   stages are not used, even if the last stage fails to write the
142     *   ViewportIndex.
143     *
144     * So writes to those outputs in ES are simply ignored.
145     */
146    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
147    if (io_sem.location == VARYING_SLOT_LAYER || io_sem.location == VARYING_SLOT_VIEWPORT) {
148       nir_instr_remove(&intrin->instr);
149       return true;
150    }
151 
152    lower_esgs_io_state *st = (lower_esgs_io_state *) state;
153 
154    /* When an ES output isn't read by GS, don't emit anything. */
155    if ((io_sem.no_varying || !(st->gs_inputs_read & BITFIELD64_BIT(io_sem.location)))) {
156       nir_instr_remove(&intrin->instr);
157       return true;
158    }
159 
160    const unsigned write_mask = nir_intrinsic_write_mask(intrin);
161 
162    b->cursor = nir_before_instr(&intrin->instr);
163    unsigned mapped = ac_nir_map_io_location(io_sem.location, st->gs_inputs_read, st->map_io);
164    nir_def *io_off = ac_nir_calc_io_off(b, intrin, nir_imm_int(b, 16u), 4u, mapped);
165    nir_def *store_val = intrin->src[0].ssa;
166 
167    if (st->gfx_level <= GFX8) {
168       /* GFX6-8: ES is a separate HW stage, data is passed from ES to GS in VRAM. */
169       nir_def *ring = nir_load_ring_esgs_amd(b);
170       nir_def *es2gs_off = nir_load_ring_es2gs_offset_amd(b);
171       AC_NIR_STORE_IO(b, store_val, 0, write_mask, io_sem.high_16bits, emit_split_buffer_store,
172                       ring, io_off, es2gs_off, store_val->bit_size, store_const_offset,
173                       store_write_mask, true, true);
174    } else {
175       /* GFX9+: ES is merged into GS, data is passed through LDS. */
176       nir_def *vertex_idx = nir_load_local_invocation_index(b);
177       nir_def *off = nir_iadd(b, nir_imul_imm(b, vertex_idx, st->esgs_itemsize), io_off);
178       AC_NIR_STORE_IO(b, store_val, 0, write_mask, io_sem.high_16bits, nir_store_shared, off,
179                       .write_mask = store_write_mask, .base = store_const_offset);
180    }
181 
182    nir_instr_remove(&intrin->instr);
183    return true;
184 }
185 
186 static nir_def *
gs_get_vertex_offset(nir_builder * b,lower_esgs_io_state * st,unsigned vertex_index)187 gs_get_vertex_offset(nir_builder *b, lower_esgs_io_state *st, unsigned vertex_index)
188 {
189    nir_def *origin = nir_load_gs_vertex_offset_amd(b, .base = vertex_index);
190    if (!st->gs_triangle_strip_adjacency_fix)
191       return origin;
192 
193    unsigned fixed_index;
194    if (st->gfx_level < GFX9) {
195       /* Rotate vertex index by 2. */
196       fixed_index = (vertex_index + 4) % 6;
197    } else {
198       /* This issue has been fixed for GFX10+ */
199       assert(st->gfx_level == GFX9);
200       /* 6 vertex offset are packed to 3 vgprs for GFX9+ */
201       fixed_index = (vertex_index + 2) % 3;
202    }
203    nir_def *fixed = nir_load_gs_vertex_offset_amd(b, .base = fixed_index);
204 
205    nir_def *prim_id = nir_load_primitive_id(b);
206    /* odd primitive id use fixed offset */
207    nir_def *cond = nir_i2b(b, nir_iand_imm(b, prim_id, 1));
208    return nir_bcsel(b, cond, fixed, origin);
209 }
210 
211 static nir_def *
gs_per_vertex_input_vertex_offset_gfx6(nir_builder * b,lower_esgs_io_state * st,nir_src * vertex_src)212 gs_per_vertex_input_vertex_offset_gfx6(nir_builder *b, lower_esgs_io_state *st,
213                                        nir_src *vertex_src)
214 {
215    if (nir_src_is_const(*vertex_src))
216       return gs_get_vertex_offset(b, st, nir_src_as_uint(*vertex_src));
217 
218    nir_def *vertex_offset = gs_get_vertex_offset(b, st, 0);
219 
220    for (unsigned i = 1; i < b->shader->info.gs.vertices_in; ++i) {
221       nir_def *cond = nir_ieq_imm(b, vertex_src->ssa, i);
222       nir_def *elem = gs_get_vertex_offset(b, st, i);
223       vertex_offset = nir_bcsel(b, cond, elem, vertex_offset);
224    }
225 
226    return vertex_offset;
227 }
228 
229 static nir_def *
gs_per_vertex_input_vertex_offset_gfx9(nir_builder * b,lower_esgs_io_state * st,nir_src * vertex_src)230 gs_per_vertex_input_vertex_offset_gfx9(nir_builder *b, lower_esgs_io_state *st,
231                                        nir_src *vertex_src)
232 {
233    if (nir_src_is_const(*vertex_src)) {
234       unsigned vertex = nir_src_as_uint(*vertex_src);
235       return nir_ubfe_imm(b, gs_get_vertex_offset(b, st, vertex / 2u),
236                           (vertex & 1u) * 16u, 16u);
237    }
238 
239    nir_def *vertex_offset = gs_get_vertex_offset(b, st, 0);
240 
241    for (unsigned i = 1; i < b->shader->info.gs.vertices_in; i++) {
242       nir_def *cond = nir_ieq_imm(b, vertex_src->ssa, i);
243       nir_def *elem = gs_get_vertex_offset(b, st, i / 2u * 2u);
244       if (i % 2u)
245          elem = nir_ishr_imm(b, elem, 16u);
246 
247       vertex_offset = nir_bcsel(b, cond, elem, vertex_offset);
248    }
249 
250    return nir_iand_imm(b, vertex_offset, 0xffffu);
251 }
252 
253 static nir_def *
gs_per_vertex_input_vertex_offset_gfx12(nir_builder * b,lower_esgs_io_state * st,nir_src * vertex_src)254 gs_per_vertex_input_vertex_offset_gfx12(nir_builder *b, lower_esgs_io_state *st,
255                                         nir_src *vertex_src)
256 {
257    if (nir_src_is_const(*vertex_src)) {
258       unsigned vertex = nir_src_as_uint(*vertex_src);
259       return nir_ubfe_imm(b, gs_get_vertex_offset(b, st, vertex / 3),
260                           (vertex % 3) * 9, 8);
261    }
262 
263    nir_def *bitoffset = nir_imul_imm(b, nir_umod_imm(b, vertex_src->ssa, 3), 9);
264    nir_def *lt3 = nir_ult(b, vertex_src->ssa, nir_imm_int(b, 3));
265 
266    return nir_bcsel(b, lt3,
267                     nir_ubfe(b, gs_get_vertex_offset(b, st, 0), bitoffset, nir_imm_int(b, 8)),
268                     nir_ubfe(b, gs_get_vertex_offset(b, st, 1), bitoffset, nir_imm_int(b, 8)));
269 }
270 
271 static nir_def *
gs_per_vertex_input_offset(nir_builder * b,lower_esgs_io_state * st,nir_intrinsic_instr * instr)272 gs_per_vertex_input_offset(nir_builder *b,
273                            lower_esgs_io_state *st,
274                            nir_intrinsic_instr *instr)
275 {
276    nir_src *vertex_src = nir_get_io_arrayed_index_src(instr);
277    nir_def *vertex_offset;
278 
279    if (st->gfx_level >= GFX12)
280       vertex_offset = gs_per_vertex_input_vertex_offset_gfx12(b, st, vertex_src);
281    else if (st->gfx_level >= GFX9)
282       vertex_offset = gs_per_vertex_input_vertex_offset_gfx9(b, st, vertex_src);
283    else
284       vertex_offset = gs_per_vertex_input_vertex_offset_gfx6(b, st, vertex_src);
285 
286    /* Gfx6-8 can't emulate VGT_ESGS_RING_ITEMSIZE because it uses the register to determine
287     * the allocation size of the ESGS ring buffer in memory.
288     */
289    if (st->gfx_level >= GFX9)
290       vertex_offset = nir_imul(b, vertex_offset, nir_load_esgs_vertex_stride_amd(b));
291 
292    unsigned base_stride = st->gfx_level >= GFX9 ? 1 : 64 /* Wave size on GFX6-8 */;
293    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(instr);
294    unsigned mapped = ac_nir_map_io_location(io_sem.location, st->gs_inputs_read, st->map_io);
295    nir_def *io_off = ac_nir_calc_io_off(b, instr, nir_imm_int(b, base_stride * 4u), base_stride, mapped);
296    nir_def *off = nir_iadd(b, io_off, vertex_offset);
297    return nir_imul_imm(b, off, 4u);
298 }
299 
300 static nir_def *
lower_gs_per_vertex_input_load(nir_builder * b,nir_instr * instr,void * state)301 lower_gs_per_vertex_input_load(nir_builder *b,
302                                nir_instr *instr,
303                                void *state)
304 {
305    lower_esgs_io_state *st = (lower_esgs_io_state *) state;
306    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
307    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
308    nir_def *off = gs_per_vertex_input_offset(b, st, intrin);
309    nir_def *load = NULL;
310 
311    if (st->gfx_level >= GFX9) {
312       AC_NIR_LOAD_IO(load, b, intrin->num_components, intrin->def.bit_size, io_sem.high_16bits,
313                      nir_load_shared, off);
314    } else {
315       AC_NIR_LOAD_IO(load, b, intrin->num_components, intrin->def.bit_size, io_sem.high_16bits,
316                      emit_split_buffer_load, 4 * 64, nir_load_ring_esgs_amd(b), off, nir_imm_int(b, 0));
317    }
318 
319    return load;
320 }
321 
322 static bool
filter_load_per_vertex_input(const nir_instr * instr,UNUSED const void * state)323 filter_load_per_vertex_input(const nir_instr *instr, UNUSED const void *state)
324 {
325    return instr->type == nir_instr_type_intrinsic && nir_instr_as_intrinsic(instr)->intrinsic == nir_intrinsic_load_per_vertex_input;
326 }
327 
328 void
ac_nir_lower_es_outputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,unsigned esgs_itemsize,uint64_t gs_inputs_read)329 ac_nir_lower_es_outputs_to_mem(nir_shader *shader,
330                                ac_nir_map_io_driver_location map,
331                                enum amd_gfx_level gfx_level,
332                                unsigned esgs_itemsize,
333                                uint64_t gs_inputs_read)
334 {
335    lower_esgs_io_state state = {
336       .gfx_level = gfx_level,
337       .esgs_itemsize = esgs_itemsize,
338       .map_io = map,
339       .gs_inputs_read = gs_inputs_read,
340    };
341 
342    nir_shader_intrinsics_pass(shader, lower_es_output_store,
343                                 nir_metadata_control_flow,
344                                 &state);
345 }
346 
347 void
ac_nir_lower_gs_inputs_to_mem(nir_shader * shader,ac_nir_map_io_driver_location map,enum amd_gfx_level gfx_level,bool triangle_strip_adjacency_fix)348 ac_nir_lower_gs_inputs_to_mem(nir_shader *shader,
349                               ac_nir_map_io_driver_location map,
350                               enum amd_gfx_level gfx_level,
351                               bool triangle_strip_adjacency_fix)
352 {
353    lower_esgs_io_state state = {
354       .gfx_level = gfx_level,
355       .map_io = map,
356       .gs_triangle_strip_adjacency_fix = triangle_strip_adjacency_fix,
357       .gs_inputs_read = shader->info.inputs_read,
358    };
359 
360    nir_shader_lower_instructions(shader,
361                                  filter_load_per_vertex_input,
362                                  lower_gs_per_vertex_input_load,
363                                  &state);
364 }
365