xref: /aosp_15_r20/external/mesa3d/src/amd/common/ac_nir_lower_ngg.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2021 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "ac_nir_helpers.h"
9 #include "amdgfxregs.h"
10 #include "nir_builder.h"
11 #include "nir_xfb_info.h"
12 #include "util/u_math.h"
13 #include "util/u_vector.h"
14 
15 #define SPECIAL_MS_OUT_MASK \
16    (BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) | \
17     BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | \
18     BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
19 
20 #define MS_PRIM_ARG_EXP_MASK \
21    (VARYING_BIT_LAYER | \
22     VARYING_BIT_VIEWPORT | \
23     VARYING_BIT_PRIMITIVE_SHADING_RATE)
24 
25 #define MS_VERT_ARG_EXP_MASK \
26    (VARYING_BIT_CULL_DIST0 | \
27     VARYING_BIT_CULL_DIST1 | \
28     VARYING_BIT_CLIP_DIST0 | \
29     VARYING_BIT_CLIP_DIST1 | \
30     VARYING_BIT_PSIZ)
31 
32 enum {
33    nggc_passflag_used_by_pos = 1,
34    nggc_passflag_used_by_other = 2,
35    nggc_passflag_used_by_both = nggc_passflag_used_by_pos | nggc_passflag_used_by_other,
36 };
37 
38 typedef struct
39 {
40    nir_def *ssa;
41    nir_variable *var;
42 } reusable_nondeferred_variable;
43 
44 typedef struct
45 {
46    gl_varying_slot slot;
47    nir_def *chan[4];
48 } vs_output;
49 
50 typedef struct
51 {
52    const ac_nir_lower_ngg_options *options;
53 
54    nir_variable *position_value_var;
55    nir_variable *prim_exp_arg_var;
56    nir_variable *es_accepted_var;
57    nir_variable *gs_accepted_var;
58    nir_variable *gs_exported_var;
59    nir_variable *gs_vtx_indices_vars[3];
60 
61    nir_def *vtx_addr[3];
62 
63    struct u_vector reusable_nondeferred_variables;
64 
65    bool early_prim_export;
66    bool streamout_enabled;
67    bool has_user_edgeflags;
68    bool skip_primitive_id;
69    unsigned max_num_waves;
70 
71    /* LDS params */
72    unsigned pervertex_lds_bytes;
73 
74    uint64_t inputs_needed_by_pos;
75    uint64_t inputs_needed_by_others;
76 
77    nir_instr *compact_arg_stores[4];
78    nir_intrinsic_instr *overwrite_args;
79    nir_variable *repacked_rel_patch_id;
80 
81    /* clip distance */
82    nir_variable *clip_vertex_var;
83    nir_variable *clipdist_neg_mask_var;
84    bool has_clipdist;
85 
86    /* outputs */
87    ac_nir_prerast_out out;
88 } lower_ngg_nogs_state;
89 
90 typedef struct
91 {
92    const ac_nir_lower_ngg_options *options;
93 
94    nir_function_impl *impl;
95    int const_out_vtxcnt[4];
96    int const_out_prmcnt[4];
97    unsigned max_num_waves;
98    unsigned num_vertices_per_primitive;
99    nir_def *lds_addr_gs_out_vtx;
100    nir_def *lds_addr_gs_scratch;
101    unsigned lds_bytes_per_gs_out_vertex;
102    unsigned lds_offs_primflags;
103    bool output_compile_time_known;
104    bool streamout_enabled;
105    /* Outputs */
106    ac_nir_prerast_out out;
107    /* Count per stream. */
108    nir_def *vertex_count[4];
109    nir_def *primitive_count[4];
110 } lower_ngg_gs_state;
111 
112 /* LDS layout of Mesh Shader workgroup info. */
113 enum {
114    /* DW0: number of primitives */
115    lds_ms_num_prims = 0,
116    /* DW1: number of vertices */
117    lds_ms_num_vtx = 4,
118    /* DW2: workgroup index within the current dispatch */
119    lds_ms_wg_index = 8,
120    /* DW3: number of API workgroups in flight */
121    lds_ms_num_api_waves = 12,
122 };
123 
124 /* Potential location for Mesh Shader outputs. */
125 typedef enum {
126    ms_out_mode_lds,
127    ms_out_mode_scratch_ring,
128    ms_out_mode_attr_ring,
129    ms_out_mode_var,
130 } ms_out_mode;
131 
132 typedef struct
133 {
134    uint64_t mask; /* Mask of output locations */
135    uint32_t addr; /* Base address */
136 } ms_out_part;
137 
138 typedef struct
139 {
140    /* Mesh shader LDS layout. For details, see ms_calculate_output_layout. */
141    struct {
142       uint32_t workgroup_info_addr;
143       ms_out_part vtx_attr;
144       ms_out_part prm_attr;
145       uint32_t indices_addr;
146       uint32_t cull_flags_addr;
147       uint32_t total_size;
148    } lds;
149 
150    /* VRAM "mesh shader scratch ring" layout for outputs that don't fit into the LDS.
151     * Not to be confused with scratch memory.
152     */
153    struct {
154       ms_out_part vtx_attr;
155       ms_out_part prm_attr;
156    } scratch_ring;
157 
158    /* VRAM attributes ring (GFX11 only) for all non-position outputs.
159     * GFX11 doesn't have to reload attributes from this ring at the end of the shader.
160     */
161    struct {
162       ms_out_part vtx_attr;
163       ms_out_part prm_attr;
164    } attr_ring;
165 
166    /* Outputs without cross-invocation access can be stored in variables. */
167    struct {
168       ms_out_part vtx_attr;
169       ms_out_part prm_attr;
170    } var;
171 } ms_out_mem_layout;
172 
173 typedef struct
174 {
175    enum amd_gfx_level gfx_level;
176    bool fast_launch_2;
177    bool vert_multirow_export;
178    bool prim_multirow_export;
179 
180    ms_out_mem_layout layout;
181    uint64_t per_vertex_outputs;
182    uint64_t per_primitive_outputs;
183    unsigned vertices_per_prim;
184 
185    unsigned wave_size;
186    unsigned api_workgroup_size;
187    unsigned hw_workgroup_size;
188 
189    nir_def *workgroup_index;
190    nir_variable *out_variables[VARYING_SLOT_MAX * 4];
191    nir_variable *primitive_count_var;
192    nir_variable *vertex_count_var;
193 
194    /* True if the lowering needs to insert the layer output. */
195    bool insert_layer_output;
196    /* True if cull flags are used */
197    bool uses_cull_flags;
198 
199    struct {
200       /* Bitmask of components used: 4 bits per slot, 1 bit per component. */
201       uint32_t components_mask;
202    } output_info[VARYING_SLOT_MAX];
203 
204    /* Used by outputs export. */
205    nir_def *outputs[VARYING_SLOT_MAX][4];
206    uint32_t clipdist_enable_mask;
207    const uint8_t *vs_output_param_offset;
208    bool has_param_exports;
209 
210    /* True if the lowering needs to insert shader query. */
211    bool has_query;
212 } lower_ngg_ms_state;
213 
214 /* Per-vertex LDS layout of culling shaders */
215 enum {
216    /* Position of the ES vertex (at the beginning for alignment reasons) */
217    lds_es_pos_x = 0,
218    lds_es_pos_y = 4,
219    lds_es_pos_z = 8,
220    lds_es_pos_w = 12,
221 
222    /* 1 when the vertex is accepted, 0 if it should be culled */
223    lds_es_vertex_accepted = 16,
224    /* ID of the thread which will export the current thread's vertex */
225    lds_es_exporter_tid = 17,
226    /* bit i is set when the i'th clip distance of a vertex is negative */
227    lds_es_clipdist_neg_mask = 18,
228    /* TES only, relative patch ID, less than max workgroup size */
229    lds_es_tes_rel_patch_id = 19,
230 
231    /* Repacked arguments - also listed separately for VS and TES */
232    lds_es_arg_0 = 20,
233 };
234 
235 typedef struct {
236    nir_def *num_repacked_invocations;
237    nir_def *repacked_invocation_index;
238 } wg_repack_result;
239 
240 /**
241  * Computes a horizontal sum of 8-bit packed values loaded from LDS.
242  *
243  * Each lane N will sum packed bytes 0 to N-1.
244  * We only care about the results from up to wave_id+1 lanes.
245  * (Other lanes are not deactivated but their calculation is not used.)
246  */
247 static nir_def *
summarize_repack(nir_builder * b,nir_def * packed_counts,unsigned num_lds_dwords)248 summarize_repack(nir_builder *b, nir_def *packed_counts, unsigned num_lds_dwords)
249 {
250    /* We'll use shift to filter out the bytes not needed by the current lane.
251     *
252     * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
253     * However, two shifts are needed because one can't go all the way,
254     * so the shift amount is half that (and in bits).
255     *
256     * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
257     * This will yield 0x01 at wanted byte positions and 0x00 at unwanted positions,
258     * therefore v_dot can get rid of the unneeded values.
259     * This sequence is preferable because it better hides the latency of the LDS.
260     *
261     * If the v_dot instruction can't be used, we left-shift the packed bytes.
262     * This will shift out the unneeded bytes and shift in zeroes instead,
263     * then we sum them using v_msad_u8.
264     */
265 
266    nir_def *lane_id = nir_load_subgroup_invocation(b);
267    nir_def *shift = nir_iadd_imm(b, nir_imul_imm(b, lane_id, -4u), num_lds_dwords * 16);
268    bool use_dot = b->shader->options->has_udot_4x8;
269 
270    if (num_lds_dwords == 1) {
271       nir_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int(b, 0x01010101), shift), shift);
272 
273       /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
274       nir_def *packed = nir_lane_permute_16_amd(b, packed_counts, nir_imm_int(b, 0), nir_imm_int(b, 0));
275 
276       /* Horizontally add the packed bytes. */
277       if (use_dot) {
278          return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
279       } else {
280          nir_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
281          return nir_msad_4x8(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
282       }
283    } else if (num_lds_dwords == 2) {
284       nir_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, nir_imm_int64(b, 0x0101010101010101), shift), shift);
285 
286       /* Broadcast the packed data we read from LDS (to the first 16 lanes, but we only care up to num_waves). */
287       nir_def *packed_dw0 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
288       nir_def *packed_dw1 = nir_lane_permute_16_amd(b, nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 0));
289 
290       /* Horizontally add the packed bytes. */
291       if (use_dot) {
292          nir_def *sum = nir_udot_4x8_uadd(b, packed_dw0, nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
293          return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, dot_op), sum);
294       } else {
295          nir_def *sad_op = nir_ishl(b, nir_ishl(b, nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
296          nir_def *sum = nir_msad_4x8(b, nir_unpack_64_2x32_split_x(b, sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
297          return nir_msad_4x8(b, nir_unpack_64_2x32_split_y(b, sad_op), nir_imm_int(b, 0), sum);
298       }
299    } else {
300       unreachable("Unimplemented NGG wave count");
301    }
302 }
303 
304 /**
305  * Repacks invocations in the current workgroup to eliminate gaps between them.
306  *
307  * Uses 1 dword of LDS per 4 waves (1 byte of LDS per wave).
308  * Assumes that all invocations in the workgroup are active (exec = -1).
309  */
310 static wg_repack_result
repack_invocations_in_workgroup(nir_builder * b,nir_def * input_bool,nir_def * lds_addr_base,unsigned max_num_waves,unsigned wave_size)311 repack_invocations_in_workgroup(nir_builder *b, nir_def *input_bool,
312                                 nir_def *lds_addr_base, unsigned max_num_waves,
313                                 unsigned wave_size)
314 {
315    /* Input boolean: 1 if the current invocation should survive the repack. */
316    assert(input_bool->bit_size == 1);
317 
318    /* STEP 1. Count surviving invocations in the current wave.
319     *
320     * Implemented by a scalar instruction that simply counts the number of bits set in a 32/64-bit mask.
321     */
322 
323    nir_def *input_mask = nir_ballot(b, 1, wave_size, input_bool);
324    nir_def *surviving_invocations_in_current_wave = nir_bit_count(b, input_mask);
325 
326    /* If we know at compile time that the workgroup has only 1 wave, no further steps are necessary. */
327    if (max_num_waves == 1) {
328       wg_repack_result r = {
329          .num_repacked_invocations = surviving_invocations_in_current_wave,
330          .repacked_invocation_index = nir_mbcnt_amd(b, input_mask, nir_imm_int(b, 0)),
331       };
332       return r;
333    }
334 
335    /* STEP 2. Waves tell each other their number of surviving invocations.
336     *
337     * Each wave activates only its first lane (exec = 1), which stores the number of surviving
338     * invocations in that wave into the LDS, then reads the numbers from every wave.
339     *
340     * The workgroup size of NGG shaders is at most 256, which means
341     * the maximum number of waves is 4 in Wave64 mode and 8 in Wave32 mode.
342     * Each wave writes 1 byte, so it's up to 8 bytes, so at most 2 dwords are necessary.
343     */
344 
345    const unsigned num_lds_dwords = DIV_ROUND_UP(max_num_waves, 4);
346    assert(num_lds_dwords <= 2);
347 
348    nir_def *wave_id = nir_load_subgroup_id(b);
349    nir_def *lds_offset = nir_iadd(b, lds_addr_base, wave_id);
350    nir_def *dont_care = nir_undef(b, 1, num_lds_dwords * 32);
351    nir_if *if_first_lane = nir_push_if(b, nir_elect(b, 1));
352 
353    nir_store_shared(b, nir_u2u8(b, surviving_invocations_in_current_wave), lds_offset);
354 
355    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
356                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
357 
358    nir_def *packed_counts =
359       nir_load_shared(b, 1, num_lds_dwords * 32, lds_addr_base, .align_mul = 8u);
360 
361    nir_pop_if(b, if_first_lane);
362 
363    packed_counts = nir_if_phi(b, packed_counts, dont_care);
364 
365    /* STEP 3. Compute the repacked invocation index and the total number of surviving invocations.
366     *
367     * By now, every wave knows the number of surviving invocations in all waves.
368     * Each number is 1 byte, and they are packed into up to 2 dwords.
369     *
370     * Each lane N will sum the number of surviving invocations from waves 0 to N-1.
371     * If the workgroup has M waves, then each wave will use only its first M+1 lanes for this.
372     * (Other lanes are not deactivated but their calculation is not used.)
373     *
374     * - We read the sum from the lane whose id is the current wave's id.
375     *   Add the masked bitcount to this, and we get the repacked invocation index.
376     * - We read the sum from the lane whose id is the number of waves in the workgroup.
377     *   This is the total number of surviving invocations in the workgroup.
378     */
379 
380    nir_def *num_waves = nir_load_num_subgroups(b);
381    nir_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
382 
383    nir_def *wg_repacked_index_base = nir_read_invocation(b, sum, wave_id);
384    nir_def *wg_num_repacked_invocations = nir_read_invocation(b, sum, num_waves);
385    nir_def *wg_repacked_index = nir_mbcnt_amd(b, input_mask, wg_repacked_index_base);
386 
387    wg_repack_result r = {
388       .num_repacked_invocations = wg_num_repacked_invocations,
389       .repacked_invocation_index = wg_repacked_index,
390    };
391 
392    return r;
393 }
394 
395 static nir_def *
pervertex_lds_addr(nir_builder * b,nir_def * vertex_idx,unsigned per_vtx_bytes)396 pervertex_lds_addr(nir_builder *b, nir_def *vertex_idx, unsigned per_vtx_bytes)
397 {
398    return nir_imul_imm(b, vertex_idx, per_vtx_bytes);
399 }
400 
401 static nir_def *
emit_pack_ngg_prim_exp_arg(nir_builder * b,unsigned num_vertices_per_primitives,nir_def * vertex_indices[3],nir_def * is_null_prim,enum amd_gfx_level gfx_level)402 emit_pack_ngg_prim_exp_arg(nir_builder *b, unsigned num_vertices_per_primitives,
403                            nir_def *vertex_indices[3], nir_def *is_null_prim,
404                            enum amd_gfx_level gfx_level)
405 {
406    nir_def *arg = nir_load_initial_edgeflags_amd(b);
407 
408    for (unsigned i = 0; i < num_vertices_per_primitives; ++i) {
409       assert(vertex_indices[i]);
410       arg = nir_ior(b, arg, nir_ishl_imm(b, vertex_indices[i],
411                                          (gfx_level >= GFX12 ? 9u : 10u) * i));
412    }
413 
414    if (is_null_prim) {
415       if (is_null_prim->bit_size == 1)
416          is_null_prim = nir_b2i32(b, is_null_prim);
417       assert(is_null_prim->bit_size == 32);
418       arg = nir_ior(b, arg, nir_ishl_imm(b, is_null_prim, 31u));
419    }
420 
421    return arg;
422 }
423 
424 static void
alloc_vertices_and_primitives(nir_builder * b,nir_def * num_vtx,nir_def * num_prim)425 alloc_vertices_and_primitives(nir_builder *b,
426                               nir_def *num_vtx,
427                               nir_def *num_prim)
428 {
429    /* The caller should only call this conditionally on wave 0.
430     *
431     * Send GS Alloc Request message from the first wave of the group to SPI.
432     * Message payload (in the m0 register) is:
433     * - bits 0..10: number of vertices in group
434     * - bits 12..22: number of primitives in group
435     */
436 
437    nir_def *m0 = nir_ior(b, nir_ishl_imm(b, num_prim, 12), num_vtx);
438    nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_GS_ALLOC_REQ);
439 }
440 
441 static void
alloc_vertices_and_primitives_gfx10_workaround(nir_builder * b,nir_def * num_vtx,nir_def * num_prim)442 alloc_vertices_and_primitives_gfx10_workaround(nir_builder *b,
443                                                nir_def *num_vtx,
444                                                nir_def *num_prim)
445 {
446    /* HW workaround for a GPU hang with 100% culling on GFX10.
447     * We always have to export at least 1 primitive.
448     * Export a degenerate triangle using vertex 0 for all 3 vertices.
449     *
450     * NOTE: We rely on the caller to set the vertex count also to 0 when the primitive count is 0.
451     */
452    nir_def *is_prim_cnt_0 = nir_ieq_imm(b, num_prim, 0);
453    nir_if *if_prim_cnt_0 = nir_push_if(b, is_prim_cnt_0);
454    {
455       nir_def *one = nir_imm_int(b, 1);
456       alloc_vertices_and_primitives(b, one, one);
457 
458       nir_def *tid = nir_load_subgroup_invocation(b);
459       nir_def *is_thread_0 = nir_ieq_imm(b, tid, 0);
460       nir_if *if_thread_0 = nir_push_if(b, is_thread_0);
461       {
462          /* The vertex indices are 0, 0, 0. */
463          nir_export_amd(b, nir_imm_zero(b, 4, 32),
464                         .base = V_008DFC_SQ_EXP_PRIM,
465                         .flags = AC_EXP_FLAG_DONE,
466                         .write_mask = 1);
467 
468          /* The HW culls primitives with NaN. -1 is also NaN and can save
469           * a dword in binary code by inlining constant.
470           */
471          nir_export_amd(b, nir_imm_ivec4(b, -1, -1, -1, -1),
472                         .base = V_008DFC_SQ_EXP_POS,
473                         .flags = AC_EXP_FLAG_DONE,
474                         .write_mask = 0xf);
475       }
476       nir_pop_if(b, if_thread_0);
477    }
478    nir_push_else(b, if_prim_cnt_0);
479    {
480       alloc_vertices_and_primitives(b, num_vtx, num_prim);
481    }
482    nir_pop_if(b, if_prim_cnt_0);
483 }
484 
485 static void
ngg_nogs_init_vertex_indices_vars(nir_builder * b,nir_function_impl * impl,lower_ngg_nogs_state * s)486 ngg_nogs_init_vertex_indices_vars(nir_builder *b, nir_function_impl *impl, lower_ngg_nogs_state *s)
487 {
488    for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
489       s->gs_vtx_indices_vars[v] = nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx_addr");
490 
491       nir_def *vtx;
492 
493       if (s->options->gfx_level >= GFX12) {
494          vtx = nir_ubfe_imm(b, nir_load_packed_passthrough_primitive_amd(b), 9 * v, 8);
495       } else if (s->options->passthrough) {
496          vtx = nir_ubfe_imm(b, nir_load_packed_passthrough_primitive_amd(b), 10 * v, 9);
497       } else {
498          vtx = nir_ubfe_imm(b, nir_load_gs_vertex_offset_amd(b, .base = v / 2u),
499                             (v & 1u) * 16u, 16u);
500       }
501 
502       nir_store_var(b, s->gs_vtx_indices_vars[v], vtx, 0x1);
503    }
504 }
505 
506 static nir_def *
emit_ngg_nogs_prim_exp_arg(nir_builder * b,lower_ngg_nogs_state * s)507 emit_ngg_nogs_prim_exp_arg(nir_builder *b, lower_ngg_nogs_state *s)
508 {
509    if (s->options->gfx_level >= GFX12 || s->options->passthrough) {
510       return nir_load_packed_passthrough_primitive_amd(b);
511    } else {
512       nir_def *vtx_idx[3] = {0};
513 
514       for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v)
515          vtx_idx[v] = nir_load_var(b, s->gs_vtx_indices_vars[v]);
516 
517       return emit_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive, vtx_idx, NULL,
518                                         s->options->gfx_level);
519    }
520 }
521 
522 static nir_def *
has_input_vertex(nir_builder * b)523 has_input_vertex(nir_builder *b)
524 {
525    return nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b));
526 }
527 
528 static nir_def *
has_input_primitive(nir_builder * b)529 has_input_primitive(nir_builder *b)
530 {
531    return nir_is_subgroup_invocation_lt_amd(b,
532                                             nir_ushr_imm(b, nir_load_merged_wave_info_amd(b), 8));
533 }
534 
535 static void
nogs_prim_gen_query(nir_builder * b,lower_ngg_nogs_state * s)536 nogs_prim_gen_query(nir_builder *b, lower_ngg_nogs_state *s)
537 {
538    if (!s->options->has_gen_prim_query)
539       return;
540 
541    nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
542    {
543       /* Activate only 1 lane and add the number of primitives to query result. */
544       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
545       {
546          /* Number of input primitives in the current wave. */
547          nir_def *num_input_prims = nir_ubfe_imm(b, nir_load_merged_wave_info_amd(b),
548                                                      8, 8);
549 
550          /* Add to stream 0 primitive generated counter. */
551          nir_atomic_add_gen_prim_count_amd(b, num_input_prims, .stream_id = 0);
552       }
553       nir_pop_if(b, if_elected);
554    }
555    nir_pop_if(b, if_shader_query);
556 }
557 
558 static void
emit_ngg_nogs_prim_export(nir_builder * b,lower_ngg_nogs_state * s,nir_def * arg)559 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *s, nir_def *arg)
560 {
561    nir_if *if_gs_thread = nir_push_if(b, nir_load_var(b, s->gs_exported_var));
562    {
563       if (!arg)
564          arg = emit_ngg_nogs_prim_exp_arg(b, s);
565 
566       /* pack user edge flag info into arg */
567       if (s->has_user_edgeflags) {
568          /* Workgroup barrier: wait for ES threads store user edge flags to LDS */
569          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
570                             .memory_scope = SCOPE_WORKGROUP,
571                             .memory_semantics = NIR_MEMORY_ACQ_REL,
572                             .memory_modes = nir_var_mem_shared);
573 
574          unsigned edge_flag_bits = ac_get_all_edge_flag_bits(s->options->gfx_level);
575          nir_def *mask = nir_imm_intN_t(b, ~edge_flag_bits, 32);
576 
577          unsigned edge_flag_offset = 0;
578          if (s->streamout_enabled) {
579             unsigned packed_location =
580                util_bitcount64(b->shader->info.outputs_written &
581                                BITFIELD64_MASK(VARYING_SLOT_EDGE));
582             edge_flag_offset = packed_location * 16;
583          }
584 
585          for (int i = 0; i < s->options->num_vertices_per_primitive; i++) {
586             nir_def *vtx_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]);
587             nir_def *addr = pervertex_lds_addr(b, vtx_idx, s->pervertex_lds_bytes);
588             nir_def *edge = nir_load_shared(b, 1, 32, addr, .base = edge_flag_offset);
589 
590             if (s->options->gfx_level >= GFX12)
591                mask = nir_ior(b, mask, nir_ishl_imm(b, edge, 8 + i * 9));
592             else
593                mask = nir_ior(b, mask, nir_ishl_imm(b, edge, 9 + i * 10));
594          }
595          arg = nir_iand(b, arg, mask);
596       }
597 
598       ac_nir_export_primitive(b, arg, NULL);
599    }
600    nir_pop_if(b, if_gs_thread);
601 }
602 
603 static void
emit_ngg_nogs_prim_id_store_shared(nir_builder * b,lower_ngg_nogs_state * s)604 emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *s)
605 {
606    nir_def *gs_thread =
607       s->gs_accepted_var ? nir_load_var(b, s->gs_accepted_var) : has_input_primitive(b);
608 
609    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
610    {
611       /* Copy Primitive IDs from GS threads to the LDS address
612        * corresponding to the ES thread of the provoking vertex.
613        * It will be exported as a per-vertex attribute.
614        */
615       nir_def *gs_vtx_indices[3];
616       for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++)
617          gs_vtx_indices[i] = nir_load_var(b, s->gs_vtx_indices_vars[i]);
618 
619       nir_def *provoking_vertex = nir_load_provoking_vtx_in_prim_amd(b);
620       nir_def *provoking_vtx_idx = nir_select_from_ssa_def_array(
621          b, gs_vtx_indices, s->options->num_vertices_per_primitive, provoking_vertex);
622 
623       nir_def *prim_id = nir_load_primitive_id(b);
624       nir_def *addr = pervertex_lds_addr(b, provoking_vtx_idx, s->pervertex_lds_bytes);
625 
626       /* primitive id is always at last of a vertex */
627       nir_store_shared(b, prim_id, addr, .base = s->pervertex_lds_bytes - 4);
628    }
629    nir_pop_if(b, if_gs_thread);
630 }
631 
632 static void
emit_store_ngg_nogs_es_primitive_id(nir_builder * b,lower_ngg_nogs_state * s)633 emit_store_ngg_nogs_es_primitive_id(nir_builder *b, lower_ngg_nogs_state *s)
634 {
635    nir_def *prim_id = NULL;
636 
637    if (b->shader->info.stage == MESA_SHADER_VERTEX) {
638       /* LDS address where the primitive ID is stored */
639       nir_def *thread_id_in_threadgroup = nir_load_local_invocation_index(b);
640       nir_def *addr =
641          pervertex_lds_addr(b, thread_id_in_threadgroup, s->pervertex_lds_bytes);
642 
643       /* Load primitive ID from LDS */
644       prim_id = nir_load_shared(b, 1, 32, addr, .base = s->pervertex_lds_bytes - 4);
645    } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
646       /* Just use tess eval primitive ID, which is the same as the patch ID. */
647       prim_id = nir_load_primitive_id(b);
648    }
649 
650    s->out.outputs[VARYING_SLOT_PRIMITIVE_ID][0] = prim_id;
651 
652    /* Update outputs_written to reflect that the pass added a new output. */
653    b->shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
654 }
655 
656 static void
add_clipdist_bit(nir_builder * b,nir_def * dist,unsigned index,nir_variable * mask)657 add_clipdist_bit(nir_builder *b, nir_def *dist, unsigned index, nir_variable *mask)
658 {
659    nir_def *is_neg = nir_flt_imm(b, dist, 0);
660    nir_def *neg_mask = nir_ishl_imm(b, nir_b2i32(b, is_neg), index);
661    neg_mask = nir_ior(b, neg_mask, nir_load_var(b, mask));
662    nir_store_var(b, mask, neg_mask, 1);
663 }
664 
665 static bool
remove_culling_shader_output(nir_builder * b,nir_instr * instr,void * state)666 remove_culling_shader_output(nir_builder *b, nir_instr *instr, void *state)
667 {
668    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *) state;
669 
670    if (instr->type != nir_instr_type_intrinsic)
671       return false;
672 
673    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
674 
675    /* These are not allowed in VS / TES */
676    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
677           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
678 
679    /* We are only interested in output stores now */
680    if (intrin->intrinsic != nir_intrinsic_store_output)
681       return false;
682 
683    b->cursor = nir_before_instr(instr);
684 
685    /* no indirect output */
686    assert(nir_src_is_const(intrin->src[1]) && nir_src_as_uint(intrin->src[1]) == 0);
687 
688    unsigned writemask = nir_intrinsic_write_mask(intrin);
689    unsigned component = nir_intrinsic_component(intrin);
690    nir_def *store_val = intrin->src[0].ssa;
691 
692    /* Position output - store the value to a variable, remove output store */
693    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
694    switch (io_sem.location) {
695    case VARYING_SLOT_POS:
696       ac_nir_store_var_components(b, s->position_value_var, store_val, component, writemask);
697       break;
698    case VARYING_SLOT_CLIP_DIST0:
699    case VARYING_SLOT_CLIP_DIST1: {
700       unsigned base = io_sem.location == VARYING_SLOT_CLIP_DIST1 ? 4 : 0;
701       base += component;
702 
703       /* valid clipdist component mask */
704       unsigned mask = (s->options->clip_cull_dist_mask >> base) & writemask;
705       u_foreach_bit(i, mask) {
706          add_clipdist_bit(b, nir_channel(b, store_val, i), base + i,
707                           s->clipdist_neg_mask_var);
708          s->has_clipdist = true;
709       }
710       break;
711    }
712    case VARYING_SLOT_CLIP_VERTEX:
713       ac_nir_store_var_components(b, s->clip_vertex_var, store_val, component, writemask);
714       break;
715    default:
716       break;
717    }
718 
719    /* Remove all output stores */
720    nir_instr_remove(instr);
721    return true;
722 }
723 
724 static void
remove_culling_shader_outputs(nir_shader * culling_shader,lower_ngg_nogs_state * s)725 remove_culling_shader_outputs(nir_shader *culling_shader, lower_ngg_nogs_state *s)
726 {
727    nir_shader_instructions_pass(culling_shader, remove_culling_shader_output,
728                                 nir_metadata_control_flow, s);
729 
730    /* Remove dead code resulting from the deleted outputs. */
731    bool progress;
732    do {
733       progress = false;
734       NIR_PASS(progress, culling_shader, nir_opt_dead_write_vars);
735       NIR_PASS(progress, culling_shader, nir_opt_dce);
736       NIR_PASS(progress, culling_shader, nir_opt_dead_cf);
737    } while (progress);
738 }
739 
740 static void
rewrite_uses_to_var(nir_builder * b,nir_def * old_def,nir_variable * replacement_var,unsigned replacement_var_channel)741 rewrite_uses_to_var(nir_builder *b, nir_def *old_def, nir_variable *replacement_var, unsigned replacement_var_channel)
742 {
743    if (old_def->parent_instr->type == nir_instr_type_load_const)
744       return;
745 
746    b->cursor = nir_after_instr(old_def->parent_instr);
747    if (b->cursor.instr->type == nir_instr_type_phi)
748       b->cursor = nir_after_phis(old_def->parent_instr->block);
749 
750    nir_def *pos_val_rep = nir_load_var(b, replacement_var);
751    nir_def *replacement = nir_channel(b, pos_val_rep, replacement_var_channel);
752 
753    if (old_def->num_components > 1) {
754       /* old_def uses a swizzled vector component.
755        * There is no way to replace the uses of just a single vector component,
756        * so instead create a new vector and replace all uses of the old vector.
757        */
758       nir_def *old_def_elements[NIR_MAX_VEC_COMPONENTS] = {0};
759       for (unsigned j = 0; j < old_def->num_components; ++j)
760          old_def_elements[j] = nir_channel(b, old_def, j);
761       replacement = nir_vec(b, old_def_elements, old_def->num_components);
762    }
763 
764    nir_def_rewrite_uses_after(old_def, replacement, replacement->parent_instr);
765 }
766 
767 static bool
remove_extra_pos_output(nir_builder * b,nir_instr * instr,void * state)768 remove_extra_pos_output(nir_builder *b, nir_instr *instr, void *state)
769 {
770    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *) state;
771 
772    if (instr->type != nir_instr_type_intrinsic)
773       return false;
774 
775    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
776 
777    /* These are not allowed in VS / TES */
778    assert(intrin->intrinsic != nir_intrinsic_store_per_vertex_output &&
779           intrin->intrinsic != nir_intrinsic_load_per_vertex_input);
780 
781    /* We are only interested in output stores now */
782    if (intrin->intrinsic != nir_intrinsic_store_output)
783       return false;
784 
785    nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
786    if (io_sem.location != VARYING_SLOT_POS)
787       return false;
788 
789    b->cursor = nir_before_instr(instr);
790 
791    /* In case other outputs use what we calculated for pos,
792     * try to avoid calculating it again by rewriting the usages
793     * of the store components here.
794     */
795    nir_def *store_val = intrin->src[0].ssa;
796    unsigned store_pos_component = nir_intrinsic_component(intrin);
797 
798    nir_instr_remove(instr);
799 
800    if (store_val->parent_instr->type == nir_instr_type_alu) {
801       nir_alu_instr *alu = nir_instr_as_alu(store_val->parent_instr);
802       if (nir_op_is_vec_or_mov(alu->op)) {
803          /* Output store uses a vector, we can easily rewrite uses of each vector element. */
804 
805          unsigned num_vec_src = 0;
806          if (alu->op == nir_op_mov)
807             num_vec_src = 1;
808          else if (alu->op == nir_op_vec2)
809             num_vec_src = 2;
810          else if (alu->op == nir_op_vec3)
811             num_vec_src = 3;
812          else if (alu->op == nir_op_vec4)
813             num_vec_src = 4;
814          assert(num_vec_src);
815 
816          /* Remember the current components whose uses we wish to replace.
817           * This is needed because rewriting one source can affect the others too.
818           */
819          nir_def *vec_comps[NIR_MAX_VEC_COMPONENTS] = {0};
820          for (unsigned i = 0; i < num_vec_src; i++)
821             vec_comps[i] = alu->src[i].src.ssa;
822 
823          for (unsigned i = 0; i < num_vec_src; i++)
824             rewrite_uses_to_var(b, vec_comps[i], s->position_value_var, store_pos_component + i);
825       } else {
826          rewrite_uses_to_var(b, store_val, s->position_value_var, store_pos_component);
827       }
828    } else {
829       rewrite_uses_to_var(b, store_val, s->position_value_var, store_pos_component);
830    }
831 
832    return true;
833 }
834 
835 static void
remove_extra_pos_outputs(nir_shader * shader,lower_ngg_nogs_state * s)836 remove_extra_pos_outputs(nir_shader *shader, lower_ngg_nogs_state *s)
837 {
838    nir_shader_instructions_pass(shader, remove_extra_pos_output,
839                                 nir_metadata_control_flow,
840                                 s);
841 }
842 
843 static bool
remove_compacted_arg(lower_ngg_nogs_state * s,nir_builder * b,unsigned idx)844 remove_compacted_arg(lower_ngg_nogs_state *s, nir_builder *b, unsigned idx)
845 {
846    nir_instr *store_instr = s->compact_arg_stores[idx];
847    if (!store_instr)
848       return false;
849 
850    /* Simply remove the store. */
851    nir_instr_remove(store_instr);
852 
853    /* Find the intrinsic that overwrites the shader arguments,
854     * and change its corresponding source.
855     * This will cause NIR's DCE to recognize the load and its phis as dead.
856     */
857    b->cursor = nir_before_instr(&s->overwrite_args->instr);
858    nir_def *undef_arg = nir_undef(b, 1, 32);
859    nir_def_rewrite_uses(s->overwrite_args->src[idx].ssa, undef_arg);
860 
861    s->compact_arg_stores[idx] = NULL;
862    return true;
863 }
864 
865 static bool
cleanup_culling_shader_after_dce(nir_shader * shader,nir_function_impl * function_impl,lower_ngg_nogs_state * s)866 cleanup_culling_shader_after_dce(nir_shader *shader,
867                                  nir_function_impl *function_impl,
868                                  lower_ngg_nogs_state *s)
869 {
870    bool uses_vs_vertex_id = false;
871    bool uses_vs_instance_id = false;
872    bool uses_tes_u = false;
873    bool uses_tes_v = false;
874    bool uses_tes_rel_patch_id = false;
875    bool uses_tes_patch_id = false;
876 
877    bool progress = false;
878    nir_builder b = nir_builder_create(function_impl);
879 
880    nir_foreach_block_reverse_safe(block, function_impl) {
881       nir_foreach_instr_reverse_safe(instr, block) {
882          if (instr->type != nir_instr_type_intrinsic)
883             continue;
884 
885          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
886 
887          switch (intrin->intrinsic) {
888          case nir_intrinsic_sendmsg_amd:
889             goto cleanup_culling_shader_after_dce_done;
890          case nir_intrinsic_load_vertex_id:
891          case nir_intrinsic_load_vertex_id_zero_base:
892             uses_vs_vertex_id = true;
893             break;
894          case nir_intrinsic_load_instance_id:
895             uses_vs_instance_id = true;
896             break;
897          case nir_intrinsic_load_input: {
898             const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
899             if (s->options->instance_rate_inputs & BITFIELD_BIT(io_sem.location))
900                uses_vs_instance_id = true;
901             else
902                uses_vs_vertex_id = true;
903             break;
904          }
905          case nir_intrinsic_load_tess_coord:
906             uses_tes_u = uses_tes_v = true;
907             break;
908          case nir_intrinsic_load_tess_rel_patch_id_amd:
909             uses_tes_rel_patch_id = true;
910             break;
911          case nir_intrinsic_load_primitive_id:
912             if (shader->info.stage == MESA_SHADER_TESS_EVAL)
913                uses_tes_patch_id = true;
914             break;
915          default:
916             break;
917          }
918       }
919    }
920 
921    cleanup_culling_shader_after_dce_done:
922 
923    if (shader->info.stage == MESA_SHADER_VERTEX) {
924       if (!uses_vs_vertex_id)
925          progress |= remove_compacted_arg(s, &b, 0);
926       if (!uses_vs_instance_id)
927          progress |= remove_compacted_arg(s, &b, 1);
928    } else if (shader->info.stage == MESA_SHADER_TESS_EVAL) {
929       if (!uses_tes_u)
930          progress |= remove_compacted_arg(s, &b, 0);
931       if (!uses_tes_v)
932          progress |= remove_compacted_arg(s, &b, 1);
933       if (!uses_tes_rel_patch_id)
934          progress |= remove_compacted_arg(s, &b, 3);
935       if (!uses_tes_patch_id)
936          progress |= remove_compacted_arg(s, &b, 2);
937    }
938 
939    return progress;
940 }
941 
942 /**
943  * Perform vertex compaction after culling.
944  *
945  * 1. Repack surviving ES invocations (this determines which lane will export which vertex)
946  * 2. Surviving ES vertex invocations store their data to LDS
947  * 3. Emit GS_ALLOC_REQ
948  * 4. Repacked invocations load the vertex data from LDS
949  * 5. GS threads update their vertex indices
950  */
951 static void
compact_vertices_after_culling(nir_builder * b,lower_ngg_nogs_state * s,nir_variable ** repacked_variables,nir_variable ** gs_vtxaddr_vars,nir_def * invocation_index,nir_def * es_vertex_lds_addr,nir_def * es_exporter_tid,nir_def * num_live_vertices_in_workgroup,unsigned pervertex_lds_bytes,unsigned num_repacked_variables)952 compact_vertices_after_culling(nir_builder *b,
953                                lower_ngg_nogs_state *s,
954                                nir_variable **repacked_variables,
955                                nir_variable **gs_vtxaddr_vars,
956                                nir_def *invocation_index,
957                                nir_def *es_vertex_lds_addr,
958                                nir_def *es_exporter_tid,
959                                nir_def *num_live_vertices_in_workgroup,
960                                unsigned pervertex_lds_bytes,
961                                unsigned num_repacked_variables)
962 {
963    nir_variable *es_accepted_var = s->es_accepted_var;
964    nir_variable *gs_accepted_var = s->gs_accepted_var;
965    nir_variable *position_value_var = s->position_value_var;
966    nir_variable *prim_exp_arg_var = s->prim_exp_arg_var;
967 
968    nir_if *if_es_accepted = nir_push_if(b, nir_load_var(b, es_accepted_var));
969    {
970       nir_def *exporter_addr = pervertex_lds_addr(b, es_exporter_tid, pervertex_lds_bytes);
971 
972       /* Store the exporter thread's index to the LDS space of the current thread so GS threads can load it */
973       nir_store_shared(b, nir_u2u8(b, es_exporter_tid), es_vertex_lds_addr, .base = lds_es_exporter_tid);
974 
975       /* Store the current thread's position output to the exporter thread's LDS space */
976       nir_def *pos = nir_load_var(b, position_value_var);
977       nir_store_shared(b, pos, exporter_addr, .base = lds_es_pos_x);
978 
979       /* Store the current thread's repackable arguments to the exporter thread's LDS space */
980       for (unsigned i = 0; i < num_repacked_variables; ++i) {
981          nir_def *arg_val = nir_load_var(b, repacked_variables[i]);
982          nir_intrinsic_instr *store = nir_store_shared(b, arg_val, exporter_addr, .base = lds_es_arg_0 + 4u * i);
983 
984          s->compact_arg_stores[i] = &store->instr;
985       }
986 
987       /* TES rel patch id does not cost extra dword */
988       if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
989          nir_def *arg_val = nir_load_var(b, s->repacked_rel_patch_id);
990          nir_intrinsic_instr *store =
991             nir_store_shared(b, nir_u2u8(b, arg_val), exporter_addr,
992                              .base = lds_es_tes_rel_patch_id);
993 
994          s->compact_arg_stores[3] = &store->instr;
995       }
996    }
997    nir_pop_if(b, if_es_accepted);
998 
999    /* TODO: Consider adding a shortcut exit.
1000     * Waves that have no vertices and primitives left can s_endpgm right here.
1001     */
1002 
1003    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1004                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1005 
1006    nir_def *es_survived = nir_ilt(b, invocation_index, num_live_vertices_in_workgroup);
1007    nir_if *if_packed_es_thread = nir_push_if(b, es_survived);
1008    {
1009       /* Read position from the current ES thread's LDS space (written by the exported vertex's ES thread) */
1010       nir_def *exported_pos = nir_load_shared(b, 4, 32, es_vertex_lds_addr, .base = lds_es_pos_x);
1011       nir_store_var(b, position_value_var, exported_pos, 0xfu);
1012 
1013       /* Read the repacked arguments */
1014       for (unsigned i = 0; i < num_repacked_variables; ++i) {
1015          nir_def *arg_val = nir_load_shared(b, 1, 32, es_vertex_lds_addr, .base = lds_es_arg_0 + 4u * i);
1016          nir_store_var(b, repacked_variables[i], arg_val, 0x1u);
1017       }
1018 
1019       if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1020          nir_def *arg_val = nir_load_shared(b, 1, 8, es_vertex_lds_addr,
1021                                                 .base = lds_es_tes_rel_patch_id);
1022          nir_store_var(b, s->repacked_rel_patch_id, nir_u2u32(b, arg_val), 0x1u);
1023       }
1024    }
1025    nir_push_else(b, if_packed_es_thread);
1026    {
1027       nir_store_var(b, position_value_var, nir_undef(b, 4, 32), 0xfu);
1028       for (unsigned i = 0; i < num_repacked_variables; ++i)
1029          nir_store_var(b, repacked_variables[i], nir_undef(b, 1, 32), 0x1u);
1030    }
1031    nir_pop_if(b, if_packed_es_thread);
1032 
1033    nir_if *if_gs_accepted = nir_push_if(b, nir_load_var(b, gs_accepted_var));
1034    {
1035       nir_def *exporter_vtx_indices[3] = {0};
1036 
1037       /* Load the index of the ES threads that will export the current GS thread's vertices */
1038       for (unsigned v = 0; v < s->options->num_vertices_per_primitive; ++v) {
1039          nir_def *vtx_addr = nir_load_var(b, gs_vtxaddr_vars[v]);
1040          nir_def *exporter_vtx_idx = nir_load_shared(b, 1, 8, vtx_addr, .base = lds_es_exporter_tid);
1041          exporter_vtx_indices[v] = nir_u2u32(b, exporter_vtx_idx);
1042          nir_store_var(b, s->gs_vtx_indices_vars[v], exporter_vtx_indices[v], 0x1);
1043       }
1044 
1045       nir_def *prim_exp_arg =
1046          emit_pack_ngg_prim_exp_arg(b, s->options->num_vertices_per_primitive,
1047                                     exporter_vtx_indices, NULL, s->options->gfx_level);
1048       nir_store_var(b, prim_exp_arg_var, prim_exp_arg, 0x1u);
1049    }
1050    nir_pop_if(b, if_gs_accepted);
1051 
1052    nir_store_var(b, es_accepted_var, es_survived, 0x1u);
1053 }
1054 
1055 static void
analyze_shader_before_culling_walk(nir_def * ssa,uint8_t flag,lower_ngg_nogs_state * s)1056 analyze_shader_before_culling_walk(nir_def *ssa,
1057                                    uint8_t flag,
1058                                    lower_ngg_nogs_state *s)
1059 {
1060    nir_instr *instr = ssa->parent_instr;
1061    uint8_t old_pass_flags = instr->pass_flags;
1062    instr->pass_flags |= flag;
1063 
1064    if (instr->pass_flags == old_pass_flags)
1065       return; /* Already visited. */
1066 
1067    switch (instr->type) {
1068    case nir_instr_type_intrinsic: {
1069       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1070 
1071       /* VS input loads and SSBO loads are actually VRAM reads on AMD HW. */
1072       switch (intrin->intrinsic) {
1073       case nir_intrinsic_load_input: {
1074          nir_io_semantics in_io_sem = nir_intrinsic_io_semantics(intrin);
1075          uint64_t in_mask = UINT64_C(1) << (uint64_t) in_io_sem.location;
1076          if (instr->pass_flags & nggc_passflag_used_by_pos)
1077             s->inputs_needed_by_pos |= in_mask;
1078          else if (instr->pass_flags & nggc_passflag_used_by_other)
1079             s->inputs_needed_by_others |= in_mask;
1080          break;
1081       }
1082       default:
1083          break;
1084       }
1085 
1086       break;
1087    }
1088    case nir_instr_type_alu: {
1089       nir_alu_instr *alu = nir_instr_as_alu(instr);
1090       unsigned num_srcs = nir_op_infos[alu->op].num_inputs;
1091 
1092       for (unsigned i = 0; i < num_srcs; ++i) {
1093          analyze_shader_before_culling_walk(alu->src[i].src.ssa, flag, s);
1094       }
1095 
1096       break;
1097    }
1098    case nir_instr_type_tex: {
1099       nir_tex_instr *tex = nir_instr_as_tex(instr);
1100       unsigned num_srcs = tex->num_srcs;
1101 
1102       for (unsigned i = 0; i < num_srcs; ++i) {
1103          analyze_shader_before_culling_walk(tex->src[i].src.ssa, flag, s);
1104       }
1105 
1106       break;
1107    }
1108    case nir_instr_type_phi: {
1109       nir_phi_instr *phi = nir_instr_as_phi(instr);
1110       nir_foreach_phi_src_safe(phi_src, phi) {
1111          analyze_shader_before_culling_walk(phi_src->src.ssa, flag, s);
1112       }
1113 
1114       break;
1115    }
1116    default:
1117       break;
1118    }
1119 }
1120 
1121 static void
analyze_shader_before_culling(nir_shader * shader,lower_ngg_nogs_state * s)1122 analyze_shader_before_culling(nir_shader *shader, lower_ngg_nogs_state *s)
1123 {
1124    /* LCSSA is needed to get correct results from divergence analysis. */
1125    nir_convert_to_lcssa(shader, true, true);
1126    /* We need divergence info for culling shaders. */
1127    nir_divergence_analysis(shader);
1128 
1129    nir_foreach_function_impl(impl, shader) {
1130       nir_foreach_block(block, impl) {
1131          nir_foreach_instr(instr, block) {
1132             instr->pass_flags = 0;
1133 
1134             if (instr->type != nir_instr_type_intrinsic)
1135                continue;
1136 
1137             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1138             if (intrin->intrinsic != nir_intrinsic_store_output)
1139                continue;
1140 
1141             nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
1142             nir_def *store_val = intrin->src[0].ssa;
1143             uint8_t flag = io_sem.location == VARYING_SLOT_POS ? nggc_passflag_used_by_pos : nggc_passflag_used_by_other;
1144             analyze_shader_before_culling_walk(store_val, flag, s);
1145          }
1146       }
1147    }
1148 }
1149 
1150 static nir_def *
find_reusable_ssa_def(nir_instr * instr)1151 find_reusable_ssa_def(nir_instr *instr)
1152 {
1153    /* Find instructions whose SSA definitions are used by both
1154     * the top and bottom parts of the shader (before and after culling).
1155     * Only in this case, it makes sense for the bottom part
1156     * to try to reuse these from the top part.
1157     */
1158    if ((instr->pass_flags & nggc_passflag_used_by_both) != nggc_passflag_used_by_both)
1159       return NULL;
1160 
1161    switch (instr->type) {
1162    case nir_instr_type_alu: {
1163       nir_alu_instr *alu = nir_instr_as_alu(instr);
1164       if (alu->def.divergent)
1165          return NULL;
1166       /* Ignore uniform floats because they regress VGPR usage too much */
1167       if (nir_op_infos[alu->op].output_type & nir_type_float)
1168          return NULL;
1169       return &alu->def;
1170    }
1171    case nir_instr_type_intrinsic: {
1172       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1173       if (!nir_intrinsic_can_reorder(intrin) ||
1174             !nir_intrinsic_infos[intrin->intrinsic].has_dest ||
1175             intrin->def.divergent)
1176          return NULL;
1177       return &intrin->def;
1178    }
1179    case nir_instr_type_phi: {
1180       nir_phi_instr *phi = nir_instr_as_phi(instr);
1181       if (phi->def.divergent)
1182          return NULL;
1183       return &phi->def;
1184    }
1185    default:
1186       return NULL;
1187    }
1188 }
1189 
1190 static const struct glsl_type *
glsl_uint_type_for_ssa(nir_def * ssa)1191 glsl_uint_type_for_ssa(nir_def *ssa)
1192 {
1193    enum glsl_base_type base_type = GLSL_TYPE_UINT;
1194    switch (ssa->bit_size) {
1195    case 8: base_type = GLSL_TYPE_UINT8; break;
1196    case 16: base_type = GLSL_TYPE_UINT16; break;
1197    case 32: base_type = GLSL_TYPE_UINT; break;
1198    case 64: base_type = GLSL_TYPE_UINT64; break;
1199    default: return NULL;
1200    }
1201 
1202    return ssa->num_components == 1
1203           ? glsl_scalar_type(base_type)
1204           : glsl_vector_type(base_type, ssa->num_components);
1205 }
1206 
1207 /**
1208  * Save the reusable SSA definitions to variables so that the
1209  * bottom shader part can reuse them from the top part.
1210  *
1211  * 1. We create a new function temporary variable for reusables,
1212  *    and insert a store+load.
1213  * 2. The shader is cloned (the top part is created), then the
1214  *    control flow is reinserted (for the bottom part.)
1215  * 3. For reusables, we delete the variable stores from the
1216  *    bottom part. This will make them use the variables from
1217  *    the top part and DCE the redundant instructions.
1218  */
1219 static void
save_reusable_variables(nir_builder * b,lower_ngg_nogs_state * s)1220 save_reusable_variables(nir_builder *b, lower_ngg_nogs_state *s)
1221 {
1222    ASSERTED int vec_ok = u_vector_init(&s->reusable_nondeferred_variables, 4, sizeof(reusable_nondeferred_variable));
1223    assert(vec_ok);
1224 
1225    /* Upper limit on reusable uniforms in order to reduce SGPR spilling. */
1226    unsigned remaining_reusable_uniforms = 48;
1227 
1228    nir_block *block = nir_start_block(b->impl);
1229    while (block) {
1230       /* Process the instructions in the current block. */
1231       nir_foreach_instr_safe(instr, block) {
1232          /* Determine if we can reuse the current SSA value.
1233           * When vertex compaction is used, it is possible that the same shader invocation
1234           * processes a different vertex in the top and bottom part of the shader.
1235           * Therefore, we only reuse uniform values.
1236           */
1237          nir_def *ssa = find_reusable_ssa_def(instr);
1238          if (!ssa)
1239             continue;
1240 
1241          /* Determine a suitable type for the SSA value. */
1242          const struct glsl_type *t = glsl_uint_type_for_ssa(ssa);
1243          if (!t)
1244             continue;
1245 
1246          if (!ssa->divergent) {
1247             if (remaining_reusable_uniforms < ssa->num_components)
1248                continue;
1249 
1250             remaining_reusable_uniforms -= ssa->num_components;
1251          }
1252 
1253          reusable_nondeferred_variable *saved = (reusable_nondeferred_variable *) u_vector_add(&s->reusable_nondeferred_variables);
1254          assert(saved);
1255 
1256          /* Create a new NIR variable where we store the reusable value.
1257           * Then, we reload the variable and replace the uses of the value
1258           * with the reloaded variable.
1259           */
1260          saved->var = nir_local_variable_create(b->impl, t, NULL);
1261          saved->ssa = ssa;
1262 
1263          b->cursor = instr->type == nir_instr_type_phi
1264                      ? nir_after_instr_and_phis(instr)
1265                      : nir_after_instr(instr);
1266          nir_store_var(b, saved->var, saved->ssa, BITFIELD_MASK(ssa->num_components));
1267          nir_def *reloaded = nir_load_var(b, saved->var);
1268          nir_def_rewrite_uses_after(ssa, reloaded, reloaded->parent_instr);
1269       }
1270 
1271       /* Look at the next CF node. */
1272       nir_cf_node *next_cf_node = nir_cf_node_next(&block->cf_node);
1273       if (next_cf_node) {
1274          /* It makes no sense to try to reuse things from within loops. */
1275          bool next_is_loop = next_cf_node->type == nir_cf_node_loop;
1276 
1277          /* Don't reuse if we're in divergent control flow.
1278           *
1279           * Thanks to vertex repacking, the same shader invocation may process a different vertex
1280           * in the top and bottom part, and it's even possible that this different vertex was initially
1281           * processed in a different wave. So the two parts may take a different divergent code path.
1282           * Therefore, these variables in divergent control flow may stay undefined.
1283           *
1284           * Note that this problem doesn't exist if vertices are not repacked or if the
1285           * workgroup only has a single wave.
1286           */
1287          bool next_is_divergent_if =
1288             next_cf_node->type == nir_cf_node_if &&
1289             nir_cf_node_as_if(next_cf_node)->condition.ssa->divergent;
1290 
1291          if (next_is_loop || next_is_divergent_if) {
1292             block = nir_cf_node_cf_tree_next(next_cf_node);
1293             continue;
1294          }
1295       }
1296 
1297       /* Go to the next block. */
1298       block = nir_block_cf_tree_next(block);
1299    }
1300 }
1301 
1302 /**
1303  * Reuses suitable variables from the top part of the shader,
1304  * by deleting their stores from the bottom part.
1305  */
1306 static void
apply_reusable_variables(nir_builder * b,lower_ngg_nogs_state * s)1307 apply_reusable_variables(nir_builder *b, lower_ngg_nogs_state *s)
1308 {
1309    if (!u_vector_length(&s->reusable_nondeferred_variables)) {
1310       u_vector_finish(&s->reusable_nondeferred_variables);
1311       return;
1312    }
1313 
1314    nir_foreach_block_reverse_safe(block, b->impl) {
1315       nir_foreach_instr_reverse_safe(instr, block) {
1316          if (instr->type != nir_instr_type_intrinsic)
1317             continue;
1318          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1319 
1320          /* When we found any of these intrinsics, it means
1321           * we reached the top part and we must stop.
1322           */
1323          if (intrin->intrinsic == nir_intrinsic_sendmsg_amd)
1324             goto done;
1325 
1326          if (intrin->intrinsic != nir_intrinsic_store_deref)
1327             continue;
1328          nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
1329          if (deref->deref_type != nir_deref_type_var)
1330             continue;
1331 
1332          reusable_nondeferred_variable *saved;
1333          u_vector_foreach(saved, &s->reusable_nondeferred_variables) {
1334             if (saved->var == deref->var) {
1335                nir_instr_remove(instr);
1336             }
1337          }
1338       }
1339    }
1340 
1341    done:
1342    u_vector_finish(&s->reusable_nondeferred_variables);
1343 }
1344 
1345 static void
cull_primitive_accepted(nir_builder * b,void * state)1346 cull_primitive_accepted(nir_builder *b, void *state)
1347 {
1348    lower_ngg_nogs_state *s = (lower_ngg_nogs_state *)state;
1349 
1350    nir_store_var(b, s->gs_accepted_var, nir_imm_true(b), 0x1u);
1351 
1352    /* Store the accepted state to LDS for ES threads */
1353    for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx)
1354       nir_store_shared(b, nir_imm_intN_t(b, 1, 8), s->vtx_addr[vtx], .base = lds_es_vertex_accepted);
1355 }
1356 
1357 static void
clipdist_culling_es_part(nir_builder * b,lower_ngg_nogs_state * s,nir_def * es_vertex_lds_addr)1358 clipdist_culling_es_part(nir_builder *b, lower_ngg_nogs_state *s,
1359                          nir_def *es_vertex_lds_addr)
1360 {
1361    /* no gl_ClipDistance used but we have user defined clip plane */
1362    if (s->options->user_clip_plane_enable_mask && !s->has_clipdist) {
1363       /* use gl_ClipVertex if defined */
1364       nir_variable *clip_vertex_var =
1365          b->shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CLIP_VERTEX) ?
1366          s->clip_vertex_var : s->position_value_var;
1367       nir_def *clip_vertex = nir_load_var(b, clip_vertex_var);
1368 
1369       /* clip against user defined clip planes */
1370       for (unsigned i = 0; i < 8; i++) {
1371          if (!(s->options->user_clip_plane_enable_mask & BITFIELD_BIT(i)))
1372             continue;
1373 
1374          nir_def *plane = nir_load_user_clip_plane(b, .ucp_id = i);
1375          nir_def *dist = nir_fdot(b, clip_vertex, plane);
1376          add_clipdist_bit(b, dist, i, s->clipdist_neg_mask_var);
1377       }
1378 
1379       s->has_clipdist = true;
1380    }
1381 
1382    /* store clipdist_neg_mask to LDS for culling latter in gs thread */
1383    if (s->has_clipdist) {
1384       nir_def *mask = nir_load_var(b, s->clipdist_neg_mask_var);
1385       nir_store_shared(b, nir_u2u8(b, mask), es_vertex_lds_addr,
1386                        .base = lds_es_clipdist_neg_mask);
1387    }
1388 }
1389 
1390 static unsigned
ngg_nogs_get_culling_pervertex_lds_size(gl_shader_stage stage,bool uses_instance_id,bool uses_primitive_id,unsigned * num_repacked_variables)1391 ngg_nogs_get_culling_pervertex_lds_size(gl_shader_stage stage,
1392                                         bool uses_instance_id,
1393                                         bool uses_primitive_id,
1394                                         unsigned *num_repacked_variables)
1395 {
1396    /* Culling shaders must repack some variables because
1397     * the same shader invocation may process different vertices
1398     * before and after the culling algorithm.
1399     */
1400 
1401    unsigned num_repacked;
1402    if (stage == MESA_SHADER_VERTEX) {
1403       /* Vertex shaders repack:
1404        * - Vertex ID
1405        * - Instance ID (only if used)
1406        */
1407       num_repacked = uses_instance_id ? 2 : 1;
1408    } else {
1409       /* Tess eval shaders repack:
1410        * - U, V coordinates
1411        * - primitive ID (aka. patch id, only if used)
1412        * - relative patch id (not included here because doesn't need a dword)
1413        */
1414       assert(stage == MESA_SHADER_TESS_EVAL);
1415       num_repacked = uses_primitive_id ? 3 : 2;
1416    }
1417 
1418    if (num_repacked_variables)
1419       *num_repacked_variables = num_repacked;
1420 
1421    /* one odd dword to reduce LDS bank conflict */
1422    return (lds_es_arg_0 + num_repacked * 4u) | 4u;
1423 }
1424 
1425 static void
add_deferred_attribute_culling(nir_builder * b,nir_cf_list * original_extracted_cf,lower_ngg_nogs_state * s)1426 add_deferred_attribute_culling(nir_builder *b, nir_cf_list *original_extracted_cf, lower_ngg_nogs_state *s)
1427 {
1428    bool uses_instance_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_INSTANCE_ID);
1429    bool uses_tess_primitive_id = BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_PRIMITIVE_ID);
1430 
1431    unsigned num_repacked_variables;
1432    unsigned pervertex_lds_bytes =
1433       ngg_nogs_get_culling_pervertex_lds_size(b->shader->info.stage,
1434                                               uses_instance_id,
1435                                               uses_tess_primitive_id,
1436                                               &num_repacked_variables);
1437 
1438    nir_function_impl *impl = nir_shader_get_entrypoint(b->shader);
1439 
1440    /* Create some helper variables. */
1441    nir_variable *gs_vtxaddr_vars[3] = {
1442       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx0_addr"),
1443       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx1_addr"),
1444       nir_local_variable_create(impl, glsl_uint_type(), "gs_vtx2_addr"),
1445    };
1446 
1447    nir_variable *repacked_variables[3] = {
1448       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_0"),
1449       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_1"),
1450       nir_local_variable_create(impl, glsl_uint_type(), "repacked_var_2"),
1451    };
1452 
1453    /* Relative patch ID is a special case because it doesn't need an extra dword, repack separately. */
1454    s->repacked_rel_patch_id = nir_local_variable_create(impl, glsl_uint_type(), "repacked_rel_patch_id");
1455 
1456    if (s->options->clip_cull_dist_mask ||
1457        s->options->user_clip_plane_enable_mask) {
1458       s->clip_vertex_var =
1459          nir_local_variable_create(impl, glsl_vec4_type(), "clip_vertex");
1460       s->clipdist_neg_mask_var =
1461          nir_local_variable_create(impl, glsl_uint_type(), "clipdist_neg_mask");
1462 
1463       /* init mask to 0 */
1464       nir_store_var(b, s->clipdist_neg_mask_var, nir_imm_int(b, 0), 1);
1465    }
1466 
1467    /* Top part of the culling shader (aka. position shader part)
1468     *
1469     * We clone the full ES shader and emit it here, but we only really care
1470     * about its position output, so we delete every other output from this part.
1471     * The position output is stored into a temporary variable, and reloaded later.
1472     */
1473 
1474    nir_def *es_thread = has_input_vertex(b);
1475    nir_if *if_es_thread = nir_push_if(b, es_thread);
1476    {
1477       /* Initialize the position output variable to zeroes, in case not all VS/TES invocations store the output.
1478        * The spec doesn't require it, but we use (0, 0, 0, 1) because some games rely on that.
1479        */
1480       nir_store_var(b, s->position_value_var, nir_imm_vec4(b, 0.0f, 0.0f, 0.0f, 1.0f), 0xfu);
1481 
1482       /* Now reinsert a clone of the shader code */
1483       struct hash_table *remap_table = _mesa_pointer_hash_table_create(NULL);
1484       nir_cf_list_clone_and_reinsert(original_extracted_cf, &if_es_thread->cf_node, b->cursor, remap_table);
1485       _mesa_hash_table_destroy(remap_table, NULL);
1486       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
1487 
1488       /* Remember the current thread's shader arguments */
1489       if (b->shader->info.stage == MESA_SHADER_VERTEX) {
1490          nir_store_var(b, repacked_variables[0], nir_load_vertex_id_zero_base(b), 0x1u);
1491          if (uses_instance_id)
1492             nir_store_var(b, repacked_variables[1], nir_load_instance_id(b), 0x1u);
1493       } else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL) {
1494          nir_store_var(b, s->repacked_rel_patch_id, nir_load_tess_rel_patch_id_amd(b), 0x1u);
1495          nir_def *tess_coord = nir_load_tess_coord(b);
1496          nir_store_var(b, repacked_variables[0], nir_channel(b, tess_coord, 0), 0x1u);
1497          nir_store_var(b, repacked_variables[1], nir_channel(b, tess_coord, 1), 0x1u);
1498          if (uses_tess_primitive_id)
1499             nir_store_var(b, repacked_variables[2], nir_load_primitive_id(b), 0x1u);
1500       } else {
1501          unreachable("Should be VS or TES.");
1502       }
1503    }
1504    nir_pop_if(b, if_es_thread);
1505 
1506    nir_store_var(b, s->es_accepted_var, es_thread, 0x1u);
1507    nir_def *gs_thread = has_input_primitive(b);
1508    nir_store_var(b, s->gs_accepted_var, gs_thread, 0x1u);
1509 
1510    /* Remove all non-position outputs, and put the position output into the variable. */
1511    nir_metadata_preserve(impl, nir_metadata_none);
1512    remove_culling_shader_outputs(b->shader, s);
1513    b->cursor = nir_after_impl(impl);
1514 
1515    nir_def *lds_scratch_base = nir_load_lds_ngg_scratch_base_amd(b);
1516 
1517    /* Run culling algorithms if culling is enabled.
1518     *
1519     * NGG culling can be enabled or disabled in runtime.
1520     * This is determined by a SGPR shader argument which is accessed
1521     * by the following NIR intrinsic.
1522     */
1523 
1524    nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
1525    {
1526       nir_def *invocation_index = nir_load_local_invocation_index(b);
1527       nir_def *es_vertex_lds_addr = pervertex_lds_addr(b, invocation_index, pervertex_lds_bytes);
1528 
1529       /* ES invocations store their vertex data to LDS for GS threads to read. */
1530       if_es_thread = nir_push_if(b, es_thread);
1531       if_es_thread->control = nir_selection_control_divergent_always_taken;
1532       {
1533          /* Store position components that are relevant to culling in LDS */
1534          nir_def *pre_cull_pos = nir_load_var(b, s->position_value_var);
1535          nir_def *pre_cull_w = nir_channel(b, pre_cull_pos, 3);
1536          nir_store_shared(b, pre_cull_w, es_vertex_lds_addr, .base = lds_es_pos_w);
1537          nir_def *pre_cull_x_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 0), pre_cull_w);
1538          nir_def *pre_cull_y_div_w = nir_fdiv(b, nir_channel(b, pre_cull_pos, 1), pre_cull_w);
1539          nir_store_shared(b, nir_vec2(b, pre_cull_x_div_w, pre_cull_y_div_w), es_vertex_lds_addr, .base = lds_es_pos_x);
1540 
1541          /* Clear out the ES accepted flag in LDS */
1542          nir_store_shared(b, nir_imm_zero(b, 1, 8), es_vertex_lds_addr, .align_mul = 4, .base = lds_es_vertex_accepted);
1543 
1544          /* For clipdist culling */
1545          clipdist_culling_es_part(b, s, es_vertex_lds_addr);
1546       }
1547       nir_pop_if(b, if_es_thread);
1548 
1549       nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1550                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1551 
1552       nir_store_var(b, s->gs_accepted_var, nir_imm_false(b), 0x1u);
1553       nir_store_var(b, s->prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u);
1554 
1555       /* GS invocations load the vertex data and perform the culling. */
1556       nir_if *if_gs_thread = nir_push_if(b, gs_thread);
1557       {
1558          /* Load vertex indices from input VGPRs */
1559          nir_def *vtx_idx[3] = {0};
1560          for (unsigned vertex = 0; vertex < s->options->num_vertices_per_primitive;
1561               ++vertex)
1562             vtx_idx[vertex] = nir_load_var(b, s->gs_vtx_indices_vars[vertex]);
1563 
1564          nir_def *pos[3][4] = {0};
1565 
1566          /* Load W positions of vertices first because the culling code will use these first */
1567          for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1568             s->vtx_addr[vtx] = pervertex_lds_addr(b, vtx_idx[vtx], pervertex_lds_bytes);
1569             pos[vtx][3] = nir_load_shared(b, 1, 32, s->vtx_addr[vtx], .base = lds_es_pos_w);
1570             nir_store_var(b, gs_vtxaddr_vars[vtx], s->vtx_addr[vtx], 0x1u);
1571          }
1572 
1573          /* Load the X/W, Y/W positions of vertices */
1574          for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1575             nir_def *xy = nir_load_shared(b, 2, 32, s->vtx_addr[vtx], .base = lds_es_pos_x);
1576             pos[vtx][0] = nir_channel(b, xy, 0);
1577             pos[vtx][1] = nir_channel(b, xy, 1);
1578          }
1579 
1580          nir_def *accepted_by_clipdist;
1581          if (s->has_clipdist) {
1582             nir_def *clipdist_neg_mask = nir_imm_intN_t(b, 0xff, 8);
1583             for (unsigned vtx = 0; vtx < s->options->num_vertices_per_primitive; ++vtx) {
1584                nir_def *mask =
1585                   nir_load_shared(b, 1, 8, s->vtx_addr[vtx],
1586                                   .base = lds_es_clipdist_neg_mask);
1587                clipdist_neg_mask = nir_iand(b, clipdist_neg_mask, mask);
1588             }
1589             /* primitive is culled if any plane's clipdist of all vertices are negative */
1590             accepted_by_clipdist = nir_ieq_imm(b, clipdist_neg_mask, 0);
1591          } else {
1592             accepted_by_clipdist = nir_imm_true(b);
1593          }
1594 
1595          /* See if the current primitive is accepted */
1596          ac_nir_cull_primitive(b, accepted_by_clipdist, pos,
1597                                s->options->num_vertices_per_primitive,
1598                                cull_primitive_accepted, s);
1599       }
1600       nir_pop_if(b, if_gs_thread);
1601 
1602       nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
1603                             .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
1604 
1605       nir_store_var(b, s->es_accepted_var, nir_imm_false(b), 0x1u);
1606 
1607       /* ES invocations load their accepted flag from LDS. */
1608       if_es_thread = nir_push_if(b, es_thread);
1609       if_es_thread->control = nir_selection_control_divergent_always_taken;
1610       {
1611          nir_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, .base = lds_es_vertex_accepted, .align_mul = 4u);
1612          nir_def *accepted_bool = nir_ine_imm(b, nir_u2u32(b, accepted), 0);
1613          nir_store_var(b, s->es_accepted_var, accepted_bool, 0x1u);
1614       }
1615       nir_pop_if(b, if_es_thread);
1616 
1617       nir_def *es_accepted = nir_load_var(b, s->es_accepted_var);
1618 
1619       /* Repack the vertices that survived the culling. */
1620       wg_repack_result rep = repack_invocations_in_workgroup(b, es_accepted, lds_scratch_base,
1621                                                              s->max_num_waves,
1622                                                              s->options->wave_size);
1623       nir_def *num_live_vertices_in_workgroup = rep.num_repacked_invocations;
1624       nir_def *es_exporter_tid = rep.repacked_invocation_index;
1625 
1626       /* If all vertices are culled, set primitive count to 0 as well. */
1627       nir_def *num_exported_prims = nir_load_workgroup_num_input_primitives_amd(b);
1628       nir_def *fully_culled = nir_ieq_imm(b, num_live_vertices_in_workgroup, 0u);
1629       num_exported_prims = nir_bcsel(b, fully_culled, nir_imm_int(b, 0u), num_exported_prims);
1630       nir_store_var(b, s->gs_exported_var, nir_iand(b, nir_inot(b, fully_culled), has_input_primitive(b)), 0x1u);
1631 
1632       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
1633       {
1634          /* Tell the final vertex and primitive count to the HW. */
1635          if (s->options->gfx_level == GFX10) {
1636             alloc_vertices_and_primitives_gfx10_workaround(
1637                b, num_live_vertices_in_workgroup, num_exported_prims);
1638          } else {
1639             alloc_vertices_and_primitives(
1640                b, num_live_vertices_in_workgroup, num_exported_prims);
1641          }
1642       }
1643       nir_pop_if(b, if_wave_0);
1644 
1645       /* Vertex compaction. */
1646       compact_vertices_after_culling(b, s,
1647                                      repacked_variables, gs_vtxaddr_vars,
1648                                      invocation_index, es_vertex_lds_addr,
1649                                      es_exporter_tid, num_live_vertices_in_workgroup,
1650                                      pervertex_lds_bytes, num_repacked_variables);
1651    }
1652    nir_push_else(b, if_cull_en);
1653    {
1654       /* When culling is disabled, we do the same as we would without culling. */
1655       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
1656       {
1657          nir_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
1658          nir_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
1659          alloc_vertices_and_primitives(b, vtx_cnt, prim_cnt);
1660       }
1661       nir_pop_if(b, if_wave_0);
1662       nir_store_var(b, s->prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, s), 0x1u);
1663    }
1664    nir_pop_if(b, if_cull_en);
1665 
1666    /* Update shader arguments.
1667     *
1668     * The registers which hold information about the subgroup's
1669     * vertices and primitives are updated here, so the rest of the shader
1670     * doesn't need to worry about the culling.
1671     *
1672     * These "overwrite" intrinsics must be at top level control flow,
1673     * otherwise they can mess up the backend (eg. ACO's SSA).
1674     *
1675     * TODO:
1676     * A cleaner solution would be to simply replace all usages of these args
1677     * with the load of the variables.
1678     * However, this wouldn't work right now because the backend uses the arguments
1679     * for purposes not expressed in NIR, eg. VS input loads, etc.
1680     * This can change if VS input loads and other stuff are lowered to eg. load_buffer_amd.
1681     */
1682 
1683    if (b->shader->info.stage == MESA_SHADER_VERTEX)
1684       s->overwrite_args =
1685          nir_overwrite_vs_arguments_amd(b,
1686             nir_load_var(b, repacked_variables[0]), nir_load_var(b, repacked_variables[1]));
1687    else if (b->shader->info.stage == MESA_SHADER_TESS_EVAL)
1688       s->overwrite_args =
1689          nir_overwrite_tes_arguments_amd(b,
1690             nir_load_var(b, repacked_variables[0]), nir_load_var(b, repacked_variables[1]),
1691             nir_load_var(b, repacked_variables[2]), nir_load_var(b, s->repacked_rel_patch_id));
1692    else
1693       unreachable("Should be VS or TES.");
1694 }
1695 
1696 static void
ngg_nogs_store_edgeflag_to_lds(nir_builder * b,lower_ngg_nogs_state * s)1697 ngg_nogs_store_edgeflag_to_lds(nir_builder *b, lower_ngg_nogs_state *s)
1698 {
1699    if (!s->out.outputs[VARYING_SLOT_EDGE][0])
1700       return;
1701 
1702    /* clamp user edge flag to 1 for latter bit operations */
1703    nir_def *edgeflag = s->out.outputs[VARYING_SLOT_EDGE][0];
1704    edgeflag = nir_umin(b, edgeflag, nir_imm_int(b, 1));
1705 
1706    /* user edge flag is stored at the beginning of a vertex if streamout is not enabled */
1707    unsigned offset = 0;
1708    if (s->streamout_enabled) {
1709       unsigned packed_location =
1710          util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(VARYING_SLOT_EDGE));
1711       offset = packed_location * 16;
1712    }
1713 
1714    nir_def *tid = nir_load_local_invocation_index(b);
1715    nir_def *addr = pervertex_lds_addr(b, tid, s->pervertex_lds_bytes);
1716 
1717    nir_store_shared(b, edgeflag, addr, .base = offset);
1718 }
1719 
1720 static void
ngg_nogs_store_xfb_outputs_to_lds(nir_builder * b,lower_ngg_nogs_state * s)1721 ngg_nogs_store_xfb_outputs_to_lds(nir_builder *b, lower_ngg_nogs_state *s)
1722 {
1723    nir_xfb_info *info = b->shader->xfb_info;
1724 
1725    uint64_t xfb_outputs = 0;
1726    unsigned xfb_outputs_16bit = 0;
1727    uint8_t xfb_mask[VARYING_SLOT_MAX] = {0};
1728    uint8_t xfb_mask_16bit_lo[16] = {0};
1729    uint8_t xfb_mask_16bit_hi[16] = {0};
1730 
1731    /* Get XFB output mask for each slot. */
1732    for (int i = 0; i < info->output_count; i++) {
1733       nir_xfb_output_info *out = info->outputs + i;
1734 
1735       if (out->location < VARYING_SLOT_VAR0_16BIT) {
1736          xfb_outputs |= BITFIELD64_BIT(out->location);
1737          xfb_mask[out->location] |= out->component_mask;
1738       } else {
1739          unsigned index = out->location - VARYING_SLOT_VAR0_16BIT;
1740          xfb_outputs_16bit |= BITFIELD_BIT(index);
1741 
1742          if (out->high_16bits)
1743             xfb_mask_16bit_hi[index] |= out->component_mask;
1744          else
1745             xfb_mask_16bit_lo[index] |= out->component_mask;
1746       }
1747    }
1748 
1749    nir_def *tid = nir_load_local_invocation_index(b);
1750    nir_def *addr = pervertex_lds_addr(b, tid, s->pervertex_lds_bytes);
1751 
1752    u_foreach_bit64(slot, xfb_outputs) {
1753       uint64_t outputs_written = b->shader->info.outputs_written;
1754       if (s->skip_primitive_id)
1755          outputs_written &= ~VARYING_BIT_PRIMITIVE_ID;
1756       unsigned packed_location =
1757          util_bitcount64(outputs_written & BITFIELD64_MASK(slot));
1758 
1759       unsigned mask = xfb_mask[slot];
1760 
1761       /* Clear unused components. */
1762       for (unsigned i = 0; i < 4; i++) {
1763          if (!s->out.outputs[slot][i])
1764             mask &= ~BITFIELD_BIT(i);
1765       }
1766 
1767       while (mask) {
1768          int start, count;
1769          u_bit_scan_consecutive_range(&mask, &start, &count);
1770          /* Outputs here are sure to be 32bit.
1771           *
1772           * 64bit outputs have been lowered to two 32bit. As 16bit outputs:
1773           *   Vulkan does not allow streamout outputs less than 32bit.
1774           *   OpenGL puts 16bit outputs in VARYING_SLOT_VAR0_16BIT.
1775           */
1776          nir_def *store_val = nir_vec(b, &s->out.outputs[slot][start], (unsigned)count);
1777          nir_store_shared(b, store_val, addr, .base = packed_location * 16 + start * 4);
1778       }
1779    }
1780 
1781    unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
1782    u_foreach_bit64(slot, xfb_outputs_16bit) {
1783       unsigned packed_location = num_32bit_outputs +
1784          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
1785 
1786       unsigned mask_lo = xfb_mask_16bit_lo[slot];
1787       unsigned mask_hi = xfb_mask_16bit_hi[slot];
1788 
1789       /* Clear unused components. */
1790       for (unsigned i = 0; i < 4; i++) {
1791          if (!s->out.outputs_16bit_lo[slot][i])
1792             mask_lo &= ~BITFIELD_BIT(i);
1793          if (!s->out.outputs_16bit_hi[slot][i])
1794             mask_hi &= ~BITFIELD_BIT(i);
1795       }
1796 
1797       nir_def **outputs_lo = s->out.outputs_16bit_lo[slot];
1798       nir_def **outputs_hi = s->out.outputs_16bit_hi[slot];
1799       nir_def *undef = nir_undef(b, 1, 16);
1800 
1801       unsigned mask = mask_lo | mask_hi;
1802       while (mask) {
1803          int start, count;
1804          u_bit_scan_consecutive_range(&mask, &start, &count);
1805 
1806          nir_def *values[4] = {0};
1807          for (int c = start; c < start + count; ++c) {
1808             nir_def *lo = mask_lo & BITFIELD_BIT(c) ? outputs_lo[c] : undef;
1809             nir_def *hi = mask_hi & BITFIELD_BIT(c) ? outputs_hi[c] : undef;
1810 
1811             /* extend 8/16 bit to 32 bit, 64 bit has been lowered */
1812             values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
1813          }
1814 
1815          nir_def *store_val = nir_vec(b, values, (unsigned)count);
1816          nir_store_shared(b, store_val, addr, .base = packed_location * 16 + start * 4);
1817       }
1818    }
1819 }
1820 
1821 static nir_def *
write_values_to_lanes(nir_builder * b,nir_def ** values,unsigned lane_mask)1822 write_values_to_lanes(nir_builder *b, nir_def **values, unsigned lane_mask)
1823 {
1824    nir_def *lanes = nir_imm_int(b, 0);
1825 
1826    u_foreach_bit(i, lane_mask) {
1827       lanes = nir_write_invocation_amd(b, lanes, values[i], nir_imm_int(b, i));
1828    }
1829    return lanes;
1830 }
1831 
1832 static void
ngg_build_streamout_buffer_info(nir_builder * b,nir_xfb_info * info,enum amd_gfx_level gfx_level,bool has_xfb_prim_query,bool use_gfx12_xfb_intrinsic,nir_def * scratch_base,nir_def * tid_in_tg,nir_def * gen_prim[4],nir_def * prim_stride_ret[4],nir_def * so_buffer_ret[4],nir_def * buffer_offsets_ret[4],nir_def * emit_prim_ret[4])1833 ngg_build_streamout_buffer_info(nir_builder *b,
1834                                 nir_xfb_info *info,
1835                                 enum amd_gfx_level gfx_level,
1836                                 bool has_xfb_prim_query,
1837                                 bool use_gfx12_xfb_intrinsic,
1838                                 nir_def *scratch_base,
1839                                 nir_def *tid_in_tg,
1840                                 nir_def *gen_prim[4],
1841                                 nir_def *prim_stride_ret[4],
1842                                 nir_def *so_buffer_ret[4],
1843                                 nir_def *buffer_offsets_ret[4],
1844                                 nir_def *emit_prim_ret[4])
1845 {
1846    nir_def *undef = nir_undef(b, 1, 32);
1847 
1848    /* For radeonsi which pass this value by arg when VS. Streamout need accurate
1849     * num-vert-per-prim for writing correct amount of data to buffer.
1850     */
1851    nir_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
1852    for (unsigned buffer = 0; buffer < 4; buffer++) {
1853       if (!(info->buffers_written & BITFIELD_BIT(buffer)))
1854          continue;
1855 
1856       assert(info->buffers[buffer].stride);
1857 
1858       prim_stride_ret[buffer] =
1859          nir_imul_imm(b, num_vert_per_prim, info->buffers[buffer].stride);
1860       so_buffer_ret[buffer] = nir_load_streamout_buffer_amd(b, .base = buffer);
1861    }
1862 
1863    nir_if *if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
1864    {
1865       nir_def *workgroup_buffer_sizes[4];
1866       for (unsigned buffer = 0; buffer < 4; buffer++) {
1867          if (info->buffers_written & BITFIELD_BIT(buffer)) {
1868             nir_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2);
1869             /* In radeonsi, we may not know if a feedback buffer has been bound when
1870              * compile time, so have to check buffer size in runtime to disable the
1871              * GDS update for unbind buffer to prevent the case that previous draw
1872              * compiled with streamout but does not bind feedback buffer miss update
1873              * GDS which will affect current draw's streamout.
1874              */
1875             nir_def *buffer_valid = nir_ine_imm(b, buffer_size, 0);
1876             nir_def *inc_buffer_size =
1877                nir_imul(b, gen_prim[info->buffer_to_stream[buffer]], prim_stride_ret[buffer]);
1878             workgroup_buffer_sizes[buffer] =
1879                nir_bcsel(b, buffer_valid, inc_buffer_size, nir_imm_int(b, 0));
1880          } else
1881             workgroup_buffer_sizes[buffer] = undef;
1882       }
1883 
1884       nir_def *buffer_offsets = NULL, *xfb_state_address = NULL, *xfb_voffset = NULL;
1885 
1886       /* Get current global offset of buffer and increase by amount of
1887        * workgroup buffer size. This is an ordered operation sorted by
1888        * ordered_id; Each buffer info is in a channel of a vec4.
1889        */
1890       if (gfx_level >= GFX12) {
1891          nir_pop_if(b, if_invocation_0);
1892 
1893          for (unsigned buffer = 0; buffer < 4; buffer++)
1894             workgroup_buffer_sizes[buffer] = nir_if_phi(b, workgroup_buffer_sizes[buffer], undef);
1895 
1896          /* These must be set after nir_pop_if and phis. */
1897          xfb_state_address = nir_load_xfb_state_address_gfx12_amd(b);
1898          xfb_voffset = nir_imul_imm(b, tid_in_tg, 8);
1899 
1900          nir_if *if_4lanes = nir_push_if(b, nir_ult_imm(b, tid_in_tg, 4));
1901          {
1902             /* Move workgroup buffer sizes from SGPRs to the first 4 lanes. */
1903             nir_def *workgroup_buffer_size_per_lane =
1904                write_values_to_lanes(b, workgroup_buffer_sizes, info->buffers_written);
1905             nir_def *ordered_id = nir_load_ordered_id_amd(b);
1906 
1907             /* The atomic value for the 4 lanes is:
1908              *    lane 0: uvec2(ordered_id, workgroup_buffer_size0)
1909              *    lane 1: uvec2(ordered_id, workgroup_buffer_size1)
1910              *    lane 2: uvec2(ordered_id, workgroup_buffer_size2)
1911              *    lane 3: uvec2(ordered_id, workgroup_buffer_size3)
1912              */
1913             nir_def *atomic_src = nir_pack_64_2x32_split(b, ordered_id,
1914                                                          workgroup_buffer_size_per_lane);
1915 
1916             /* The memory layout of the xfb state is:
1917              *    struct {
1918              *       unsigned ordered_id;
1919              *       unsigned dwords_written0;
1920              *       unsigned ordered_id;
1921              *       unsigned dwords_written1;
1922              *       unsigned ordered_id;
1923              *       unsigned dwords_written2;
1924              *       unsigned ordered_id;
1925              *       unsigned dwords_written3;
1926              *    };
1927              *
1928              * Notes:
1929              * - global_atomic_ordered_add_b64 is semantically a 64-bit atomic, requiring 8-byte
1930              *   address alignment, even though it operates on a pair of 32-bit values.
1931              * - The whole structure is updated at once by issuing the atomic from 4 lanes
1932              *   with 8-byte address increments.
1933              * - The whole structure should be entirely within one 64B block of memory
1934              *   for performance. (the address bits above 64B should not differ between lanes)
1935              */
1936             nir_def *buffer_offset_per_lane;
1937 
1938             /* The gfx12 intrinsic inserts hand-written assembly producing better code than current
1939              * LLVM.
1940              */
1941             if (use_gfx12_xfb_intrinsic) {
1942                buffer_offset_per_lane =
1943                   nir_ordered_add_loop_gfx12_amd(b, xfb_state_address, xfb_voffset, ordered_id,
1944                                                  atomic_src);
1945             } else {
1946                /* The NIR version of the above using nir_atomic_op_ordered_add_gfx12_amd. */
1947                enum { NUM_ATOMICS_IN_FLIGHT = 6 };
1948 
1949                nir_variable *result_ring[NUM_ATOMICS_IN_FLIGHT] = {0};
1950                for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT; i++)
1951                   result_ring[i] = nir_local_variable_create(b->impl, glsl_uint64_t_type(), "result");
1952 
1953                /* Issue the first N-1 atomics. The shader must not wait because we want them to be
1954                 * pipelined. It will only wait for the oldest atomic in the NIR loop.
1955                 */
1956                for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT - 1; i++) {
1957                   nir_store_var(b, result_ring[i],
1958                                 nir_global_atomic_amd(b, 64, xfb_state_address, atomic_src, xfb_voffset,
1959                                                       .atomic_op = nir_atomic_op_ordered_add_gfx12_amd), 0x1);
1960                }
1961 
1962                nir_variable *buffer_offset_per_lane_var =
1963                   nir_local_variable_create(b->impl, glsl_uint_type(), "buffer_offset_per_lane");
1964 
1965                nir_loop *loop = nir_push_loop(b);
1966                {
1967                   for (unsigned i = 0; i < NUM_ATOMICS_IN_FLIGHT; i++) {
1968                      int issue_index = (NUM_ATOMICS_IN_FLIGHT - 1 + i) % NUM_ATOMICS_IN_FLIGHT;
1969                      int read_index = i;
1970 
1971                      /* Issue (or repeat) the atomic. */
1972                      nir_store_var(b, result_ring[issue_index],
1973                                    nir_global_atomic_amd(b, 64, xfb_state_address, atomic_src, xfb_voffset,
1974                                                          .atomic_op = nir_atomic_op_ordered_add_gfx12_amd), 0x1);
1975 
1976                      /* Break if the oldest atomic succeeded in incrementing the offsets. */
1977                      nir_def *oldest_result = nir_load_var(b, result_ring[read_index]);
1978                      nir_def *loaded_ordered_id = nir_unpack_64_2x32_split_x(b, oldest_result);
1979                      nir_def *loaded_dwords_written = nir_unpack_64_2x32_split_y(b, oldest_result);
1980 
1981                      /* Debug: Write the vec4 into a shader log ring buffer. */
1982 #if 0
1983                      ac_nir_store_debug_log_amd(b, nir_vec4(b, nir_u2u32(b, xfb_state_address),
1984                                                             ordered_id, loaded_ordered_id,
1985                                                             loaded_dwords_written));
1986 #endif
1987 
1988                      /* This results in better code than using ballot with LLVM. */
1989                      loaded_ordered_id = nir_read_invocation(b, loaded_ordered_id, nir_imm_int(b, 0));
1990 
1991                      nir_if *if_break = nir_push_if(b, nir_ieq(b, loaded_ordered_id, ordered_id));
1992                      {
1993                         nir_store_var(b, buffer_offset_per_lane_var, loaded_dwords_written, 0x1);
1994                         nir_jump(b, nir_jump_break);
1995                      }
1996                      nir_pop_if(b, if_break);
1997                   }
1998                }
1999                nir_pop_loop(b, loop);
2000 
2001                buffer_offset_per_lane = nir_load_var(b, buffer_offset_per_lane_var);
2002             }
2003 
2004             /* Move the buffer offsets from the 4 lanes to lane 0. */
2005             nir_def *offset[4] = {undef, undef, undef, undef};
2006 
2007             for (unsigned buffer = 0; buffer < 4; buffer++) {
2008                if (info->buffers_written & BITFIELD_BIT(buffer)) {
2009                   if (!buffer) {
2010                      offset[buffer] = buffer_offset_per_lane;
2011                   } else {
2012                      offset[buffer] = nir_quad_swizzle_amd(b, buffer_offset_per_lane,
2013                                                            .swizzle_mask = BITFIELD_BIT(buffer));
2014                   }
2015                }
2016             }
2017             buffer_offsets = nir_vec(b, offset, 4);
2018          }
2019          nir_pop_if(b, if_4lanes);
2020          buffer_offsets = nir_if_phi(b, buffer_offsets, nir_undef(b, 4, 32));
2021 
2022          if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
2023       } else {
2024          nir_def *ordered_id = nir_load_ordered_id_amd(b);
2025          buffer_offsets =
2026             nir_ordered_xfb_counter_add_gfx11_amd(b, ordered_id,
2027                                                   nir_vec(b, workgroup_buffer_sizes, 4),
2028                                                   /* mask of buffers to update */
2029                                                   .write_mask = info->buffers_written);
2030       }
2031 
2032       nir_def *emit_prim[4];
2033       memcpy(emit_prim, gen_prim, 4 * sizeof(nir_def *));
2034 
2035       nir_def *any_overflow = nir_imm_false(b);
2036       nir_def *overflow_amount[4] = {undef, undef, undef, undef};
2037 
2038       for (unsigned buffer = 0; buffer < 4; buffer++) {
2039          if (!(info->buffers_written & BITFIELD_BIT(buffer)))
2040             continue;
2041 
2042          nir_def *buffer_size = nir_channel(b, so_buffer_ret[buffer], 2);
2043 
2044          /* Only consider overflow for valid feedback buffers because
2045           * otherwise the ordered operation above (GDS atomic return) might
2046           * return non-zero offsets for invalid buffers.
2047           */
2048          nir_def *buffer_valid = nir_ine_imm(b, buffer_size, 0);
2049          nir_def *buffer_offset = nir_channel(b, buffer_offsets, buffer);
2050          buffer_offset = nir_bcsel(b, buffer_valid, buffer_offset, nir_imm_int(b, 0));
2051 
2052          nir_def *remain_size = nir_isub(b, buffer_size, buffer_offset);
2053          nir_def *remain_prim = nir_idiv(b, remain_size, prim_stride_ret[buffer]);
2054          nir_def *overflow = nir_ilt(b, buffer_size, buffer_offset);
2055 
2056          any_overflow = nir_ior(b, any_overflow, overflow);
2057          overflow_amount[buffer] = nir_imax(b, nir_imm_int(b, 0),
2058                                             nir_isub(b, buffer_offset, buffer_size));
2059 
2060          unsigned stream = info->buffer_to_stream[buffer];
2061          /* when previous workgroup overflow, we can't emit any primitive */
2062          emit_prim[stream] = nir_bcsel(
2063             b, overflow, nir_imm_int(b, 0),
2064             /* we can emit part primitives, limited by smallest buffer */
2065             nir_imin(b, emit_prim[stream], remain_prim));
2066 
2067          /* Save to LDS for being accessed by other waves in this workgroup. */
2068          nir_store_shared(b, buffer_offset, scratch_base, .base = buffer * 4);
2069       }
2070 
2071       /* We have to fix up the streamout offsets if we overflowed because they determine
2072        * the vertex count for DrawTransformFeedback.
2073        */
2074       if (gfx_level >= GFX12) {
2075          nir_pop_if(b, if_invocation_0);
2076 
2077          any_overflow = nir_if_phi(b, any_overflow, nir_undef(b, 1, 1));
2078          for (unsigned buffer = 0; buffer < 4; buffer++)
2079             overflow_amount[buffer] = nir_if_phi(b, overflow_amount[buffer], undef);
2080          for (unsigned stream = 0; stream < 4; stream++) {
2081             if (emit_prim[stream])
2082                emit_prim[stream] = nir_if_phi(b, emit_prim[stream], undef);
2083          }
2084 
2085          nir_if *if_any_overflow_4_lanes =
2086             nir_push_if(b, nir_iand(b, any_overflow, nir_ult_imm(b, tid_in_tg, 4)));
2087          {
2088             /* Move overflow amounts from SGPRs to the first 4 lanes. */
2089             nir_def *overflow_amount_per_lane =
2090                write_values_to_lanes(b, overflow_amount, info->buffers_written);
2091 
2092             nir_global_atomic_amd(b, 32, xfb_state_address, nir_ineg(b, overflow_amount_per_lane),
2093                                   xfb_voffset, .base = 4, .atomic_op = nir_atomic_op_iadd);
2094          }
2095          nir_pop_if(b, if_any_overflow_4_lanes);
2096 
2097          if_invocation_0 = nir_push_if(b, nir_ieq_imm(b, tid_in_tg, 0));
2098       } else {
2099          nir_if *if_any_overflow = nir_push_if(b, any_overflow);
2100          nir_xfb_counter_sub_gfx11_amd(b, nir_vec(b, overflow_amount, 4),
2101                                        /* mask of buffers to update */
2102                                        .write_mask = info->buffers_written);
2103          nir_pop_if(b, if_any_overflow);
2104       }
2105 
2106       /* Save to LDS for being accessed by other waves in this workgroup. */
2107       for (unsigned stream = 0; stream < 4; stream++) {
2108          if (!(info->streams_written & BITFIELD_BIT(stream)))
2109             continue;
2110 
2111          nir_store_shared(b, emit_prim[stream], scratch_base, .base = 16 + stream * 4);
2112       }
2113 
2114       /* Update shader query. */
2115       if (has_xfb_prim_query) {
2116          nir_if *if_shader_query = nir_push_if(b, nir_load_prim_xfb_query_enabled_amd(b));
2117          {
2118             for (unsigned stream = 0; stream < 4; stream++) {
2119                if (info->streams_written & BITFIELD_BIT(stream))
2120                   nir_atomic_add_xfb_prim_count_amd(b, emit_prim[stream], .stream_id = stream);
2121             }
2122          }
2123          nir_pop_if(b, if_shader_query);
2124       }
2125    }
2126    nir_pop_if(b, if_invocation_0);
2127 
2128    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
2129                       .memory_scope = SCOPE_WORKGROUP,
2130                       .memory_semantics = NIR_MEMORY_ACQ_REL,
2131                       .memory_modes = nir_var_mem_shared);
2132 
2133    /* Fetch the per-buffer offsets in all waves. */
2134    for (unsigned buffer = 0; buffer < 4; buffer++) {
2135       if (!(info->buffers_written & BITFIELD_BIT(buffer)))
2136          continue;
2137 
2138       buffer_offsets_ret[buffer] =
2139          nir_load_shared(b, 1, 32, scratch_base, .base = buffer * 4);
2140    }
2141 
2142    /* Fetch the per-stream emit prim in all waves. */
2143    for (unsigned stream = 0; stream < 4; stream++) {
2144       if (!(info->streams_written & BITFIELD_BIT(stream)))
2145             continue;
2146 
2147       emit_prim_ret[stream] =
2148          nir_load_shared(b, 1, 32, scratch_base, .base = 16 + stream * 4);
2149    }
2150 }
2151 
2152 static void
ngg_build_streamout_vertex(nir_builder * b,nir_xfb_info * info,unsigned stream,nir_def * so_buffer[4],nir_def * buffer_offsets[4],nir_def * vtx_buffer_idx,nir_def * vtx_lds_addr,ac_nir_prerast_out * pr_out,bool skip_primitive_id)2153 ngg_build_streamout_vertex(nir_builder *b, nir_xfb_info *info,
2154                            unsigned stream, nir_def *so_buffer[4],
2155                            nir_def *buffer_offsets[4],
2156                            nir_def *vtx_buffer_idx, nir_def *vtx_lds_addr,
2157                            ac_nir_prerast_out *pr_out,
2158                            bool skip_primitive_id)
2159 {
2160    nir_def *vtx_buffer_offsets[4];
2161    for (unsigned buffer = 0; buffer < 4; buffer++) {
2162       if (!(info->buffers_written & BITFIELD_BIT(buffer)))
2163          continue;
2164 
2165       nir_def *offset = nir_imul_imm(b, vtx_buffer_idx, info->buffers[buffer].stride);
2166       vtx_buffer_offsets[buffer] = nir_iadd(b, buffer_offsets[buffer], offset);
2167    }
2168 
2169    for (unsigned i = 0; i < info->output_count; i++) {
2170       nir_xfb_output_info *out = info->outputs + i;
2171       if (!out->component_mask || info->buffer_to_stream[out->buffer] != stream)
2172          continue;
2173 
2174       unsigned base;
2175       if (out->location >= VARYING_SLOT_VAR0_16BIT) {
2176          base =
2177             util_bitcount64(b->shader->info.outputs_written) +
2178             util_bitcount(b->shader->info.outputs_written_16bit &
2179                           BITFIELD_MASK(out->location - VARYING_SLOT_VAR0_16BIT));
2180       } else {
2181          uint64_t outputs_written = b->shader->info.outputs_written;
2182          if (skip_primitive_id)
2183             outputs_written &= ~VARYING_BIT_PRIMITIVE_ID;
2184 
2185          base =
2186             util_bitcount64(outputs_written &
2187                             BITFIELD64_MASK(out->location));
2188       }
2189 
2190       unsigned offset = (base * 4 + out->component_offset) * 4;
2191       unsigned count = util_bitcount(out->component_mask);
2192 
2193       assert(u_bit_consecutive(out->component_offset, count) == out->component_mask);
2194 
2195       nir_def *out_data =
2196          nir_load_shared(b, count, 32, vtx_lds_addr, .base = offset);
2197 
2198       /* Up-scaling 16bit outputs to 32bit.
2199        *
2200        * OpenGL ES will put 16bit medium precision varyings to VARYING_SLOT_VAR0_16BIT.
2201        * We need to up-scaling them to 32bit when streamout to buffer.
2202        *
2203        * Vulkan does not allow 8/16bit varyings to be streamout.
2204        */
2205       if (out->location >= VARYING_SLOT_VAR0_16BIT) {
2206          unsigned index = out->location - VARYING_SLOT_VAR0_16BIT;
2207          nir_def *values[4];
2208 
2209          for (int j = 0; j < count; j++) {
2210             unsigned c = out->component_offset + j;
2211             nir_def *v = nir_channel(b, out_data, j);
2212             nir_alu_type t;
2213 
2214             if (out->high_16bits) {
2215                v = nir_unpack_32_2x16_split_y(b, v);
2216                t = pr_out->types_16bit_hi[index][c];
2217             } else {
2218                v = nir_unpack_32_2x16_split_x(b, v);
2219                t = pr_out->types_16bit_lo[index][c];
2220             }
2221 
2222             t = nir_alu_type_get_base_type(t);
2223             values[j] = nir_convert_to_bit_size(b, v, t, 32);
2224          }
2225 
2226          out_data = nir_vec(b, values, count);
2227       }
2228 
2229       nir_def *zero = nir_imm_int(b, 0);
2230       nir_store_buffer_amd(b, out_data, so_buffer[out->buffer],
2231                            vtx_buffer_offsets[out->buffer],
2232                            zero, zero,
2233                            .base = out->offset,
2234                            .memory_modes = nir_var_mem_ssbo,
2235                            .access = ACCESS_NON_TEMPORAL);
2236    }
2237 }
2238 
2239 static void
ngg_nogs_build_streamout(nir_builder * b,lower_ngg_nogs_state * s)2240 ngg_nogs_build_streamout(nir_builder *b, lower_ngg_nogs_state *s)
2241 {
2242    nir_xfb_info *info = b->shader->xfb_info;
2243 
2244    nir_def *lds_scratch_base = nir_load_lds_ngg_scratch_base_amd(b);
2245 
2246    /* Get global buffer offset where this workgroup will stream out data to. */
2247    nir_def *generated_prim = nir_load_workgroup_num_input_primitives_amd(b);
2248    nir_def *gen_prim_per_stream[4] = {generated_prim, 0, 0, 0};
2249    nir_def *emit_prim_per_stream[4] = {0};
2250    nir_def *buffer_offsets[4] = {0};
2251    nir_def *so_buffer[4] = {0};
2252    nir_def *prim_stride[4] = {0};
2253    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
2254    ngg_build_streamout_buffer_info(b, info, s->options->gfx_level, s->options->has_xfb_prim_query,
2255                                    s->options->use_gfx12_xfb_intrinsic, lds_scratch_base, tid_in_tg,
2256                                    gen_prim_per_stream, prim_stride,
2257                                    so_buffer, buffer_offsets,
2258                                    emit_prim_per_stream);
2259 
2260    /* Write out primitive data */
2261    nir_if *if_emit = nir_push_if(b, nir_ilt(b, tid_in_tg, emit_prim_per_stream[0]));
2262    {
2263       unsigned vtx_lds_stride = (b->shader->num_outputs * 4 + 1) * 4;
2264       nir_def *num_vert_per_prim = nir_load_num_vertices_per_primitive_amd(b);
2265       nir_def *vtx_buffer_idx = nir_imul(b, tid_in_tg, num_vert_per_prim);
2266 
2267       for (unsigned i = 0; i < s->options->num_vertices_per_primitive; i++) {
2268          nir_if *if_valid_vertex =
2269             nir_push_if(b, nir_igt_imm(b, num_vert_per_prim, i));
2270          {
2271             nir_def *vtx_lds_idx = nir_load_var(b, s->gs_vtx_indices_vars[i]);
2272             nir_def *vtx_lds_addr = pervertex_lds_addr(b, vtx_lds_idx, vtx_lds_stride);
2273             ngg_build_streamout_vertex(b, info, 0, so_buffer, buffer_offsets,
2274                                        nir_iadd_imm(b, vtx_buffer_idx, i),
2275                                        vtx_lds_addr, &s->out, s->skip_primitive_id);
2276          }
2277          nir_pop_if(b, if_valid_vertex);
2278       }
2279    }
2280    nir_pop_if(b, if_emit);
2281 
2282    /* Wait streamout memory ops done before export primitive, otherwise it
2283     * may not finish when shader ends.
2284     *
2285     * If a shader has no param exports, rasterization can start before
2286     * the shader finishes and thus memory stores might not finish before
2287     * the pixel shader starts.
2288     *
2289     * TODO: we only need this when no param exports.
2290     *
2291     * TODO: not sure if we need this barrier when late prim export, as I
2292     *       can't observe test fail without this barrier.
2293     */
2294    nir_scoped_memory_barrier(b, SCOPE_DEVICE, NIR_MEMORY_RELEASE, nir_var_mem_ssbo);
2295 }
2296 
2297 static unsigned
ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,unsigned shader_num_outputs,bool streamout_enabled,bool export_prim_id,bool has_user_edgeflags)2298 ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
2299                                 unsigned shader_num_outputs,
2300                                 bool streamout_enabled,
2301                                 bool export_prim_id,
2302                                 bool has_user_edgeflags)
2303 {
2304    unsigned pervertex_lds_bytes = 0;
2305 
2306    if (streamout_enabled) {
2307       /* The extra dword is used to avoid LDS bank conflicts and store the primitive id.
2308        * TODO: only alloc space for outputs that really need streamout.
2309        */
2310       pervertex_lds_bytes = (shader_num_outputs * 4 + 1) * 4;
2311    }
2312 
2313    bool need_prim_id_store_shared = export_prim_id && stage == MESA_SHADER_VERTEX;
2314    if (need_prim_id_store_shared || has_user_edgeflags) {
2315       unsigned size = 0;
2316       if (need_prim_id_store_shared)
2317          size += 4;
2318       if (has_user_edgeflags)
2319          size += 4;
2320 
2321       /* pad to odd dwords to avoid LDS bank conflict */
2322       size |= 4;
2323 
2324       pervertex_lds_bytes = MAX2(pervertex_lds_bytes, size);
2325    }
2326 
2327    return pervertex_lds_bytes;
2328 }
2329 
2330 static void
ngg_nogs_gather_outputs(nir_builder * b,struct exec_list * cf_list,lower_ngg_nogs_state * s)2331 ngg_nogs_gather_outputs(nir_builder *b, struct exec_list *cf_list, lower_ngg_nogs_state *s)
2332 {
2333    /* Assume:
2334     * - the shader used nir_lower_io_to_temporaries
2335     * - 64-bit outputs are lowered
2336     * - no indirect indexing is present
2337     */
2338    struct nir_cf_node *first_node =
2339       exec_node_data(nir_cf_node, exec_list_get_head(cf_list), node);
2340 
2341    for (nir_block *block = nir_cf_node_cf_tree_first(first_node); block != NULL;
2342         block = nir_block_cf_tree_next(block)) {
2343       nir_foreach_instr_safe (instr, block) {
2344          if (instr->type != nir_instr_type_intrinsic)
2345             continue;
2346 
2347          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
2348          if (intrin->intrinsic != nir_intrinsic_store_output)
2349             continue;
2350 
2351          ac_nir_gather_prerast_store_output_info(b, intrin, &s->out);
2352          nir_instr_remove(instr);
2353       }
2354    }
2355 }
2356 
2357 static unsigned
gather_vs_outputs(nir_builder * b,vs_output * outputs,const uint8_t * param_offsets,nir_def * (* data)[4],nir_def * (* data_16bit_lo)[4],nir_def * (* data_16bit_hi)[4])2358 gather_vs_outputs(nir_builder *b, vs_output *outputs,
2359                   const uint8_t *param_offsets,
2360                   nir_def *(*data)[4],
2361                   nir_def *(*data_16bit_lo)[4],
2362                   nir_def *(*data_16bit_hi)[4])
2363 {
2364    unsigned num_outputs = 0;
2365    u_foreach_bit64 (slot, b->shader->info.outputs_written) {
2366       if (param_offsets[slot] > AC_EXP_PARAM_OFFSET_31)
2367          continue;
2368 
2369       nir_def **output = data[slot];
2370 
2371       /* skip output if no one written before */
2372       if (!output[0] && !output[1] && !output[2] && !output[3])
2373          continue;
2374 
2375       outputs[num_outputs].slot = slot;
2376       for (int i = 0; i < 4; i++) {
2377          outputs[num_outputs].chan[i] = output[i];
2378       }
2379       num_outputs++;
2380    }
2381 
2382    u_foreach_bit (i, b->shader->info.outputs_written_16bit) {
2383       unsigned slot = VARYING_SLOT_VAR0_16BIT + i;
2384       if (param_offsets[slot] > AC_EXP_PARAM_OFFSET_31)
2385          continue;
2386 
2387       nir_def **output_lo = data_16bit_lo[i];
2388       nir_def **output_hi = data_16bit_hi[i];
2389 
2390       /* skip output if no one written before */
2391       if (!output_lo[0] && !output_lo[1] && !output_lo[2] && !output_lo[3] &&
2392           !output_hi[0] && !output_hi[1] && !output_hi[2] && !output_hi[3])
2393          continue;
2394 
2395       vs_output *output = &outputs[num_outputs++];
2396       output->slot = slot;
2397 
2398       nir_def *undef = nir_undef(b, 1, 16);
2399       for (int j = 0; j < 4; j++) {
2400          nir_def *lo = output_lo[j] ? output_lo[j] : undef;
2401          nir_def *hi = output_hi[j] ? output_hi[j] : undef;
2402          if (output_lo[j] || output_hi[j])
2403             output->chan[j] = nir_pack_32_2x16_split(b, lo, hi);
2404          else
2405             output->chan[j] = NULL;
2406       }
2407    }
2408 
2409    return num_outputs;
2410 }
2411 
2412 static void
create_vertex_param_phis(nir_builder * b,unsigned num_outputs,vs_output * outputs)2413 create_vertex_param_phis(nir_builder *b, unsigned num_outputs, vs_output *outputs)
2414 {
2415    nir_def *undef = nir_undef(b, 1, 32); /* inserted at the start of the shader */
2416 
2417    for (unsigned i = 0; i < num_outputs; i++) {
2418       for (unsigned j = 0; j < 4; j++) {
2419          if (outputs[i].chan[j])
2420             outputs[i].chan[j] = nir_if_phi(b, outputs[i].chan[j], undef);
2421       }
2422    }
2423 }
2424 
2425 static void
export_vertex_params_gfx11(nir_builder * b,nir_def * export_tid,nir_def * num_export_threads,unsigned num_outputs,vs_output * outputs,const uint8_t * vs_output_param_offset)2426 export_vertex_params_gfx11(nir_builder *b, nir_def *export_tid, nir_def *num_export_threads,
2427                            unsigned num_outputs, vs_output *outputs,
2428                            const uint8_t *vs_output_param_offset)
2429 {
2430    nir_def *attr_rsrc = nir_load_ring_attr_amd(b);
2431 
2432    /* We should always store full vec4s in groups of 8 lanes for the best performance even if
2433     * some of them are garbage or have unused components, so align the number of export threads
2434     * to 8.
2435     */
2436    num_export_threads = nir_iand_imm(b, nir_iadd_imm(b, num_export_threads, 7), ~7);
2437    if (!export_tid)
2438       nir_push_if(b, nir_is_subgroup_invocation_lt_amd(b, num_export_threads));
2439    else
2440       nir_push_if(b, nir_ult(b, export_tid, num_export_threads));
2441 
2442    nir_def *attr_offset = nir_load_ring_attr_offset_amd(b);
2443    nir_def *vindex = nir_load_local_invocation_index(b);
2444    nir_def *voffset = nir_imm_int(b, 0);
2445    nir_def *undef = nir_undef(b, 1, 32);
2446 
2447    uint32_t exported_params = 0;
2448 
2449    for (unsigned i = 0; i < num_outputs; i++) {
2450       gl_varying_slot slot = outputs[i].slot;
2451       unsigned offset = vs_output_param_offset[slot];
2452 
2453       /* Since vs_output_param_offset[] can map multiple varying slots to
2454        * the same param export index (that's radeonsi-specific behavior),
2455        * we need to do this so as not to emit duplicated exports.
2456        */
2457       if (exported_params & BITFIELD_BIT(offset))
2458          continue;
2459 
2460       nir_def *comp[4];
2461       for (unsigned j = 0; j < 4; j++)
2462          comp[j] = outputs[i].chan[j] ? outputs[i].chan[j] : undef;
2463       nir_store_buffer_amd(b, nir_vec(b, comp, 4), attr_rsrc, voffset, attr_offset, vindex,
2464                            .base = offset * 16,
2465                            .memory_modes = nir_var_shader_out,
2466                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
2467       exported_params |= BITFIELD_BIT(offset);
2468    }
2469 
2470    nir_pop_if(b, NULL);
2471 }
2472 
must_wait_attr_ring(enum amd_gfx_level gfx_level,bool has_param_exports)2473 static bool must_wait_attr_ring(enum amd_gfx_level gfx_level, bool has_param_exports)
2474 {
2475    return (gfx_level == GFX11 || gfx_level == GFX11_5) && has_param_exports;
2476 }
2477 
2478 static void
export_pos0_wait_attr_ring(nir_builder * b,nir_if * if_es_thread,nir_def * outputs[VARYING_SLOT_MAX][4],const ac_nir_lower_ngg_options * options)2479 export_pos0_wait_attr_ring(nir_builder *b, nir_if *if_es_thread, nir_def *outputs[VARYING_SLOT_MAX][4], const ac_nir_lower_ngg_options *options)
2480 {
2481    b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2482 
2483    /* Create phi for the position output values. */
2484    vs_output pos_output = {
2485       .slot = VARYING_SLOT_POS,
2486       .chan = {
2487          outputs[VARYING_SLOT_POS][0],
2488          outputs[VARYING_SLOT_POS][1],
2489          outputs[VARYING_SLOT_POS][2],
2490          outputs[VARYING_SLOT_POS][3],
2491       },
2492    };
2493    create_vertex_param_phis(b, 1, &pos_output);
2494 
2495    b->cursor = nir_after_cf_list(&b->impl->body);
2496 
2497    /* Wait for attribute stores to finish. */
2498    nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
2499                   .memory_scope = SCOPE_DEVICE,
2500                   .memory_semantics = NIR_MEMORY_RELEASE,
2501                   .memory_modes = nir_var_mem_ssbo | nir_var_shader_out | nir_var_mem_global | nir_var_image);
2502 
2503    /* Export just the pos0 output. */
2504    nir_if *if_export_empty_pos = nir_push_if(b, if_es_thread->condition.ssa);
2505    {
2506       nir_def *pos_output_array[VARYING_SLOT_MAX][4] = {0};
2507       memcpy(pos_output_array[VARYING_SLOT_POS], pos_output.chan, sizeof(pos_output.chan));
2508 
2509       ac_nir_export_position(b, options->gfx_level,
2510                              options->clip_cull_dist_mask,
2511                              !options->has_param_exports,
2512                              options->force_vrs, true,
2513                              VARYING_BIT_POS, pos_output_array, NULL);
2514    }
2515    nir_pop_if(b, if_export_empty_pos);
2516 }
2517 
2518 static void
nogs_export_vertex_params(nir_builder * b,nir_function_impl * impl,nir_if * if_es_thread,nir_def * num_es_threads,lower_ngg_nogs_state * s)2519 nogs_export_vertex_params(nir_builder *b, nir_function_impl *impl,
2520                           nir_if *if_es_thread, nir_def *num_es_threads,
2521                           lower_ngg_nogs_state *s)
2522 {
2523    if (!s->options->has_param_exports)
2524       return;
2525 
2526    if (s->options->gfx_level >= GFX11) {
2527       /* Export varyings for GFX11+ */
2528       vs_output outputs[64];
2529       const unsigned num_outputs =
2530          gather_vs_outputs(b, outputs,
2531                            s->options->vs_output_param_offset,
2532                            s->out.outputs,
2533                            s->out.outputs_16bit_lo,
2534                            s->out.outputs_16bit_hi);
2535 
2536       if (!num_outputs)
2537          return;
2538 
2539       b->cursor = nir_after_cf_node(&if_es_thread->cf_node);
2540       create_vertex_param_phis(b, num_outputs, outputs);
2541 
2542       b->cursor = nir_after_impl(impl);
2543       if (!num_es_threads)
2544          num_es_threads = nir_load_merged_wave_info_amd(b);
2545 
2546       export_vertex_params_gfx11(b, NULL, num_es_threads, num_outputs, outputs,
2547                                  s->options->vs_output_param_offset);
2548    } else {
2549       ac_nir_export_parameters(b, s->options->vs_output_param_offset,
2550                                  b->shader->info.outputs_written,
2551                                  b->shader->info.outputs_written_16bit,
2552                                  s->out.outputs, s->out.outputs_16bit_lo,
2553                                  s->out.outputs_16bit_hi);
2554    }
2555 }
2556 
2557 void
ac_nir_lower_ngg_nogs(nir_shader * shader,const ac_nir_lower_ngg_options * options)2558 ac_nir_lower_ngg_nogs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
2559 {
2560    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
2561    assert(impl);
2562    assert(options->max_workgroup_size && options->wave_size);
2563    assert(!(options->can_cull && options->passthrough));
2564 
2565    nir_variable *position_value_var = nir_local_variable_create(impl, glsl_vec4_type(), "position_value");
2566    nir_variable *prim_exp_arg_var = nir_local_variable_create(impl, glsl_uint_type(), "prim_exp_arg");
2567    nir_variable *es_accepted_var =
2568       options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "es_accepted") : NULL;
2569    nir_variable *gs_accepted_var =
2570       options->can_cull ? nir_local_variable_create(impl, glsl_bool_type(), "gs_accepted") : NULL;
2571    nir_variable *gs_exported_var = nir_local_variable_create(impl, glsl_bool_type(), "gs_exported");
2572 
2573    bool streamout_enabled = shader->xfb_info && !options->disable_streamout;
2574    bool has_user_edgeflags =
2575       options->use_edgeflags && (shader->info.outputs_written & VARYING_BIT_EDGE);
2576    /* streamout need to be done before either prim or vertex export. Because when no
2577     * param export, rasterization can start right after prim and vertex export,
2578     * which left streamout buffer writes un-finished.
2579     *
2580     * Always use late prim export when user edge flags are enabled.
2581     * This is because edge flags are written by ES threads but they
2582     * are exported by GS threads as part of th primitive export.
2583     */
2584    bool early_prim_export =
2585       options->early_prim_export && !(streamout_enabled || has_user_edgeflags);
2586 
2587    lower_ngg_nogs_state state = {
2588       .options = options,
2589       .early_prim_export = early_prim_export,
2590       .streamout_enabled = streamout_enabled,
2591       .position_value_var = position_value_var,
2592       .prim_exp_arg_var = prim_exp_arg_var,
2593       .es_accepted_var = es_accepted_var,
2594       .gs_accepted_var = gs_accepted_var,
2595       .gs_exported_var = gs_exported_var,
2596       .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
2597       .has_user_edgeflags = has_user_edgeflags,
2598       .skip_primitive_id = streamout_enabled && options->export_primitive_id,
2599    };
2600 
2601    const bool need_prim_id_store_shared =
2602       options->export_primitive_id && shader->info.stage == MESA_SHADER_VERTEX;
2603 
2604    if (options->export_primitive_id) {
2605       shader->info.outputs_written |= VARYING_BIT_PRIMITIVE_ID;
2606    }
2607 
2608    nir_builder builder = nir_builder_create(impl);
2609    nir_builder *b = &builder; /* This is to avoid the & */
2610 
2611    if (options->can_cull) {
2612       analyze_shader_before_culling(shader, &state);
2613       save_reusable_variables(b, &state);
2614    }
2615 
2616    nir_cf_list extracted;
2617    nir_cf_extract(&extracted, nir_before_impl(impl),
2618                   nir_after_impl(impl));
2619    b->cursor = nir_before_impl(impl);
2620 
2621    ngg_nogs_init_vertex_indices_vars(b, impl, &state);
2622 
2623    /* Emit primitives generated query code here, so that
2624     * it executes before culling and isn't in the extracted CF.
2625     */
2626    nogs_prim_gen_query(b, &state);
2627 
2628    /* Whether a shader invocation should export a primitive,
2629     * initialize to all invocations that have an input primitive.
2630     */
2631    nir_store_var(b, gs_exported_var, has_input_primitive(b), 0x1u);
2632 
2633    if (!options->can_cull) {
2634       /* Newer chips can use PRIMGEN_PASSTHRU_NO_MSG to skip gs_alloc_req for NGG passthrough. */
2635       if (!(options->passthrough && options->family >= CHIP_NAVI23)) {
2636          /* Allocate export space on wave 0 - confirm to the HW that we want to use all possible space */
2637          nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
2638          {
2639             nir_def *vtx_cnt = nir_load_workgroup_num_input_vertices_amd(b);
2640             nir_def *prim_cnt = nir_load_workgroup_num_input_primitives_amd(b);
2641             alloc_vertices_and_primitives(b, vtx_cnt, prim_cnt);
2642          }
2643          nir_pop_if(b, if_wave_0);
2644       }
2645 
2646       /* Take care of early primitive export, otherwise just pack the primitive export argument */
2647       if (state.early_prim_export)
2648          emit_ngg_nogs_prim_export(b, &state, NULL);
2649       else
2650          nir_store_var(b, prim_exp_arg_var, emit_ngg_nogs_prim_exp_arg(b, &state), 0x1u);
2651    } else {
2652       add_deferred_attribute_culling(b, &extracted, &state);
2653       b->cursor = nir_after_impl(impl);
2654 
2655       if (state.early_prim_export)
2656          emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, state.prim_exp_arg_var));
2657 
2658       /* Wait for culling to finish using LDS. */
2659       if (need_prim_id_store_shared || has_user_edgeflags) {
2660          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
2661                                .memory_scope = SCOPE_WORKGROUP,
2662                                .memory_semantics = NIR_MEMORY_ACQ_REL,
2663                                .memory_modes = nir_var_mem_shared);
2664       }
2665    }
2666 
2667    /* determine the LDS vertex stride */
2668    state.pervertex_lds_bytes =
2669       ngg_nogs_get_pervertex_lds_size(shader->info.stage,
2670                                       shader->num_outputs,
2671                                       state.streamout_enabled,
2672                                       options->export_primitive_id,
2673                                       state.has_user_edgeflags);
2674 
2675    if (need_prim_id_store_shared) {
2676       emit_ngg_nogs_prim_id_store_shared(b, &state);
2677 
2678       /* Wait for GS threads to store primitive ID in LDS. */
2679       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
2680                             .memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
2681    }
2682 
2683    nir_def *es_thread =
2684       options->can_cull ? nir_load_var(b, es_accepted_var) : has_input_vertex(b);
2685 
2686    /* Calculate the bit count here instead of below for lower SGPR usage and better ALU
2687     * scheduling.
2688     */
2689    nir_def *num_es_threads = NULL;
2690    if (state.options->gfx_level >= GFX11 && options->can_cull) {
2691       nir_def *es_accepted_mask =
2692          nir_ballot(b, 1, options->wave_size, nir_load_var(b, es_accepted_var));
2693       num_es_threads = nir_bit_count(b, es_accepted_mask);
2694    }
2695 
2696    nir_if *if_es_thread = nir_push_if(b, es_thread);
2697    {
2698       /* Run the actual shader */
2699       nir_cf_reinsert(&extracted, b->cursor);
2700       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2701 
2702       if (options->export_primitive_id)
2703          emit_store_ngg_nogs_es_primitive_id(b, &state);
2704    }
2705    nir_pop_if(b, if_es_thread);
2706 
2707    if (options->can_cull) {
2708       /* Replace uniforms. */
2709       apply_reusable_variables(b, &state);
2710 
2711       /* Remove the redundant position output. */
2712       remove_extra_pos_outputs(shader, &state);
2713 
2714       /* After looking at the performance in apps eg. Doom Eternal, and The Witcher 3,
2715        * it seems that it's best to put the position export always at the end, and
2716        * then let ACO schedule it up (slightly) only when early prim export is used.
2717        */
2718       b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2719 
2720       nir_def *pos_val = nir_load_var(b, state.position_value_var);
2721       for (int i = 0; i < 4; i++)
2722          state.out.outputs[VARYING_SLOT_POS][i] = nir_channel(b, pos_val, i);
2723    }
2724 
2725    /* Gather outputs data and types */
2726    ngg_nogs_gather_outputs(b, &if_es_thread->then_list, &state);
2727    b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2728 
2729    if (state.has_user_edgeflags)
2730       ngg_nogs_store_edgeflag_to_lds(b, &state);
2731 
2732    if (state.streamout_enabled) {
2733       /* TODO: support culling after streamout. */
2734       assert(!options->can_cull);
2735 
2736       ngg_nogs_store_xfb_outputs_to_lds(b, &state);
2737 
2738       b->cursor = nir_after_impl(impl);
2739       ngg_nogs_build_streamout(b, &state);
2740    }
2741 
2742    /* Take care of late primitive export */
2743    if (!state.early_prim_export) {
2744       b->cursor = nir_after_impl(impl);
2745       emit_ngg_nogs_prim_export(b, &state, nir_load_var(b, prim_exp_arg_var));
2746    }
2747 
2748    uint64_t export_outputs = shader->info.outputs_written | VARYING_BIT_POS;
2749    if (options->kill_pointsize)
2750       export_outputs &= ~VARYING_BIT_PSIZ;
2751    if (options->kill_layer)
2752       export_outputs &= ~VARYING_BIT_LAYER;
2753 
2754    const bool wait_attr_ring = must_wait_attr_ring(options->gfx_level, options->has_param_exports);
2755    if (wait_attr_ring)
2756       export_outputs &= ~VARYING_BIT_POS;
2757 
2758    b->cursor = nir_after_cf_list(&if_es_thread->then_list);
2759 
2760    ac_nir_export_position(b, options->gfx_level,
2761                           options->clip_cull_dist_mask,
2762                           !options->has_param_exports,
2763                           options->force_vrs, !wait_attr_ring,
2764                           export_outputs, state.out.outputs, NULL);
2765 
2766    nogs_export_vertex_params(b, impl, if_es_thread, num_es_threads, &state);
2767 
2768    if (wait_attr_ring)
2769       export_pos0_wait_attr_ring(b, if_es_thread, state.out.outputs, options);
2770 
2771    nir_metadata_preserve(impl, nir_metadata_none);
2772    nir_validate_shader(shader, "after emitting NGG VS/TES");
2773 
2774    /* Cleanup */
2775    nir_opt_dead_write_vars(shader);
2776    nir_lower_vars_to_ssa(shader);
2777    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
2778    nir_lower_alu_to_scalar(shader, NULL, NULL);
2779    nir_lower_phis_to_scalar(shader, true);
2780 
2781    if (options->can_cull) {
2782       /* It's beneficial to redo these opts after splitting the shader. */
2783       nir_opt_sink(shader, nir_move_load_input | nir_move_const_undef | nir_move_copies);
2784       nir_opt_move(shader, nir_move_load_input | nir_move_copies | nir_move_const_undef);
2785    }
2786 
2787    bool progress;
2788    do {
2789       progress = false;
2790       NIR_PASS(progress, shader, nir_opt_undef);
2791       NIR_PASS(progress, shader, nir_opt_dce);
2792       NIR_PASS(progress, shader, nir_opt_dead_cf);
2793 
2794       if (options->can_cull)
2795          progress |= cleanup_culling_shader_after_dce(shader, b->impl, &state);
2796    } while (progress);
2797 }
2798 
2799 /**
2800  * Return the address of the LDS storage reserved for the N'th vertex,
2801  * where N is in emit order, meaning:
2802  * - during the finale, N is the invocation_index (within the workgroup)
2803  * - during vertex emit, i.e. while the API GS shader invocation is running,
2804  *   N = invocation_index * gs_max_out_vertices + emit_idx
2805  *   where emit_idx is the vertex index in the current API GS invocation.
2806  *
2807  * Goals of the LDS memory layout:
2808  * 1. Eliminate bank conflicts on write for geometry shaders that have all emits
2809  *    in uniform control flow
2810  * 2. Eliminate bank conflicts on read for export if, additionally, there is no
2811  *    culling
2812  * 3. Agnostic to the number of waves (since we don't know it before compiling)
2813  * 4. Allow coalescing of LDS instructions (ds_write_b128 etc.)
2814  * 5. Avoid wasting memory.
2815  *
2816  * We use an AoS layout due to point 4 (this also helps point 3). In an AoS
2817  * layout, elimination of bank conflicts requires that each vertex occupy an
2818  * odd number of dwords. We use the additional dword to store the output stream
2819  * index as well as a flag to indicate whether this vertex ends a primitive
2820  * for rasterization.
2821  *
2822  * Swizzling is required to satisfy points 1 and 2 simultaneously.
2823  *
2824  * Vertices are stored in export order (gsthread * gs_max_out_vertices + emitidx).
2825  * Indices are swizzled in groups of 32, which ensures point 1 without
2826  * disturbing point 2.
2827  *
2828  * \return an LDS pointer to type {[N x i32], [4 x i8]}
2829  */
2830 static nir_def *
ngg_gs_out_vertex_addr(nir_builder * b,nir_def * out_vtx_idx,lower_ngg_gs_state * s)2831 ngg_gs_out_vertex_addr(nir_builder *b, nir_def *out_vtx_idx, lower_ngg_gs_state *s)
2832 {
2833    unsigned write_stride_2exp = ffs(MAX2(b->shader->info.gs.vertices_out, 1)) - 1;
2834 
2835    /* gs_max_out_vertices = 2^(write_stride_2exp) * some odd number */
2836    if (write_stride_2exp) {
2837       nir_def *row = nir_ushr_imm(b, out_vtx_idx, 5);
2838       nir_def *swizzle = nir_iand_imm(b, row, (1u << write_stride_2exp) - 1u);
2839       out_vtx_idx = nir_ixor(b, out_vtx_idx, swizzle);
2840    }
2841 
2842    nir_def *out_vtx_offs = nir_imul_imm(b, out_vtx_idx, s->lds_bytes_per_gs_out_vertex);
2843    return nir_iadd_nuw(b, out_vtx_offs, s->lds_addr_gs_out_vtx);
2844 }
2845 
2846 static nir_def *
ngg_gs_emit_vertex_addr(nir_builder * b,nir_def * gs_vtx_idx,lower_ngg_gs_state * s)2847 ngg_gs_emit_vertex_addr(nir_builder *b, nir_def *gs_vtx_idx, lower_ngg_gs_state *s)
2848 {
2849    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
2850    nir_def *gs_out_vtx_base = nir_imul_imm(b, tid_in_tg, b->shader->info.gs.vertices_out);
2851    nir_def *out_vtx_idx = nir_iadd_nuw(b, gs_out_vtx_base, gs_vtx_idx);
2852 
2853    return ngg_gs_out_vertex_addr(b, out_vtx_idx, s);
2854 }
2855 
2856 static void
ngg_gs_clear_primflags(nir_builder * b,nir_def * num_vertices,unsigned stream,lower_ngg_gs_state * s)2857 ngg_gs_clear_primflags(nir_builder *b, nir_def *num_vertices, unsigned stream, lower_ngg_gs_state *s)
2858 {
2859    char name[32];
2860    snprintf(name, sizeof(name), "clear_primflag_idx_%u", stream);
2861    nir_variable *clear_primflag_idx_var = nir_local_variable_create(b->impl, glsl_uint_type(), name);
2862 
2863    nir_def *zero_u8 = nir_imm_zero(b, 1, 8);
2864    nir_store_var(b, clear_primflag_idx_var, num_vertices, 0x1u);
2865 
2866    nir_loop *loop = nir_push_loop(b);
2867    {
2868       nir_def *clear_primflag_idx = nir_load_var(b, clear_primflag_idx_var);
2869       nir_if *if_break = nir_push_if(b, nir_uge_imm(b, clear_primflag_idx, b->shader->info.gs.vertices_out));
2870       {
2871          nir_jump(b, nir_jump_break);
2872       }
2873       nir_push_else(b, if_break);
2874       {
2875          nir_def *emit_vtx_addr = ngg_gs_emit_vertex_addr(b, clear_primflag_idx, s);
2876          nir_store_shared(b, zero_u8, emit_vtx_addr, .base = s->lds_offs_primflags + stream);
2877          nir_store_var(b, clear_primflag_idx_var, nir_iadd_imm_nuw(b, clear_primflag_idx, 1), 0x1u);
2878       }
2879       nir_pop_if(b, if_break);
2880    }
2881    nir_pop_loop(b, loop);
2882 }
2883 
2884 static bool
lower_ngg_gs_store_output(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2885 lower_ngg_gs_store_output(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2886 {
2887    ac_nir_gather_prerast_store_output_info(b, intrin, &s->out);
2888    nir_instr_remove(&intrin->instr);
2889    return true;
2890 }
2891 
2892 static unsigned
gs_output_component_mask_with_stream(ac_nir_prerast_per_output_info * info,unsigned stream)2893 gs_output_component_mask_with_stream(ac_nir_prerast_per_output_info *info, unsigned stream)
2894 {
2895    unsigned mask = info->components_mask;
2896    if (!mask)
2897       return 0;
2898 
2899    /* clear component when not requested stream */
2900    for (int i = 0; i < 4; i++) {
2901       if (((info->stream >> (i * 2)) & 3) != stream)
2902          mask &= ~(1 << i);
2903    }
2904 
2905    return mask;
2906 }
2907 
2908 static bool
lower_ngg_gs_emit_vertex_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)2909 lower_ngg_gs_emit_vertex_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
2910 {
2911    b->cursor = nir_before_instr(&intrin->instr);
2912 
2913    unsigned stream = nir_intrinsic_stream_id(intrin);
2914    if (!(b->shader->info.gs.active_stream_mask & (1 << stream))) {
2915       nir_instr_remove(&intrin->instr);
2916       return true;
2917    }
2918 
2919    nir_def *gs_emit_vtx_idx = intrin->src[0].ssa;
2920    nir_def *current_vtx_per_prim = intrin->src[1].ssa;
2921    nir_def *gs_emit_vtx_addr = ngg_gs_emit_vertex_addr(b, gs_emit_vtx_idx, s);
2922 
2923    /* Store generic 32-bit outputs to LDS.
2924     * In case of packed 16-bit, we assume that has been already packed into 32 bit slots by now.
2925     */
2926    u_foreach_bit64(slot, b->shader->info.outputs_written) {
2927       const unsigned packed_location = util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
2928       unsigned mask = gs_output_component_mask_with_stream(&s->out.infos[slot], stream);
2929 
2930       nir_def **output = s->out.outputs[slot];
2931       nir_def *undef = nir_undef(b, 1, 32);
2932 
2933       while (mask) {
2934          int start, count;
2935          u_bit_scan_consecutive_range(&mask, &start, &count);
2936          nir_def *values[4] = {0};
2937          for (int c = start; c < start + count; ++c) {
2938             if (!output[c]) {
2939                /* The shader hasn't written this output. */
2940                values[c - start] = undef;
2941             } else {
2942                assert(output[c]->bit_size == 32);
2943                values[c - start] = output[c];
2944             }
2945          }
2946 
2947          nir_def *store_val = nir_vec(b, values, (unsigned)count);
2948          nir_store_shared(b, store_val, gs_emit_vtx_addr,
2949                           .base = packed_location * 16 + start * 4,
2950                           .align_mul = 4);
2951       }
2952 
2953       /* Clear all outputs (they are undefined after emit_vertex) */
2954       memset(s->out.outputs[slot], 0, sizeof(s->out.outputs[slot]));
2955    }
2956 
2957    const unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
2958 
2959    /* Store dedicated 16-bit outputs to LDS. */
2960    u_foreach_bit(slot, b->shader->info.outputs_written_16bit) {
2961       const unsigned packed_location = num_32bit_outputs +
2962          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(slot));
2963 
2964       const unsigned mask_lo = gs_output_component_mask_with_stream(s->out.infos_16bit_lo + slot, stream);
2965       const unsigned mask_hi = gs_output_component_mask_with_stream(s->out.infos_16bit_hi + slot, stream);
2966       unsigned mask = mask_lo | mask_hi;
2967 
2968       nir_def **output_lo = s->out.outputs_16bit_lo[slot];
2969       nir_def **output_hi = s->out.outputs_16bit_hi[slot];
2970       nir_def *undef = nir_undef(b, 1, 16);
2971 
2972       while (mask) {
2973          int start, count;
2974          u_bit_scan_consecutive_range(&mask, &start, &count);
2975          nir_def *values[4] = {0};
2976          for (int c = start; c < start + count; ++c) {
2977             nir_def *lo = output_lo[c] ? output_lo[c] : undef;
2978             nir_def *hi = output_hi[c] ? output_hi[c] : undef;
2979 
2980             values[c - start] = nir_pack_32_2x16_split(b, lo, hi);
2981          }
2982 
2983          nir_def *store_val = nir_vec(b, values, (unsigned)count);
2984          nir_store_shared(b, store_val, gs_emit_vtx_addr,
2985                           .base = packed_location * 16 + start * 4,
2986                           .align_mul = 4);
2987       }
2988 
2989       /* Clear all outputs (they are undefined after emit_vertex) */
2990       memset(s->out.outputs_16bit_lo[slot], 0, sizeof(s->out.outputs_16bit_lo[slot]));
2991       memset(s->out.outputs_16bit_hi[slot], 0, sizeof(s->out.outputs_16bit_hi[slot]));
2992    }
2993 
2994    /* Calculate and store per-vertex primitive flags based on vertex counts:
2995     * - bit 0: whether this vertex finishes a primitive (a real primitive, not the strip)
2996     * - bit 1: whether the primitive index is odd (if we are emitting triangle strips, otherwise always 0)
2997     *          only set when the vertex also finishes the primitive
2998     * - bit 2: whether vertex is live (if culling is enabled: set after culling, otherwise always 1)
2999     */
3000 
3001    nir_def *vertex_live_flag =
3002       !stream && s->options->can_cull
3003          ? nir_ishl_imm(b, nir_b2i32(b, nir_inot(b, nir_load_cull_any_enabled_amd(b))), 2)
3004          : nir_imm_int(b, 0b100);
3005 
3006    nir_def *completes_prim = nir_ige_imm(b, current_vtx_per_prim, s->num_vertices_per_primitive - 1);
3007    nir_def *complete_flag = nir_b2i32(b, completes_prim);
3008 
3009    nir_def *prim_flag = nir_ior(b, vertex_live_flag, complete_flag);
3010    if (s->num_vertices_per_primitive == 3) {
3011       nir_def *odd = nir_iand(b, current_vtx_per_prim, complete_flag);
3012       nir_def *odd_flag = nir_ishl_imm(b, odd, 1);
3013       prim_flag = nir_ior(b, prim_flag, odd_flag);
3014    }
3015 
3016    nir_store_shared(b, nir_u2u8(b, prim_flag), gs_emit_vtx_addr,
3017                     .base = s->lds_offs_primflags + stream,
3018                     .align_mul = 4, .align_offset = stream);
3019 
3020    nir_instr_remove(&intrin->instr);
3021    return true;
3022 }
3023 
3024 static bool
lower_ngg_gs_end_primitive_with_counter(nir_builder * b,nir_intrinsic_instr * intrin,UNUSED lower_ngg_gs_state * s)3025 lower_ngg_gs_end_primitive_with_counter(nir_builder *b, nir_intrinsic_instr *intrin, UNUSED lower_ngg_gs_state *s)
3026 {
3027    b->cursor = nir_before_instr(&intrin->instr);
3028 
3029    /* These are not needed, we can simply remove them */
3030    nir_instr_remove(&intrin->instr);
3031    return true;
3032 }
3033 
3034 static bool
lower_ngg_gs_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_gs_state * s)3035 lower_ngg_gs_set_vertex_and_primitive_count(nir_builder *b, nir_intrinsic_instr *intrin, lower_ngg_gs_state *s)
3036 {
3037    b->cursor = nir_before_instr(&intrin->instr);
3038 
3039    unsigned stream = nir_intrinsic_stream_id(intrin);
3040    if (stream > 0 && !(b->shader->info.gs.active_stream_mask & (1 << stream))) {
3041       nir_instr_remove(&intrin->instr);
3042       return true;
3043    }
3044 
3045    s->vertex_count[stream] = intrin->src[0].ssa;
3046    s->primitive_count[stream] = intrin->src[1].ssa;
3047 
3048    /* Clear the primitive flags of non-emitted vertices */
3049    if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < b->shader->info.gs.vertices_out)
3050       ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
3051 
3052    nir_instr_remove(&intrin->instr);
3053    return true;
3054 }
3055 
3056 static bool
lower_ngg_gs_intrinsic(nir_builder * b,nir_instr * instr,void * state)3057 lower_ngg_gs_intrinsic(nir_builder *b, nir_instr *instr, void *state)
3058 {
3059    lower_ngg_gs_state *s = (lower_ngg_gs_state *) state;
3060 
3061    if (instr->type != nir_instr_type_intrinsic)
3062       return false;
3063 
3064    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
3065 
3066    if (intrin->intrinsic == nir_intrinsic_store_output)
3067       return lower_ngg_gs_store_output(b, intrin, s);
3068    else if (intrin->intrinsic == nir_intrinsic_emit_vertex_with_counter)
3069       return lower_ngg_gs_emit_vertex_with_counter(b, intrin, s);
3070    else if (intrin->intrinsic == nir_intrinsic_end_primitive_with_counter)
3071       return lower_ngg_gs_end_primitive_with_counter(b, intrin, s);
3072    else if (intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count)
3073       return lower_ngg_gs_set_vertex_and_primitive_count(b, intrin, s);
3074 
3075    return false;
3076 }
3077 
3078 static void
lower_ngg_gs_intrinsics(nir_shader * shader,lower_ngg_gs_state * s)3079 lower_ngg_gs_intrinsics(nir_shader *shader, lower_ngg_gs_state *s)
3080 {
3081    nir_shader_instructions_pass(shader, lower_ngg_gs_intrinsic, nir_metadata_none, s);
3082 }
3083 
3084 static void
ngg_gs_export_primitives(nir_builder * b,nir_def * max_num_out_prims,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,nir_def * primflag_0,lower_ngg_gs_state * s)3085 ngg_gs_export_primitives(nir_builder *b, nir_def *max_num_out_prims, nir_def *tid_in_tg,
3086                          nir_def *exporter_tid_in_tg, nir_def *primflag_0,
3087                          lower_ngg_gs_state *s)
3088 {
3089    nir_if *if_prim_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_prims));
3090 
3091    /* Only bit 0 matters here - set it to 1 when the primitive should be null */
3092    nir_def *is_null_prim = nir_ixor(b, primflag_0, nir_imm_int(b, -1u));
3093 
3094    nir_def *vtx_indices[3] = {0};
3095    vtx_indices[s->num_vertices_per_primitive - 1] = exporter_tid_in_tg;
3096    if (s->num_vertices_per_primitive >= 2)
3097       vtx_indices[s->num_vertices_per_primitive - 2] = nir_iadd_imm(b, exporter_tid_in_tg, -1);
3098    if (s->num_vertices_per_primitive == 3)
3099       vtx_indices[s->num_vertices_per_primitive - 3] = nir_iadd_imm(b, exporter_tid_in_tg, -2);
3100 
3101    if (s->num_vertices_per_primitive == 3) {
3102       /* API GS outputs triangle strips, but NGG HW understands triangles.
3103        * We already know the triangles due to how we set the primitive flags, but we need to
3104        * make sure the vertex order is so that the front/back is correct, and the provoking vertex is kept.
3105        */
3106 
3107       nir_def *is_odd = nir_ubfe_imm(b, primflag_0, 1, 1);
3108       nir_def *provoking_vertex_index = nir_load_provoking_vtx_in_prim_amd(b);
3109       nir_def *provoking_vertex_first = nir_ieq_imm(b, provoking_vertex_index, 0);
3110 
3111       vtx_indices[0] = nir_bcsel(b, provoking_vertex_first, vtx_indices[0],
3112                                  nir_iadd(b, vtx_indices[0], is_odd));
3113       vtx_indices[1] = nir_bcsel(b, provoking_vertex_first,
3114                                  nir_iadd(b, vtx_indices[1], is_odd),
3115                                  nir_isub(b, vtx_indices[1], is_odd));
3116       vtx_indices[2] = nir_bcsel(b, provoking_vertex_first,
3117                                  nir_isub(b, vtx_indices[2], is_odd), vtx_indices[2]);
3118    }
3119 
3120    nir_def *arg = emit_pack_ngg_prim_exp_arg(b, s->num_vertices_per_primitive, vtx_indices,
3121                                              is_null_prim, s->options->gfx_level);
3122    ac_nir_export_primitive(b, arg, NULL);
3123    nir_pop_if(b, if_prim_export_thread);
3124 }
3125 
3126 static void
ngg_gs_export_vertices(nir_builder * b,nir_def * max_num_out_vtx,nir_def * tid_in_tg,nir_def * out_vtx_lds_addr,lower_ngg_gs_state * s)3127 ngg_gs_export_vertices(nir_builder *b, nir_def *max_num_out_vtx, nir_def *tid_in_tg,
3128                        nir_def *out_vtx_lds_addr, lower_ngg_gs_state *s)
3129 {
3130    nir_if *if_vtx_export_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
3131    nir_def *exported_out_vtx_lds_addr = out_vtx_lds_addr;
3132 
3133    if (!s->output_compile_time_known) {
3134       /* Vertex compaction.
3135        * The current thread will export a vertex that was live in another invocation.
3136        * Load the index of the vertex that the current thread will have to export.
3137        */
3138       nir_def *exported_vtx_idx = nir_load_shared(b, 1, 8, out_vtx_lds_addr, .base = s->lds_offs_primflags + 1);
3139       exported_out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, nir_u2u32(b, exported_vtx_idx), s);
3140    }
3141 
3142    u_foreach_bit64(slot, b->shader->info.outputs_written) {
3143       const unsigned packed_location =
3144          util_bitcount64((b->shader->info.outputs_written & BITFIELD64_MASK(slot)));
3145 
3146       unsigned mask = gs_output_component_mask_with_stream(&s->out.infos[slot], 0);
3147 
3148       while (mask) {
3149          int start, count;
3150          u_bit_scan_consecutive_range(&mask, &start, &count);
3151          nir_def *load =
3152             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
3153                             .base = packed_location * 16 + start * 4,
3154                             .align_mul = 4);
3155 
3156          for (int i = 0; i < count; i++)
3157             s->out.outputs[slot][start + i] = nir_channel(b, load, i);
3158       }
3159    }
3160 
3161    const unsigned num_32bit_outputs = util_bitcount64(b->shader->info.outputs_written);
3162 
3163    /* Dedicated 16-bit outputs. */
3164    u_foreach_bit(i, b->shader->info.outputs_written_16bit) {
3165       const unsigned packed_location = num_32bit_outputs +
3166          util_bitcount(b->shader->info.outputs_written_16bit & BITFIELD_MASK(i));
3167 
3168       const unsigned mask_lo = gs_output_component_mask_with_stream(&s->out.infos_16bit_lo[i], 0);
3169       const unsigned mask_hi = gs_output_component_mask_with_stream(&s->out.infos_16bit_hi[i], 0);
3170       unsigned mask = mask_lo | mask_hi;
3171 
3172       while (mask) {
3173          int start, count;
3174          u_bit_scan_consecutive_range(&mask, &start, &count);
3175          nir_def *load =
3176             nir_load_shared(b, count, 32, exported_out_vtx_lds_addr,
3177                             .base = packed_location * 16 + start * 4,
3178                             .align_mul = 4);
3179 
3180          for (int j = 0; j < count; j++) {
3181             nir_def *val = nir_channel(b, load, j);
3182             unsigned comp = start + j;
3183 
3184             if (mask_lo & BITFIELD_BIT(comp))
3185                s->out.outputs_16bit_lo[i][comp] = nir_unpack_32_2x16_split_x(b, val);
3186 
3187             if (mask_hi & BITFIELD_BIT(comp))
3188                s->out.outputs_16bit_hi[i][comp] = nir_unpack_32_2x16_split_y(b, val);
3189          }
3190       }
3191    }
3192 
3193    uint64_t export_outputs = b->shader->info.outputs_written | VARYING_BIT_POS;
3194    if (s->options->kill_pointsize)
3195       export_outputs &= ~VARYING_BIT_PSIZ;
3196    if (s->options->kill_layer)
3197       export_outputs &= ~VARYING_BIT_LAYER;
3198 
3199    const bool wait_attr_ring = must_wait_attr_ring(s->options->gfx_level, s->options->has_param_exports);
3200    if (wait_attr_ring)
3201       export_outputs &= ~VARYING_BIT_POS;
3202 
3203    ac_nir_export_position(b, s->options->gfx_level,
3204                           s->options->clip_cull_dist_mask,
3205                           !s->options->has_param_exports,
3206                           s->options->force_vrs, !wait_attr_ring,
3207                           export_outputs, s->out.outputs, NULL);
3208 
3209    nir_pop_if(b, if_vtx_export_thread);
3210 
3211    if (s->options->has_param_exports) {
3212       b->cursor = nir_after_cf_list(&if_vtx_export_thread->then_list);
3213 
3214       if (s->options->gfx_level >= GFX11) {
3215          vs_output outputs[64];
3216          unsigned num_outputs = gather_vs_outputs(b, outputs,
3217                                                   s->options->vs_output_param_offset,
3218                                                   s->out.outputs, s->out.outputs_16bit_lo,
3219                                                   s->out.outputs_16bit_hi);
3220 
3221          if (num_outputs) {
3222             b->cursor = nir_after_impl(s->impl);
3223             create_vertex_param_phis(b, num_outputs, outputs);
3224 
3225             export_vertex_params_gfx11(b, tid_in_tg, max_num_out_vtx, num_outputs, outputs,
3226                                        s->options->vs_output_param_offset);
3227          }
3228       } else {
3229          ac_nir_export_parameters(b, s->options->vs_output_param_offset,
3230                                   b->shader->info.outputs_written,
3231                                   b->shader->info.outputs_written_16bit,
3232                                   s->out.outputs, s->out.outputs_16bit_lo,
3233                                   s->out.outputs_16bit_hi);
3234       }
3235    }
3236 
3237    if (wait_attr_ring)
3238       export_pos0_wait_attr_ring(b, if_vtx_export_thread, s->out.outputs, s->options);
3239 }
3240 
3241 static void
ngg_gs_setup_vertex_compaction(nir_builder * b,nir_def * vertex_live,nir_def * tid_in_tg,nir_def * exporter_tid_in_tg,lower_ngg_gs_state * s)3242 ngg_gs_setup_vertex_compaction(nir_builder *b, nir_def *vertex_live, nir_def *tid_in_tg,
3243                                nir_def *exporter_tid_in_tg, lower_ngg_gs_state *s)
3244 {
3245    assert(vertex_live->bit_size == 1);
3246    nir_if *if_vertex_live = nir_push_if(b, vertex_live);
3247    {
3248       /* Setup the vertex compaction.
3249        * Save the current thread's id for the thread which will export the current vertex.
3250        * We reuse stream 1 of the primitive flag of the other thread's vertex for storing this.
3251        */
3252 
3253       nir_def *exporter_lds_addr = ngg_gs_out_vertex_addr(b, exporter_tid_in_tg, s);
3254       nir_def *tid_in_tg_u8 = nir_u2u8(b, tid_in_tg);
3255       nir_store_shared(b, tid_in_tg_u8, exporter_lds_addr, .base = s->lds_offs_primflags + 1);
3256    }
3257    nir_pop_if(b, if_vertex_live);
3258 }
3259 
3260 static nir_def *
ngg_gs_load_out_vtx_primflag(nir_builder * b,unsigned stream,nir_def * tid_in_tg,nir_def * vtx_lds_addr,nir_def * max_num_out_vtx,lower_ngg_gs_state * s)3261 ngg_gs_load_out_vtx_primflag(nir_builder *b, unsigned stream, nir_def *tid_in_tg,
3262                              nir_def *vtx_lds_addr, nir_def *max_num_out_vtx,
3263                              lower_ngg_gs_state *s)
3264 {
3265    nir_def *zero = nir_imm_int(b, 0);
3266 
3267    nir_if *if_outvtx_thread = nir_push_if(b, nir_ilt(b, tid_in_tg, max_num_out_vtx));
3268    nir_def *primflag = nir_load_shared(b, 1, 8, vtx_lds_addr,
3269                                            .base = s->lds_offs_primflags + stream);
3270    primflag = nir_u2u32(b, primflag);
3271    nir_pop_if(b, if_outvtx_thread);
3272 
3273    return nir_if_phi(b, primflag, zero);
3274 }
3275 
3276 static void
ngg_gs_out_prim_all_vtxptr(nir_builder * b,nir_def * last_vtxidx,nir_def * last_vtxptr,nir_def * last_vtx_primflag,lower_ngg_gs_state * s,nir_def * vtxptr[3])3277 ngg_gs_out_prim_all_vtxptr(nir_builder *b, nir_def *last_vtxidx, nir_def *last_vtxptr,
3278                            nir_def *last_vtx_primflag, lower_ngg_gs_state *s,
3279                            nir_def *vtxptr[3])
3280 {
3281    unsigned last_vtx = s->num_vertices_per_primitive - 1;
3282    vtxptr[last_vtx]= last_vtxptr;
3283 
3284    bool primitive_is_triangle = s->num_vertices_per_primitive == 3;
3285    nir_def *is_odd = primitive_is_triangle ?
3286       nir_ubfe_imm(b, last_vtx_primflag, 1, 1) : NULL;
3287 
3288    for (unsigned i = 0; i < s->num_vertices_per_primitive - 1; i++) {
3289       nir_def *vtxidx = nir_iadd_imm(b, last_vtxidx, -(last_vtx - i));
3290 
3291       /* Need to swap vertex 0 and vertex 1 when vertex 2 index is odd to keep
3292        * CW/CCW order for correct front/back face culling.
3293        */
3294       if (primitive_is_triangle)
3295          vtxidx = i == 0 ? nir_iadd(b, vtxidx, is_odd) : nir_isub(b, vtxidx, is_odd);
3296 
3297       vtxptr[i] = ngg_gs_out_vertex_addr(b, vtxidx, s);
3298    }
3299 }
3300 
3301 static nir_def *
ngg_gs_cull_primitive(nir_builder * b,nir_def * tid_in_tg,nir_def * max_vtxcnt,nir_def * out_vtx_lds_addr,nir_def * out_vtx_primflag_0,lower_ngg_gs_state * s)3302 ngg_gs_cull_primitive(nir_builder *b, nir_def *tid_in_tg, nir_def *max_vtxcnt,
3303                       nir_def *out_vtx_lds_addr, nir_def *out_vtx_primflag_0,
3304                       lower_ngg_gs_state *s)
3305 {
3306    /* we haven't enabled point culling, if enabled this function could be further optimized */
3307    assert(s->num_vertices_per_primitive > 1);
3308 
3309    /* save the primflag so that we don't need to load it from LDS again */
3310    nir_variable *primflag_var = nir_local_variable_create(s->impl, glsl_uint_type(), "primflag");
3311    nir_store_var(b, primflag_var, out_vtx_primflag_0, 1);
3312 
3313    /* last bit of primflag indicate if this is the final vertex of a primitive */
3314    nir_def *is_end_prim_vtx = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag_0, 1));
3315    nir_def *has_output_vertex = nir_ilt(b, tid_in_tg, max_vtxcnt);
3316    nir_def *prim_enable = nir_iand(b, is_end_prim_vtx, has_output_vertex);
3317 
3318    nir_if *if_prim_enable = nir_push_if(b, prim_enable);
3319    {
3320       /* Calculate the LDS address of every vertex in the current primitive. */
3321       nir_def *vtxptr[3];
3322       ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, out_vtx_primflag_0, s, vtxptr);
3323 
3324       /* Load the positions from LDS. */
3325       nir_def *pos[3][4];
3326       for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
3327          /* VARYING_SLOT_POS == 0, so base won't count packed location */
3328          pos[i][3] = nir_load_shared(b, 1, 32, vtxptr[i], .base = 12); /* W */
3329          nir_def *xy = nir_load_shared(b, 2, 32, vtxptr[i], .base = 0, .align_mul = 4);
3330          pos[i][0] = nir_channel(b, xy, 0);
3331          pos[i][1] = nir_channel(b, xy, 1);
3332 
3333          pos[i][0] = nir_fdiv(b, pos[i][0], pos[i][3]);
3334          pos[i][1] = nir_fdiv(b, pos[i][1], pos[i][3]);
3335       }
3336 
3337       /* TODO: support clipdist culling in GS */
3338       nir_def *accepted_by_clipdist = nir_imm_true(b);
3339 
3340       nir_def *accepted = ac_nir_cull_primitive(
3341          b, accepted_by_clipdist, pos, s->num_vertices_per_primitive, NULL, NULL);
3342 
3343       nir_if *if_rejected = nir_push_if(b, nir_inot(b, accepted));
3344       {
3345          /* clear the primflag if rejected */
3346          nir_store_shared(b, nir_imm_zero(b, 1, 8), out_vtx_lds_addr,
3347                           .base = s->lds_offs_primflags);
3348 
3349          nir_store_var(b, primflag_var, nir_imm_int(b, 0), 1);
3350       }
3351       nir_pop_if(b, if_rejected);
3352    }
3353    nir_pop_if(b, if_prim_enable);
3354 
3355    /* Wait for LDS primflag access done. */
3356    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
3357                          .memory_scope = SCOPE_WORKGROUP,
3358                          .memory_semantics = NIR_MEMORY_ACQ_REL,
3359                          .memory_modes = nir_var_mem_shared);
3360 
3361    /* only dead vertex need a chance to relive */
3362    nir_def *vtx_is_dead = nir_ieq_imm(b, nir_load_var(b, primflag_var), 0);
3363    nir_def *vtx_update_primflag = nir_iand(b, vtx_is_dead, has_output_vertex);
3364    nir_if *if_update_primflag = nir_push_if(b, vtx_update_primflag);
3365    {
3366       /* get succeeding vertices' primflag to detect this vertex's liveness */
3367       for (unsigned i = 1; i < s->num_vertices_per_primitive; i++) {
3368          nir_def *vtxidx = nir_iadd_imm(b, tid_in_tg, i);
3369          nir_def *not_overflow = nir_ilt(b, vtxidx, max_vtxcnt);
3370          nir_if *if_not_overflow = nir_push_if(b, not_overflow);
3371          {
3372             nir_def *vtxptr = ngg_gs_out_vertex_addr(b, vtxidx, s);
3373             nir_def *vtx_primflag =
3374                nir_load_shared(b, 1, 8, vtxptr, .base = s->lds_offs_primflags);
3375             vtx_primflag = nir_u2u32(b, vtx_primflag);
3376 
3377             /* if succeeding vertex is alive end of primitive vertex, need to set current
3378              * thread vertex's liveness flag (bit 2)
3379              */
3380             nir_def *has_prim = nir_i2b(b, nir_iand_imm(b, vtx_primflag, 1));
3381             nir_def *vtx_live_flag =
3382                nir_bcsel(b, has_prim, nir_imm_int(b, 0b100), nir_imm_int(b, 0));
3383 
3384             /* update this vertex's primflag */
3385             nir_def *primflag = nir_load_var(b, primflag_var);
3386             primflag = nir_ior(b, primflag, vtx_live_flag);
3387             nir_store_var(b, primflag_var, primflag, 1);
3388          }
3389          nir_pop_if(b, if_not_overflow);
3390       }
3391    }
3392    nir_pop_if(b, if_update_primflag);
3393 
3394    return nir_load_var(b, primflag_var);
3395 }
3396 
3397 static void
ngg_gs_build_streamout(nir_builder * b,lower_ngg_gs_state * s)3398 ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *s)
3399 {
3400    nir_xfb_info *info = b->shader->xfb_info;
3401 
3402    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
3403    nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
3404    nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
3405    nir_def *prim_live[4] = {0};
3406    nir_def *gen_prim[4] = {0};
3407    nir_def *export_seq[4] = {0};
3408    nir_def *out_vtx_primflag[4] = {0};
3409    for (unsigned stream = 0; stream < 4; stream++) {
3410       if (!(info->streams_written & BITFIELD_BIT(stream)))
3411          continue;
3412 
3413       out_vtx_primflag[stream] =
3414          ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
3415 
3416       /* Check bit 0 of primflag for primitive alive, it's set for every last
3417        * vertex of a primitive.
3418        */
3419       prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 1));
3420 
3421       unsigned scratch_stride = ALIGN(s->max_num_waves, 4);
3422       nir_def *scratch_base =
3423          nir_iadd_imm(b, s->lds_addr_gs_scratch, stream * scratch_stride);
3424 
3425       /* We want to export primitives to streamout buffer in sequence,
3426        * but not all vertices are alive or mark end of a primitive, so
3427        * there're "holes". We don't need continuous invocations to write
3428        * primitives to streamout buffer like final vertex export, so
3429        * just repack to get the sequence (export_seq) is enough, no need
3430        * to do compaction.
3431        *
3432        * Use separate scratch space for each stream to avoid barrier.
3433        * TODO: we may further reduce barriers by writing to all stream
3434        * LDS at once, then we only need one barrier instead of one each
3435        * stream..
3436        */
3437       wg_repack_result rep =
3438          repack_invocations_in_workgroup(b, prim_live[stream], scratch_base,
3439                                          s->max_num_waves, s->options->wave_size);
3440 
3441       /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of
3442        * current wave, but still need LDS to sum all wave's count to get workgroup count.
3443        * And we need repack to export primitive to streamout buffer anyway, so do here.
3444        */
3445       gen_prim[stream] = rep.num_repacked_invocations;
3446       export_seq[stream] = rep.repacked_invocation_index;
3447    }
3448 
3449    /* Workgroup barrier: wait for LDS scratch reads finish. */
3450    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
3451                       .memory_scope = SCOPE_WORKGROUP,
3452                       .memory_semantics = NIR_MEMORY_ACQ_REL,
3453                       .memory_modes = nir_var_mem_shared);
3454 
3455    /* Get global buffer offset where this workgroup will stream out data to. */
3456    nir_def *emit_prim[4] = {0};
3457    nir_def *buffer_offsets[4] = {0};
3458    nir_def *so_buffer[4] = {0};
3459    nir_def *prim_stride[4] = {0};
3460    ngg_build_streamout_buffer_info(b, info, s->options->gfx_level, s->options->has_xfb_prim_query,
3461                                    s->options->use_gfx12_xfb_intrinsic, s->lds_addr_gs_scratch, tid_in_tg,
3462                                    gen_prim, prim_stride, so_buffer, buffer_offsets, emit_prim);
3463 
3464    for (unsigned stream = 0; stream < 4; stream++) {
3465       if (!(info->streams_written & BITFIELD_BIT(stream)))
3466          continue;
3467 
3468       nir_def *can_emit = nir_ilt(b, export_seq[stream], emit_prim[stream]);
3469       nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, prim_live[stream]));
3470       {
3471          /* Get streamout buffer vertex index for the first vertex of this primitive. */
3472          nir_def *vtx_buffer_idx =
3473             nir_imul_imm(b, export_seq[stream], s->num_vertices_per_primitive);
3474 
3475          /* Get all vertices' lds address of this primitive. */
3476          nir_def *exported_vtx_lds_addr[3];
3477          ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr,
3478                                     out_vtx_primflag[stream], s,
3479                                     exported_vtx_lds_addr);
3480 
3481          /* Write all vertices of this primitive to streamout buffer. */
3482          for (unsigned i = 0; i < s->num_vertices_per_primitive; i++) {
3483             ngg_build_streamout_vertex(b, info, stream, so_buffer,
3484                                        buffer_offsets,
3485                                        nir_iadd_imm(b, vtx_buffer_idx, i),
3486                                        exported_vtx_lds_addr[i],
3487                                        &s->out, false);
3488          }
3489       }
3490       nir_pop_if(b, if_emit);
3491    }
3492 }
3493 
3494 static void
ngg_gs_finale(nir_builder * b,lower_ngg_gs_state * s)3495 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
3496 {
3497    nir_def *tid_in_tg = nir_load_local_invocation_index(b);
3498    nir_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
3499    nir_def *max_prmcnt = max_vtxcnt; /* They are currently practically the same; both RADV and RadeonSI do this. */
3500    nir_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, s);
3501 
3502    if (s->output_compile_time_known) {
3503       /* When the output is compile-time known, the GS writes all possible vertices and primitives it can.
3504        * The gs_alloc_req needs to happen on one wave only, otherwise the HW hangs.
3505        */
3506       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
3507       alloc_vertices_and_primitives(b, max_vtxcnt, max_prmcnt);
3508       nir_pop_if(b, if_wave_0);
3509    }
3510 
3511    /* Workgroup barrier already emitted, we can assume all GS output stores are done by now. */
3512 
3513    nir_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
3514 
3515    if (s->output_compile_time_known) {
3516       ngg_gs_export_primitives(b, max_vtxcnt, tid_in_tg, tid_in_tg, out_vtx_primflag_0, s);
3517       ngg_gs_export_vertices(b, max_vtxcnt, tid_in_tg, out_vtx_lds_addr, s);
3518       return;
3519    }
3520 
3521    /* cull primitives */
3522    if (s->options->can_cull) {
3523       nir_if *if_cull_en = nir_push_if(b, nir_load_cull_any_enabled_amd(b));
3524 
3525       /* culling code will update the primflag */
3526       nir_def *updated_primflag =
3527          ngg_gs_cull_primitive(b, tid_in_tg, max_vtxcnt, out_vtx_lds_addr,
3528                                out_vtx_primflag_0, s);
3529 
3530       nir_pop_if(b, if_cull_en);
3531 
3532       out_vtx_primflag_0 = nir_if_phi(b, updated_primflag, out_vtx_primflag_0);
3533    }
3534 
3535    /* When the output vertex count is not known at compile time:
3536     * There may be gaps between invocations that have live vertices, but NGG hardware
3537     * requires that the invocations that export vertices are packed (ie. compact).
3538     * To ensure this, we need to repack invocations that have a live vertex.
3539     */
3540    nir_def *vertex_live = nir_ine_imm(b, out_vtx_primflag_0, 0);
3541    wg_repack_result rep = repack_invocations_in_workgroup(b, vertex_live, s->lds_addr_gs_scratch,
3542                                                           s->max_num_waves, s->options->wave_size);
3543 
3544    nir_def *workgroup_num_vertices = rep.num_repacked_invocations;
3545    nir_def *exporter_tid_in_tg = rep.repacked_invocation_index;
3546 
3547    /* When the workgroup emits 0 total vertices, we also must export 0 primitives (otherwise the HW can hang). */
3548    nir_def *any_output = nir_ine_imm(b, workgroup_num_vertices, 0);
3549    max_prmcnt = nir_bcsel(b, any_output, max_prmcnt, nir_imm_int(b, 0));
3550 
3551    /* Allocate export space. We currently don't compact primitives, just use the maximum number. */
3552    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
3553    {
3554       if (s->options->gfx_level == GFX10)
3555          alloc_vertices_and_primitives_gfx10_workaround(b, workgroup_num_vertices, max_prmcnt);
3556       else
3557          alloc_vertices_and_primitives(b, workgroup_num_vertices, max_prmcnt);
3558    }
3559    nir_pop_if(b, if_wave_0);
3560 
3561    /* Vertex compaction. This makes sure there are no gaps between threads that export vertices. */
3562    ngg_gs_setup_vertex_compaction(b, vertex_live, tid_in_tg, exporter_tid_in_tg, s);
3563 
3564    /* Workgroup barrier: wait for all LDS stores to finish. */
3565    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3566                         .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3567 
3568    ngg_gs_export_primitives(b, max_prmcnt, tid_in_tg, exporter_tid_in_tg, out_vtx_primflag_0, s);
3569    ngg_gs_export_vertices(b, workgroup_num_vertices, tid_in_tg, out_vtx_lds_addr, s);
3570 }
3571 
3572 void
ac_nir_lower_ngg_gs(nir_shader * shader,const ac_nir_lower_ngg_options * options)3573 ac_nir_lower_ngg_gs(nir_shader *shader, const ac_nir_lower_ngg_options *options)
3574 {
3575    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
3576    assert(impl);
3577 
3578    lower_ngg_gs_state state = {
3579       .options = options,
3580       .impl = impl,
3581       .max_num_waves = DIV_ROUND_UP(options->max_workgroup_size, options->wave_size),
3582       .lds_offs_primflags = options->gs_out_vtx_bytes,
3583       .lds_bytes_per_gs_out_vertex = options->gs_out_vtx_bytes + 4u,
3584       .streamout_enabled = shader->xfb_info && !options->disable_streamout,
3585    };
3586 
3587    if (!options->can_cull) {
3588       nir_gs_count_vertices_and_primitives(shader, state.const_out_vtxcnt,
3589                                            state.const_out_prmcnt, NULL, 4u);
3590       state.output_compile_time_known =
3591          state.const_out_vtxcnt[0] == shader->info.gs.vertices_out &&
3592          state.const_out_prmcnt[0] != -1;
3593    }
3594 
3595    if (shader->info.gs.output_primitive == MESA_PRIM_POINTS)
3596       state.num_vertices_per_primitive = 1;
3597    else if (shader->info.gs.output_primitive == MESA_PRIM_LINE_STRIP)
3598       state.num_vertices_per_primitive = 2;
3599    else if (shader->info.gs.output_primitive == MESA_PRIM_TRIANGLE_STRIP)
3600       state.num_vertices_per_primitive = 3;
3601    else
3602       unreachable("Invalid GS output primitive.");
3603 
3604    /* Extract the full control flow. It is going to be wrapped in an if statement. */
3605    nir_cf_list extracted;
3606    nir_cf_extract(&extracted, nir_before_impl(impl),
3607                   nir_after_impl(impl));
3608 
3609    nir_builder builder = nir_builder_at(nir_before_impl(impl));
3610    nir_builder *b = &builder; /* This is to avoid the & */
3611 
3612    /* Workgroup barrier: wait for ES threads */
3613    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3614                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3615 
3616    state.lds_addr_gs_out_vtx = nir_load_lds_ngg_gs_out_vertex_base_amd(b);
3617    state.lds_addr_gs_scratch = nir_load_lds_ngg_scratch_base_amd(b);
3618 
3619    /* Wrap the GS control flow. */
3620    nir_if *if_gs_thread = nir_push_if(b, has_input_primitive(b));
3621 
3622    nir_cf_reinsert(&extracted, b->cursor);
3623    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
3624    nir_pop_if(b, if_gs_thread);
3625 
3626    /* Workgroup barrier: wait for all GS threads to finish */
3627    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
3628                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared);
3629 
3630    if (state.streamout_enabled)
3631       ngg_gs_build_streamout(b, &state);
3632 
3633    /* Lower the GS intrinsics */
3634    lower_ngg_gs_intrinsics(shader, &state);
3635 
3636    if (!state.vertex_count[0]) {
3637       fprintf(stderr, "Could not find set_vertex_and_primitive_count for stream 0. This would hang your GPU.");
3638       abort();
3639    }
3640 
3641    /* Emit shader queries */
3642    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
3643    ac_nir_gs_shader_query(b,
3644                           state.options->has_gen_prim_query,
3645                           state.options->has_gs_invocations_query,
3646                           state.options->has_gs_primitives_query,
3647                           state.num_vertices_per_primitive,
3648                           state.options->wave_size,
3649                           state.vertex_count,
3650                           state.primitive_count);
3651 
3652    b->cursor = nir_after_impl(impl);
3653 
3654    /* Emit the finale sequence */
3655    ngg_gs_finale(b, &state);
3656    nir_validate_shader(shader, "after emitting NGG GS");
3657 
3658    /* Cleanup */
3659    nir_lower_vars_to_ssa(shader);
3660    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
3661    nir_metadata_preserve(impl, nir_metadata_none);
3662 }
3663 
3664 unsigned
ac_ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,unsigned shader_num_outputs,bool streamout_enabled,bool export_prim_id,bool has_user_edgeflags,bool can_cull,bool uses_instance_id,bool uses_primitive_id)3665 ac_ngg_nogs_get_pervertex_lds_size(gl_shader_stage stage,
3666                                    unsigned shader_num_outputs,
3667                                    bool streamout_enabled,
3668                                    bool export_prim_id,
3669                                    bool has_user_edgeflags,
3670                                    bool can_cull,
3671                                    bool uses_instance_id,
3672                                    bool uses_primitive_id)
3673 {
3674    /* for culling time lds layout only */
3675    unsigned culling_pervertex_lds_bytes = can_cull ?
3676       ngg_nogs_get_culling_pervertex_lds_size(
3677          stage, uses_instance_id, uses_primitive_id, NULL) : 0;
3678 
3679    unsigned pervertex_lds_bytes =
3680       ngg_nogs_get_pervertex_lds_size(stage, shader_num_outputs, streamout_enabled,
3681                                       export_prim_id, has_user_edgeflags);
3682 
3683    return MAX2(culling_pervertex_lds_bytes, pervertex_lds_bytes);
3684 }
3685 
3686 unsigned
ac_ngg_get_scratch_lds_size(gl_shader_stage stage,unsigned workgroup_size,unsigned wave_size,bool streamout_enabled,bool can_cull)3687 ac_ngg_get_scratch_lds_size(gl_shader_stage stage,
3688                             unsigned workgroup_size,
3689                             unsigned wave_size,
3690                             bool streamout_enabled,
3691                             bool can_cull)
3692 {
3693    unsigned scratch_lds_size = 0;
3694    unsigned max_num_waves = DIV_ROUND_UP(workgroup_size, wave_size);
3695 
3696    if (stage == MESA_SHADER_VERTEX || stage == MESA_SHADER_TESS_EVAL) {
3697       if (streamout_enabled) {
3698          /* 4 dwords for 4 streamout buffer offset, 1 dword for emit prim count */
3699          scratch_lds_size = 20;
3700       } else if (can_cull) {
3701          scratch_lds_size = ALIGN(max_num_waves, 4u);
3702       }
3703    } else {
3704       assert(stage == MESA_SHADER_GEOMETRY);
3705 
3706       scratch_lds_size = ALIGN(max_num_waves, 4u);
3707       /* streamout take 8 dwords for buffer offset and emit vertex per stream */
3708       if (streamout_enabled)
3709          scratch_lds_size = MAX2(scratch_lds_size, 32);
3710    }
3711 
3712    return scratch_lds_size;
3713 }
3714 
3715 static void
ms_store_prim_indices(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)3716 ms_store_prim_indices(nir_builder *b,
3717                       nir_intrinsic_instr *intrin,
3718                       lower_ngg_ms_state *s)
3719 {
3720    /* EXT_mesh_shader primitive indices: array of vectors.
3721     * They don't count as per-primitive outputs, but the array is indexed
3722     * by the primitive index, so they are practically per-primitive.
3723     */
3724    assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
3725    assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
3726 
3727    const unsigned component_offset = nir_intrinsic_component(intrin);
3728    nir_def *store_val = intrin->src[0].ssa;
3729    assert(store_val->num_components <= 3);
3730 
3731    if (store_val->num_components > s->vertices_per_prim)
3732       store_val = nir_trim_vector(b, store_val, s->vertices_per_prim);
3733 
3734    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
3735       for (unsigned c = 0; c < store_val->num_components; ++c) {
3736          const unsigned i = VARYING_SLOT_PRIMITIVE_INDICES * 4 + c + component_offset;
3737          nir_store_var(b, s->out_variables[i], nir_channel(b, store_val, c), 0x1);
3738       }
3739       return;
3740    }
3741 
3742    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
3743    nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
3744 
3745    /* The max vertex count is 256, so these indices always fit 8 bits.
3746     * To reduce LDS use, store these as a flat array of 8-bit values.
3747     */
3748    nir_store_shared(b, nir_u2u8(b, store_val), offset, .base = s->layout.lds.indices_addr + component_offset);
3749 }
3750 
3751 static void
ms_store_cull_flag(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)3752 ms_store_cull_flag(nir_builder *b,
3753                    nir_intrinsic_instr *intrin,
3754                    lower_ngg_ms_state *s)
3755 {
3756    /* EXT_mesh_shader cull primitive: per-primitive bool. */
3757    assert(nir_src_is_const(*nir_get_io_offset_src(intrin)));
3758    assert(nir_src_as_uint(*nir_get_io_offset_src(intrin)) == 0);
3759    assert(nir_intrinsic_component(intrin) == 0);
3760    assert(nir_intrinsic_write_mask(intrin) == 1);
3761 
3762    nir_def *store_val = intrin->src[0].ssa;
3763 
3764    assert(store_val->num_components == 1);
3765    assert(store_val->bit_size == 1);
3766 
3767    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE)) {
3768       nir_store_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4], nir_b2i32(b, store_val), 0x1);
3769       return;
3770    }
3771 
3772    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
3773    nir_def *offset = nir_imul_imm(b, arr_index, s->vertices_per_prim);
3774 
3775    /* To reduce LDS use, store these as an array of 8-bit values. */
3776    nir_store_shared(b, nir_b2i8(b, store_val), offset, .base = s->layout.lds.cull_flags_addr);
3777 }
3778 
3779 static nir_def *
ms_arrayed_output_base_addr(nir_builder * b,nir_def * arr_index,unsigned mapped_location,unsigned num_arrayed_outputs)3780 ms_arrayed_output_base_addr(nir_builder *b,
3781                             nir_def *arr_index,
3782                             unsigned mapped_location,
3783                             unsigned num_arrayed_outputs)
3784 {
3785    /* Address offset of the array item (vertex or primitive). */
3786    unsigned arr_index_stride = num_arrayed_outputs * 16u;
3787    nir_def *arr_index_off = nir_imul_imm(b, arr_index, arr_index_stride);
3788 
3789    /* IO address offset within the vertex or primitive data. */
3790    unsigned io_offset = mapped_location * 16u;
3791    nir_def *io_off = nir_imm_int(b, io_offset);
3792 
3793    return nir_iadd_nuw(b, arr_index_off, io_off);
3794 }
3795 
3796 static void
update_ms_output_info_slot(lower_ngg_ms_state * s,unsigned slot,unsigned base_off,uint32_t components_mask)3797 update_ms_output_info_slot(lower_ngg_ms_state *s,
3798                            unsigned slot, unsigned base_off,
3799                            uint32_t components_mask)
3800 {
3801    while (components_mask) {
3802       s->output_info[slot + base_off].components_mask |= components_mask & 0xF;
3803 
3804       components_mask >>= 4;
3805       base_off++;
3806    }
3807 }
3808 
3809 static void
update_ms_output_info(const nir_io_semantics io_sem,const nir_src * base_offset_src,const uint32_t write_mask,const unsigned component_offset,const unsigned bit_size,const ms_out_part * out,lower_ngg_ms_state * s)3810 update_ms_output_info(const nir_io_semantics io_sem,
3811                       const nir_src *base_offset_src,
3812                       const uint32_t write_mask,
3813                       const unsigned component_offset,
3814                       const unsigned bit_size,
3815                       const ms_out_part *out,
3816                       lower_ngg_ms_state *s)
3817 {
3818    uint32_t write_mask_32 = util_widen_mask(write_mask, DIV_ROUND_UP(bit_size, 32));
3819    uint32_t components_mask = write_mask_32 << component_offset;
3820 
3821    if (nir_src_is_const(*base_offset_src)) {
3822       /* Simply mark the components of the current slot as used. */
3823       unsigned base_off = nir_src_as_uint(*base_offset_src);
3824       update_ms_output_info_slot(s, io_sem.location, base_off, components_mask);
3825    } else {
3826       /* Indirect offset: mark the components of all slots as used. */
3827       for (unsigned base_off = 0; base_off < io_sem.num_slots; ++base_off)
3828          update_ms_output_info_slot(s, io_sem.location, base_off, components_mask);
3829    }
3830 }
3831 
3832 static const ms_out_part *
ms_get_out_layout_part(unsigned location,shader_info * info,ms_out_mode * out_mode,lower_ngg_ms_state * s)3833 ms_get_out_layout_part(unsigned location,
3834                        shader_info *info,
3835                        ms_out_mode *out_mode,
3836                        lower_ngg_ms_state *s)
3837 {
3838    uint64_t mask = BITFIELD64_BIT(location);
3839 
3840    if (info->per_primitive_outputs & mask) {
3841       if (mask & s->layout.lds.prm_attr.mask) {
3842          *out_mode = ms_out_mode_lds;
3843          return &s->layout.lds.prm_attr;
3844       } else if (mask & s->layout.scratch_ring.prm_attr.mask) {
3845          *out_mode = ms_out_mode_scratch_ring;
3846          return &s->layout.scratch_ring.prm_attr;
3847       } else if (mask & s->layout.attr_ring.prm_attr.mask) {
3848          *out_mode = ms_out_mode_attr_ring;
3849          return &s->layout.attr_ring.prm_attr;
3850       } else if (mask & s->layout.var.prm_attr.mask) {
3851          *out_mode = ms_out_mode_var;
3852          return &s->layout.var.prm_attr;
3853       }
3854    } else {
3855       if (mask & s->layout.lds.vtx_attr.mask) {
3856          *out_mode = ms_out_mode_lds;
3857          return &s->layout.lds.vtx_attr;
3858       } else if (mask & s->layout.scratch_ring.vtx_attr.mask) {
3859          *out_mode = ms_out_mode_scratch_ring;
3860          return &s->layout.scratch_ring.vtx_attr;
3861       } else if (mask & s->layout.attr_ring.vtx_attr.mask) {
3862          *out_mode = ms_out_mode_attr_ring;
3863          return &s->layout.attr_ring.vtx_attr;
3864       } else if (mask & s->layout.var.vtx_attr.mask) {
3865          *out_mode = ms_out_mode_var;
3866          return &s->layout.var.vtx_attr;
3867       }
3868    }
3869 
3870    unreachable("Couldn't figure out mesh shader output mode.");
3871 }
3872 
3873 static void
ms_store_arrayed_output(nir_builder * b,nir_src * base_off_src,nir_def * store_val,nir_def * arr_index,const nir_io_semantics io_sem,const unsigned component_offset,const unsigned write_mask,lower_ngg_ms_state * s)3874 ms_store_arrayed_output(nir_builder *b,
3875                         nir_src *base_off_src,
3876                         nir_def *store_val,
3877                         nir_def *arr_index,
3878                         const nir_io_semantics io_sem,
3879                         const unsigned component_offset,
3880                         const unsigned write_mask,
3881                         lower_ngg_ms_state *s)
3882 {
3883    ms_out_mode out_mode;
3884    const ms_out_part *out = ms_get_out_layout_part(io_sem.location, &b->shader->info, &out_mode, s);
3885    update_ms_output_info(io_sem, base_off_src, write_mask, component_offset, store_val->bit_size, out, s);
3886 
3887    bool hi_16b = io_sem.high_16bits;
3888    bool lo_16b = !hi_16b && store_val->bit_size == 16;
3889 
3890    unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, io_sem.location));
3891    unsigned num_outputs = util_bitcount64(out->mask);
3892    unsigned const_off = out->addr + component_offset * 4 + (hi_16b ? 2 : 0);
3893 
3894    nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
3895    nir_def *base_offset = base_off_src->ssa;
3896    nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16u);
3897    nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
3898 
3899    if (out_mode == ms_out_mode_lds) {
3900       nir_store_shared(b, store_val, addr, .base = const_off,
3901                      .write_mask = write_mask, .align_mul = 16,
3902                      .align_offset = const_off % 16);
3903    } else if (out_mode == ms_out_mode_scratch_ring) {
3904       nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
3905       nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
3906       nir_def *zero = nir_imm_int(b, 0);
3907       nir_store_buffer_amd(b, store_val, ring, addr, off, zero,
3908                            .base = const_off,
3909                            .write_mask = write_mask,
3910                            .memory_modes = nir_var_shader_out,
3911                            .access = ACCESS_COHERENT);
3912    } else if (out_mode == ms_out_mode_attr_ring) {
3913       /* GFX11+: Store params straight to the attribute ring.
3914        *
3915        * Even though the access pattern may not be the most optimal,
3916        * this is still much better than reserving LDS and losing waves.
3917        * (Also much better than storing and reloading from the scratch ring.)
3918        */
3919       unsigned param_offset = s->vs_output_param_offset[io_sem.location];
3920       nir_def *ring = nir_load_ring_attr_amd(b);
3921       nir_def *soffset = nir_load_ring_attr_offset_amd(b);
3922       nir_store_buffer_amd(b, store_val, ring, base_addr_off, soffset, arr_index,
3923                            .base = const_off + param_offset * 16,
3924                            .write_mask = write_mask,
3925                            .memory_modes = nir_var_shader_out,
3926                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
3927    } else if (out_mode == ms_out_mode_var) {
3928       unsigned write_mask_32 = write_mask;
3929       if (store_val->bit_size > 32) {
3930          /* Split 64-bit store values to 32-bit components. */
3931          store_val = nir_bitcast_vector(b, store_val, 32);
3932          /* Widen the write mask so it is in 32-bit components. */
3933          write_mask_32 = util_widen_mask(write_mask, store_val->bit_size / 32);
3934       }
3935 
3936       u_foreach_bit(comp, write_mask_32) {
3937          unsigned idx = io_sem.location * 4 + comp + component_offset;
3938          nir_def *val = nir_channel(b, store_val, comp);
3939          nir_def *v = nir_load_var(b, s->out_variables[idx]);
3940 
3941          if (lo_16b) {
3942             nir_def *var_hi = nir_unpack_32_2x16_split_y(b, v);
3943             val = nir_pack_32_2x16_split(b, val, var_hi);
3944          } else if (hi_16b) {
3945             nir_def *var_lo = nir_unpack_32_2x16_split_x(b, v);
3946             val = nir_pack_32_2x16_split(b, var_lo, val);
3947          }
3948 
3949          nir_store_var(b, s->out_variables[idx], val, 0x1);
3950       }
3951    } else {
3952       unreachable("Invalid MS output mode for store");
3953    }
3954 }
3955 
3956 static void
ms_store_arrayed_output_intrin(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)3957 ms_store_arrayed_output_intrin(nir_builder *b,
3958                                nir_intrinsic_instr *intrin,
3959                                lower_ngg_ms_state *s)
3960 {
3961    const nir_io_semantics io_sem = nir_intrinsic_io_semantics(intrin);
3962 
3963    if (io_sem.location == VARYING_SLOT_PRIMITIVE_INDICES) {
3964       ms_store_prim_indices(b, intrin, s);
3965       return;
3966    } else if (io_sem.location == VARYING_SLOT_CULL_PRIMITIVE) {
3967       ms_store_cull_flag(b, intrin, s);
3968       return;
3969    }
3970 
3971    unsigned component_offset = nir_intrinsic_component(intrin);
3972    unsigned write_mask = nir_intrinsic_write_mask(intrin);
3973 
3974    nir_def *store_val = intrin->src[0].ssa;
3975    nir_def *arr_index = nir_get_io_arrayed_index_src(intrin)->ssa;
3976    nir_src *base_off_src = nir_get_io_offset_src(intrin);
3977 
3978    if (store_val->bit_size < 32) {
3979       /* Split 16-bit output stores to ensure each 16-bit component is stored
3980        * in the correct location, without overwriting the other 16 bits there.
3981        */
3982       u_foreach_bit(c, write_mask) {
3983          nir_def *store_component = nir_channel(b, store_val, c);
3984          ms_store_arrayed_output(b, base_off_src, store_component, arr_index, io_sem, c + component_offset, 1, s);
3985       }
3986    } else {
3987       ms_store_arrayed_output(b, base_off_src, store_val, arr_index, io_sem, component_offset, write_mask, s);
3988    }
3989 }
3990 
3991 static nir_def *
ms_load_arrayed_output(nir_builder * b,nir_def * arr_index,nir_def * base_offset,unsigned location,unsigned component_offset,unsigned num_components,unsigned load_bit_size,lower_ngg_ms_state * s)3992 ms_load_arrayed_output(nir_builder *b,
3993                        nir_def *arr_index,
3994                        nir_def *base_offset,
3995                        unsigned location,
3996                        unsigned component_offset,
3997                        unsigned num_components,
3998                        unsigned load_bit_size,
3999                        lower_ngg_ms_state *s)
4000 {
4001    ms_out_mode out_mode;
4002    const ms_out_part *out = ms_get_out_layout_part(location, &b->shader->info, &out_mode, s);
4003 
4004    unsigned component_addr_off = component_offset * 4;
4005    unsigned num_outputs = util_bitcount64(out->mask);
4006    unsigned const_off = out->addr + component_offset * 4;
4007 
4008    /* Use compacted location instead of the original semantic location. */
4009    unsigned mapped_location = util_bitcount64(out->mask & u_bit_consecutive64(0, location));
4010 
4011    nir_def *base_addr = ms_arrayed_output_base_addr(b, arr_index, mapped_location, num_outputs);
4012    nir_def *base_addr_off = nir_imul_imm(b, base_offset, 16);
4013    nir_def *addr = nir_iadd_nuw(b, base_addr, base_addr_off);
4014 
4015    if (out_mode == ms_out_mode_lds) {
4016       return nir_load_shared(b, num_components, load_bit_size, addr, .align_mul = 16,
4017                              .align_offset = component_addr_off % 16,
4018                              .base = const_off);
4019    } else if (out_mode == ms_out_mode_scratch_ring) {
4020       nir_def *ring = nir_load_ring_mesh_scratch_amd(b);
4021       nir_def *off = nir_load_ring_mesh_scratch_offset_amd(b);
4022       nir_def *zero = nir_imm_int(b, 0);
4023       return nir_load_buffer_amd(b, num_components, load_bit_size, ring, addr, off, zero,
4024                                  .base = const_off,
4025                                  .memory_modes = nir_var_shader_out,
4026                                  .access = ACCESS_COHERENT);
4027    } else if (out_mode == ms_out_mode_var) {
4028       assert(load_bit_size == 32);
4029       nir_def *arr[8] = {0};
4030       for (unsigned comp = 0; comp < num_components; ++comp) {
4031          unsigned idx = location * 4 + comp + component_addr_off;
4032          arr[comp] = nir_load_var(b, s->out_variables[idx]);
4033       }
4034       return nir_vec(b, arr, num_components);
4035    } else {
4036       unreachable("Invalid MS output mode for load");
4037    }
4038 }
4039 
4040 static nir_def *
lower_ms_load_workgroup_index(nir_builder * b,UNUSED nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)4041 lower_ms_load_workgroup_index(nir_builder *b,
4042                               UNUSED nir_intrinsic_instr *intrin,
4043                               lower_ngg_ms_state *s)
4044 {
4045    return s->workgroup_index;
4046 }
4047 
4048 static nir_def *
lower_ms_set_vertex_and_primitive_count(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)4049 lower_ms_set_vertex_and_primitive_count(nir_builder *b,
4050                                         nir_intrinsic_instr *intrin,
4051                                         lower_ngg_ms_state *s)
4052 {
4053    /* If either the number of vertices or primitives is zero, set both of them to zero. */
4054    nir_def *num_vtx = nir_read_first_invocation(b, intrin->src[0].ssa);
4055    nir_def *num_prm = nir_read_first_invocation(b, intrin->src[1].ssa);
4056    nir_def *zero = nir_imm_int(b, 0);
4057    nir_def *is_either_zero = nir_ieq(b, nir_umin(b, num_vtx, num_prm), zero);
4058    num_vtx = nir_bcsel(b, is_either_zero, zero, num_vtx);
4059    num_prm = nir_bcsel(b, is_either_zero, zero, num_prm);
4060 
4061    nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
4062    nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
4063 
4064    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
4065 }
4066 
4067 static nir_def *
update_ms_barrier(nir_builder * b,nir_intrinsic_instr * intrin,lower_ngg_ms_state * s)4068 update_ms_barrier(nir_builder *b,
4069                          nir_intrinsic_instr *intrin,
4070                          lower_ngg_ms_state *s)
4071 {
4072    /* Output loads and stores are lowered to shared memory access,
4073     * so we have to update the barriers to also reflect this.
4074     */
4075    unsigned mem_modes = nir_intrinsic_memory_modes(intrin);
4076    if (mem_modes & nir_var_shader_out)
4077       mem_modes |= nir_var_mem_shared;
4078    else
4079       return NULL;
4080 
4081    nir_intrinsic_set_memory_modes(intrin, mem_modes);
4082 
4083    return NIR_LOWER_INSTR_PROGRESS;
4084 }
4085 
4086 static nir_def *
lower_ms_intrinsic(nir_builder * b,nir_instr * instr,void * state)4087 lower_ms_intrinsic(nir_builder *b, nir_instr *instr, void *state)
4088 {
4089    lower_ngg_ms_state *s = (lower_ngg_ms_state *) state;
4090 
4091    if (instr->type != nir_instr_type_intrinsic)
4092       return NULL;
4093 
4094    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
4095 
4096    switch (intrin->intrinsic) {
4097    case nir_intrinsic_store_per_vertex_output:
4098    case nir_intrinsic_store_per_primitive_output:
4099       ms_store_arrayed_output_intrin(b, intrin, s);
4100       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
4101    case nir_intrinsic_barrier:
4102       return update_ms_barrier(b, intrin, s);
4103    case nir_intrinsic_load_workgroup_index:
4104       return lower_ms_load_workgroup_index(b, intrin, s);
4105    case nir_intrinsic_set_vertex_and_primitive_count:
4106       return lower_ms_set_vertex_and_primitive_count(b, intrin, s);
4107    default:
4108       unreachable("Not a lowerable mesh shader intrinsic.");
4109    }
4110 }
4111 
4112 static bool
filter_ms_intrinsic(const nir_instr * instr,UNUSED const void * s)4113 filter_ms_intrinsic(const nir_instr *instr,
4114                     UNUSED const void *s)
4115 {
4116    if (instr->type != nir_instr_type_intrinsic)
4117       return false;
4118 
4119    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
4120    return intrin->intrinsic == nir_intrinsic_store_output ||
4121           intrin->intrinsic == nir_intrinsic_load_output ||
4122           intrin->intrinsic == nir_intrinsic_store_per_vertex_output ||
4123           intrin->intrinsic == nir_intrinsic_store_per_primitive_output ||
4124           intrin->intrinsic == nir_intrinsic_barrier ||
4125           intrin->intrinsic == nir_intrinsic_load_workgroup_index ||
4126           intrin->intrinsic == nir_intrinsic_set_vertex_and_primitive_count;
4127 }
4128 
4129 static void
lower_ms_intrinsics(nir_shader * shader,lower_ngg_ms_state * s)4130 lower_ms_intrinsics(nir_shader *shader, lower_ngg_ms_state *s)
4131 {
4132    nir_shader_lower_instructions(shader, filter_ms_intrinsic, lower_ms_intrinsic, s);
4133 }
4134 
4135 static void
ms_emit_arrayed_outputs(nir_builder * b,nir_def * invocation_index,uint64_t mask,lower_ngg_ms_state * s)4136 ms_emit_arrayed_outputs(nir_builder *b,
4137                         nir_def *invocation_index,
4138                         uint64_t mask,
4139                         lower_ngg_ms_state *s)
4140 {
4141    nir_def *zero = nir_imm_int(b, 0);
4142 
4143    u_foreach_bit64(slot, mask) {
4144       /* Should not occur here, handled separately. */
4145       assert(slot != VARYING_SLOT_PRIMITIVE_COUNT && slot != VARYING_SLOT_PRIMITIVE_INDICES);
4146 
4147       unsigned component_mask = s->output_info[slot].components_mask;
4148 
4149       while (component_mask) {
4150          int start_comp = 0, num_components = 1;
4151          u_bit_scan_consecutive_range(&component_mask, &start_comp, &num_components);
4152 
4153          nir_def *load =
4154             ms_load_arrayed_output(b, invocation_index, zero, slot, start_comp,
4155                                    num_components, 32, s);
4156 
4157          for (int i = 0; i < num_components; i++)
4158             s->outputs[slot][start_comp + i] = nir_channel(b, load, i);
4159       }
4160    }
4161 }
4162 
4163 static void
ms_create_same_invocation_vars(nir_builder * b,lower_ngg_ms_state * s)4164 ms_create_same_invocation_vars(nir_builder *b, lower_ngg_ms_state *s)
4165 {
4166    /* Initialize NIR variables for same-invocation outputs. */
4167    uint64_t same_invocation_output_mask = s->layout.var.prm_attr.mask | s->layout.var.vtx_attr.mask;
4168 
4169    u_foreach_bit64(slot, same_invocation_output_mask) {
4170       for (unsigned comp = 0; comp < 4; ++comp) {
4171          unsigned idx = slot * 4 + comp;
4172          s->out_variables[idx] = nir_local_variable_create(b->impl, glsl_uint_type(), "ms_var_output");
4173       }
4174    }
4175 }
4176 
4177 static void
ms_emit_legacy_workgroup_index(nir_builder * b,lower_ngg_ms_state * s)4178 ms_emit_legacy_workgroup_index(nir_builder *b, lower_ngg_ms_state *s)
4179 {
4180    /* Workgroup ID should have been lowered to workgroup index. */
4181    assert(!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_ID));
4182 
4183    /* No need to do anything if the shader doesn't use the workgroup index. */
4184    if (!BITSET_TEST(b->shader->info.system_values_read, SYSTEM_VALUE_WORKGROUP_INDEX))
4185       return;
4186 
4187    b->cursor = nir_before_impl(b->impl);
4188 
4189    /* Legacy fast launch mode (FAST_LAUNCH=1):
4190     *
4191     * The HW doesn't support a proper workgroup index for vertex processing stages,
4192     * so we use the vertex ID which is equivalent to the index of the current workgroup
4193     * within the current dispatch.
4194     *
4195     * Due to the register programming of mesh shaders, this value is only filled for
4196     * the first invocation of the first wave. To let other waves know, we use LDS.
4197     */
4198    nir_def *workgroup_index = nir_load_vertex_id_zero_base(b);
4199 
4200    if (s->api_workgroup_size <= s->wave_size) {
4201       /* API workgroup is small, so we don't need to use LDS. */
4202       s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
4203       return;
4204    }
4205 
4206    unsigned workgroup_index_lds_addr = s->layout.lds.workgroup_info_addr + lds_ms_wg_index;
4207 
4208    nir_def *zero = nir_imm_int(b, 0);
4209    nir_def *dont_care = nir_undef(b, 1, 32);
4210    nir_def *loaded_workgroup_index = NULL;
4211 
4212    /* Use elect to make sure only 1 invocation uses LDS. */
4213    nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
4214    {
4215       nir_def *wave_id = nir_load_subgroup_id(b);
4216       nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, wave_id, 0));
4217       {
4218          nir_store_shared(b, workgroup_index, zero, .base = workgroup_index_lds_addr);
4219          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4220                                .memory_scope = SCOPE_WORKGROUP,
4221                                .memory_semantics = NIR_MEMORY_ACQ_REL,
4222                                .memory_modes = nir_var_mem_shared);
4223       }
4224       nir_push_else(b, if_wave_0);
4225       {
4226          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4227                                .memory_scope = SCOPE_WORKGROUP,
4228                                .memory_semantics = NIR_MEMORY_ACQ_REL,
4229                                .memory_modes = nir_var_mem_shared);
4230          loaded_workgroup_index = nir_load_shared(b, 1, 32, zero, .base = workgroup_index_lds_addr);
4231       }
4232       nir_pop_if(b, if_wave_0);
4233 
4234       workgroup_index = nir_if_phi(b, workgroup_index, loaded_workgroup_index);
4235    }
4236    nir_pop_if(b, if_elected);
4237 
4238    workgroup_index = nir_if_phi(b, workgroup_index, dont_care);
4239    s->workgroup_index = nir_read_first_invocation(b, workgroup_index);
4240 }
4241 
4242 static void
set_ms_final_output_counts(nir_builder * b,lower_ngg_ms_state * s,nir_def ** out_num_prm,nir_def ** out_num_vtx)4243 set_ms_final_output_counts(nir_builder *b,
4244                            lower_ngg_ms_state *s,
4245                            nir_def **out_num_prm,
4246                            nir_def **out_num_vtx)
4247 {
4248    /* The spec allows the numbers to be divergent, and in that case we need to
4249     * use the values from the first invocation. Also the HW requires us to set
4250     * both to 0 if either was 0.
4251     *
4252     * These are already done by the lowering.
4253     */
4254    nir_def *num_prm = nir_load_var(b, s->primitive_count_var);
4255    nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
4256 
4257    if (s->hw_workgroup_size <= s->wave_size) {
4258       /* Single-wave mesh shader workgroup. */
4259       alloc_vertices_and_primitives(b, num_vtx, num_prm);
4260       *out_num_prm = num_prm;
4261       *out_num_vtx = num_vtx;
4262       return;
4263    }
4264 
4265    /* Multi-wave mesh shader workgroup:
4266     * We need to use LDS to distribute the correct values to the other waves.
4267     *
4268     * TODO:
4269     * If we can prove that the values are workgroup-uniform, we can skip this
4270     * and just use whatever the current wave has. However, NIR divergence analysis
4271     * currently doesn't support this.
4272     */
4273 
4274    nir_def *zero = nir_imm_int(b, 0);
4275 
4276    nir_if *if_wave_0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
4277    {
4278       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
4279       {
4280          nir_store_shared(b, nir_vec2(b, num_prm, num_vtx), zero,
4281                           .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
4282       }
4283       nir_pop_if(b, if_elected);
4284 
4285       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4286                             .memory_scope = SCOPE_WORKGROUP,
4287                             .memory_semantics = NIR_MEMORY_ACQ_REL,
4288                             .memory_modes = nir_var_mem_shared);
4289 
4290       alloc_vertices_and_primitives(b, num_vtx, num_prm);
4291    }
4292    nir_push_else(b, if_wave_0);
4293    {
4294       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4295                             .memory_scope = SCOPE_WORKGROUP,
4296                             .memory_semantics = NIR_MEMORY_ACQ_REL,
4297                             .memory_modes = nir_var_mem_shared);
4298 
4299       nir_def *prm_vtx = NULL;
4300       nir_def *dont_care_2x32 = nir_undef(b, 2, 32);
4301       nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
4302       {
4303          prm_vtx = nir_load_shared(b, 2, 32, zero,
4304                                    .base = s->layout.lds.workgroup_info_addr + lds_ms_num_prims);
4305       }
4306       nir_pop_if(b, if_elected);
4307 
4308       prm_vtx = nir_if_phi(b, prm_vtx, dont_care_2x32);
4309       num_prm = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 0));
4310       num_vtx = nir_read_first_invocation(b, nir_channel(b, prm_vtx, 1));
4311 
4312       nir_store_var(b, s->primitive_count_var, num_prm, 0x1);
4313       nir_store_var(b, s->vertex_count_var, num_vtx, 0x1);
4314    }
4315    nir_pop_if(b, if_wave_0);
4316 
4317    *out_num_prm = nir_load_var(b, s->primitive_count_var);
4318    *out_num_vtx = nir_load_var(b, s->vertex_count_var);
4319 }
4320 
4321 static void
ms_emit_attribute_ring_output_stores(nir_builder * b,const uint64_t outputs_mask,nir_def * idx,lower_ngg_ms_state * s)4322 ms_emit_attribute_ring_output_stores(nir_builder *b, const uint64_t outputs_mask,
4323                                      nir_def *idx, lower_ngg_ms_state *s)
4324 {
4325    if (!outputs_mask)
4326       return;
4327 
4328    nir_def *ring = nir_load_ring_attr_amd(b);
4329    nir_def *off = nir_load_ring_attr_offset_amd(b);
4330    nir_def *zero = nir_imm_int(b, 0);
4331 
4332    u_foreach_bit64 (slot, outputs_mask) {
4333       if (s->vs_output_param_offset[slot] > AC_EXP_PARAM_OFFSET_31)
4334          continue;
4335 
4336       nir_def *soffset = nir_iadd_imm(b, off, s->vs_output_param_offset[slot] * 16 * 32);
4337       nir_def *store_val = nir_undef(b, 4, 32);
4338       unsigned store_val_components = 0;
4339       for (unsigned c = 0; c < 4; ++c) {
4340          if (s->outputs[slot][c]) {
4341             store_val = nir_vector_insert_imm(b, store_val, s->outputs[slot][c], c);
4342             store_val_components = c + 1;
4343          }
4344       }
4345 
4346       store_val = nir_trim_vector(b, store_val, store_val_components);
4347       nir_store_buffer_amd(b, store_val, ring, zero, soffset, idx,
4348                            .memory_modes = nir_var_shader_out,
4349                            .access = ACCESS_COHERENT | ACCESS_IS_SWIZZLED_AMD);
4350    }
4351 }
4352 
4353 static nir_def *
ms_prim_exp_arg_ch1(nir_builder * b,nir_def * invocation_index,nir_def * num_vtx,lower_ngg_ms_state * s)4354 ms_prim_exp_arg_ch1(nir_builder *b, nir_def *invocation_index, nir_def *num_vtx, lower_ngg_ms_state *s)
4355 {
4356    /* Primitive connectivity data: describes which vertices the primitive uses. */
4357    nir_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
4358    nir_def *indices_loaded = NULL;
4359    nir_def *cull_flag = NULL;
4360 
4361    if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES)) {
4362       nir_def *indices[3] = {0};
4363       for (unsigned c = 0; c < s->vertices_per_prim; ++c)
4364          indices[c] = nir_load_var(b, s->out_variables[VARYING_SLOT_PRIMITIVE_INDICES * 4 + c]);
4365       indices_loaded = nir_vec(b, indices, s->vertices_per_prim);
4366    } else {
4367       indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->layout.lds.indices_addr);
4368       indices_loaded = nir_u2u32(b, indices_loaded);
4369    }
4370 
4371    if (s->uses_cull_flags) {
4372       nir_def *loaded_cull_flag = NULL;
4373       if (s->layout.var.prm_attr.mask & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE))
4374          loaded_cull_flag = nir_load_var(b, s->out_variables[VARYING_SLOT_CULL_PRIMITIVE * 4]);
4375       else
4376          loaded_cull_flag = nir_u2u32(b, nir_load_shared(b, 1, 8, prim_idx_addr, .base = s->layout.lds.cull_flags_addr));
4377 
4378       cull_flag = nir_i2b(b, loaded_cull_flag);
4379    }
4380 
4381    nir_def *indices[3];
4382    nir_def *max_vtx_idx = nir_iadd_imm(b, num_vtx, -1u);
4383 
4384    for (unsigned i = 0; i < s->vertices_per_prim; ++i) {
4385       indices[i] = nir_channel(b, indices_loaded, i);
4386       indices[i] = nir_umin(b, indices[i], max_vtx_idx);
4387    }
4388 
4389    return emit_pack_ngg_prim_exp_arg(b, s->vertices_per_prim, indices, cull_flag, s->gfx_level);
4390 }
4391 
4392 static nir_def *
ms_prim_exp_arg_ch2(nir_builder * b,uint64_t outputs_mask,lower_ngg_ms_state * s)4393 ms_prim_exp_arg_ch2(nir_builder *b, uint64_t outputs_mask, lower_ngg_ms_state *s)
4394 {
4395    nir_def *prim_exp_arg_ch2 = NULL;
4396 
4397    if (outputs_mask) {
4398       /* When layer, viewport etc. are per-primitive, they need to be encoded in
4399        * the primitive export instruction's second channel. The encoding is:
4400        *
4401        * --- GFX10.3 ---
4402        * bits 31..30: VRS rate Y
4403        * bits 29..28: VRS rate X
4404        * bits 23..20: viewport
4405        * bits 19..17: layer
4406        *
4407        * --- GFX11 ---
4408        * bits 31..28: VRS rate enum
4409        * bits 23..20: viewport
4410        * bits 12..00: layer
4411        */
4412       prim_exp_arg_ch2 = nir_imm_int(b, 0);
4413 
4414       if (outputs_mask & VARYING_BIT_LAYER) {
4415          nir_def *layer =
4416             nir_ishl_imm(b, s->outputs[VARYING_SLOT_LAYER][0], s->gfx_level >= GFX11 ? 0 : 17);
4417          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, layer);
4418       }
4419 
4420       if (outputs_mask & VARYING_BIT_VIEWPORT) {
4421          nir_def *view = nir_ishl_imm(b, s->outputs[VARYING_SLOT_VIEWPORT][0], 20);
4422          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, view);
4423       }
4424 
4425       if (outputs_mask & VARYING_BIT_PRIMITIVE_SHADING_RATE) {
4426          nir_def *rate = s->outputs[VARYING_SLOT_PRIMITIVE_SHADING_RATE][0];
4427          prim_exp_arg_ch2 = nir_ior(b, prim_exp_arg_ch2, rate);
4428       }
4429    }
4430 
4431    return prim_exp_arg_ch2;
4432 }
4433 
4434 static void
ms_prim_gen_query(nir_builder * b,nir_def * invocation_index,nir_def * num_prm,lower_ngg_ms_state * s)4435 ms_prim_gen_query(nir_builder *b,
4436                   nir_def *invocation_index,
4437                   nir_def *num_prm,
4438                   lower_ngg_ms_state *s)
4439 {
4440    if (!s->has_query)
4441       return;
4442 
4443    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
4444    {
4445       nir_if *if_shader_query = nir_push_if(b, nir_load_prim_gen_query_enabled_amd(b));
4446       {
4447          nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0);
4448       }
4449       nir_pop_if(b, if_shader_query);
4450    }
4451    nir_pop_if(b, if_invocation_index_zero);
4452 }
4453 
4454 static void
ms_invocation_query(nir_builder * b,nir_def * invocation_index,lower_ngg_ms_state * s)4455 ms_invocation_query(nir_builder *b,
4456                     nir_def *invocation_index,
4457                     lower_ngg_ms_state *s)
4458 {
4459    if (!s->has_query)
4460       return;
4461 
4462    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
4463    {
4464       nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b));
4465       {
4466          nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, s->api_workgroup_size));
4467       }
4468       nir_pop_if(b, if_pipeline_query);
4469    }
4470    nir_pop_if(b, if_invocation_index_zero);
4471 }
4472 
4473 static void
emit_ms_vertex(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_vertex_outputs,lower_ngg_ms_state * s)4474 emit_ms_vertex(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
4475                uint64_t per_vertex_outputs, lower_ngg_ms_state *s)
4476 {
4477    ms_emit_arrayed_outputs(b, index, per_vertex_outputs, s);
4478 
4479    if (exports) {
4480       ac_nir_export_position(b, s->gfx_level, s->clipdist_enable_mask,
4481                              !s->has_param_exports, false, true,
4482                              s->per_vertex_outputs | VARYING_BIT_POS, s->outputs, row);
4483    }
4484 
4485    if (parameters) {
4486       /* Export generic attributes on GFX10.3
4487        * (On GFX11 they are already stored in the attribute ring.)
4488        */
4489       if (s->has_param_exports && s->gfx_level == GFX10_3) {
4490          ac_nir_export_parameters(b, s->vs_output_param_offset, per_vertex_outputs, 0, s->outputs,
4491                                   NULL, NULL);
4492       }
4493 
4494       /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
4495       if (s->gfx_level >= GFX11 && (per_vertex_outputs & MS_VERT_ARG_EXP_MASK))
4496          ms_emit_attribute_ring_output_stores(b, per_vertex_outputs & MS_VERT_ARG_EXP_MASK, index, s);
4497    }
4498 }
4499 
4500 static void
emit_ms_primitive(nir_builder * b,nir_def * index,nir_def * row,bool exports,bool parameters,uint64_t per_primitive_outputs,lower_ngg_ms_state * s)4501 emit_ms_primitive(nir_builder *b, nir_def *index, nir_def *row, bool exports, bool parameters,
4502                   uint64_t per_primitive_outputs, lower_ngg_ms_state *s)
4503 {
4504    ms_emit_arrayed_outputs(b, index, per_primitive_outputs, s);
4505 
4506    /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
4507    if (s->insert_layer_output)
4508       s->outputs[VARYING_SLOT_LAYER][0] = nir_load_view_index(b);
4509 
4510    if (exports) {
4511       const uint64_t outputs_mask = per_primitive_outputs & MS_PRIM_ARG_EXP_MASK;
4512       nir_def *num_vtx = nir_load_var(b, s->vertex_count_var);
4513       nir_def *prim_exp_arg_ch1 = ms_prim_exp_arg_ch1(b, index, num_vtx, s);
4514       nir_def *prim_exp_arg_ch2 = ms_prim_exp_arg_ch2(b, outputs_mask, s);
4515 
4516       nir_def *prim_exp_arg = prim_exp_arg_ch2 ?
4517          nir_vec2(b, prim_exp_arg_ch1, prim_exp_arg_ch2) : prim_exp_arg_ch1;
4518 
4519       ac_nir_export_primitive(b, prim_exp_arg, row);
4520    }
4521 
4522    if (parameters) {
4523       /* Export generic attributes on GFX10.3
4524        * (On GFX11 they are already stored in the attribute ring.)
4525        */
4526       if (s->has_param_exports && s->gfx_level == GFX10_3) {
4527          ac_nir_export_parameters(b, s->vs_output_param_offset, per_primitive_outputs, 0,
4528                                   s->outputs, NULL, NULL);
4529       }
4530 
4531       /* GFX11+: also store special outputs to the attribute ring so PS can load them. */
4532       if (s->gfx_level >= GFX11)
4533          ms_emit_attribute_ring_output_stores(b, per_primitive_outputs & MS_PRIM_ARG_EXP_MASK, index, s);
4534    }
4535 }
4536 
4537 static void
emit_ms_outputs(nir_builder * b,nir_def * invocation_index,nir_def * row_start,nir_def * count,bool exports,bool parameters,uint64_t mask,void (* cb)(nir_builder *,nir_def *,nir_def *,bool,bool,uint64_t,lower_ngg_ms_state *),lower_ngg_ms_state * s)4538 emit_ms_outputs(nir_builder *b, nir_def *invocation_index, nir_def *row_start,
4539                 nir_def *count, bool exports, bool parameters, uint64_t mask,
4540                 void (*cb)(nir_builder *, nir_def *, nir_def *, bool, bool,
4541                            uint64_t, lower_ngg_ms_state *),
4542                 lower_ngg_ms_state *s)
4543 {
4544    if (cb == &emit_ms_primitive ? s->prim_multirow_export : s->vert_multirow_export) {
4545       assert(s->hw_workgroup_size % s->wave_size == 0);
4546       const unsigned num_waves = s->hw_workgroup_size / s->wave_size;
4547 
4548       nir_loop *row_loop = nir_push_loop(b);
4549       {
4550          nir_block *preheader = nir_cf_node_as_block(nir_cf_node_prev(&row_loop->cf_node));
4551 
4552          nir_phi_instr *index = nir_phi_instr_create(b->shader);
4553          nir_phi_instr *row = nir_phi_instr_create(b->shader);
4554          nir_def_init(&index->instr, &index->def, 1, 32);
4555          nir_def_init(&row->instr, &row->def, 1, 32);
4556 
4557          nir_phi_instr_add_src(index, preheader, invocation_index);
4558          nir_phi_instr_add_src(row, preheader, row_start);
4559 
4560          nir_if *if_break = nir_push_if(b, nir_uge(b, &index->def, count));
4561          {
4562             nir_jump(b, nir_jump_break);
4563          }
4564          nir_pop_if(b, if_break);
4565 
4566          cb(b, &index->def, &row->def, exports, parameters, mask, s);
4567 
4568          nir_block *body = nir_cursor_current_block(b->cursor);
4569          nir_phi_instr_add_src(index, body,
4570                                nir_iadd_imm(b, &index->def, s->hw_workgroup_size));
4571          nir_phi_instr_add_src(row, body,
4572                                nir_iadd_imm(b, &row->def, num_waves));
4573 
4574          nir_instr_insert_before_cf_list(&row_loop->body, &row->instr);
4575          nir_instr_insert_before_cf_list(&row_loop->body, &index->instr);
4576       }
4577       nir_pop_loop(b, row_loop);
4578    } else {
4579       nir_def *has_output = nir_ilt(b, invocation_index, count);
4580       nir_if *if_has_output = nir_push_if(b, has_output);
4581       {
4582          cb(b, invocation_index, row_start, exports, parameters, mask, s);
4583       }
4584       nir_pop_if(b, if_has_output);
4585    }
4586 }
4587 
4588 static void
emit_ms_finale(nir_builder * b,lower_ngg_ms_state * s)4589 emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
4590 {
4591    /* We assume there is always a single end block in the shader. */
4592    nir_block *last_block = nir_impl_last_block(b->impl);
4593    b->cursor = nir_after_block(last_block);
4594 
4595    nir_barrier(b, .execution_scope=SCOPE_WORKGROUP, .memory_scope=SCOPE_WORKGROUP,
4596                          .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_shader_out|nir_var_mem_shared);
4597 
4598    nir_def *num_prm;
4599    nir_def *num_vtx;
4600 
4601    set_ms_final_output_counts(b, s, &num_prm, &num_vtx);
4602 
4603    nir_def *invocation_index = nir_load_local_invocation_index(b);
4604 
4605    ms_prim_gen_query(b, invocation_index, num_prm, s);
4606 
4607    nir_def *row_start = NULL;
4608    if (s->fast_launch_2)
4609       row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : nir_load_subgroup_id(b);
4610 
4611    /* Load vertex/primitive attributes from shared memory and
4612     * emit store_output intrinsics for them.
4613     *
4614     * Contrary to the semantics of the API mesh shader, these are now
4615     * compliant with NGG HW semantics, meaning that these store the
4616     * current thread's vertex attributes in a way the HW can export.
4617     */
4618 
4619    uint64_t per_vertex_outputs =
4620       s->per_vertex_outputs & ~s->layout.attr_ring.vtx_attr.mask;
4621    uint64_t per_primitive_outputs =
4622       s->per_primitive_outputs & ~s->layout.attr_ring.prm_attr.mask & ~SPECIAL_MS_OUT_MASK;
4623 
4624    /* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
4625    if (s->insert_layer_output) {
4626       b->shader->info.outputs_written |= VARYING_BIT_LAYER;
4627       b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
4628       per_primitive_outputs |= VARYING_BIT_LAYER;
4629    }
4630 
4631    const bool has_special_param_exports =
4632       (per_vertex_outputs & MS_VERT_ARG_EXP_MASK) ||
4633       (per_primitive_outputs & MS_PRIM_ARG_EXP_MASK);
4634 
4635    const bool wait_attr_ring = must_wait_attr_ring(s->gfx_level, has_special_param_exports);
4636 
4637    /* Export vertices. */
4638    if ((per_vertex_outputs & ~VARYING_BIT_POS) || !wait_attr_ring) {
4639       emit_ms_outputs(b, invocation_index, row_start, num_vtx, !wait_attr_ring, true,
4640                       per_vertex_outputs, &emit_ms_vertex, s);
4641    }
4642 
4643    /* Export primitives. */
4644    if (per_primitive_outputs || !wait_attr_ring) {
4645       emit_ms_outputs(b, invocation_index, row_start, num_prm, !wait_attr_ring, true,
4646                       per_primitive_outputs, &emit_ms_primitive, s);
4647    }
4648 
4649    /* When we need to wait for attribute ring stores, we emit both position and primitive
4650     * export instructions after a barrier to make sure both per-vertex and per-primitive
4651     * attribute ring stores are finished before the GPU starts rasterization.
4652     */
4653    if (wait_attr_ring) {
4654       /* Wait for attribute stores to finish. */
4655       nir_barrier(b, .execution_scope = SCOPE_SUBGROUP,
4656                      .memory_scope = SCOPE_DEVICE,
4657                      .memory_semantics = NIR_MEMORY_RELEASE,
4658                      .memory_modes = nir_var_shader_out);
4659 
4660       /* Position/primitive export only */
4661       emit_ms_outputs(b, invocation_index, row_start, num_vtx, true, false,
4662                       per_vertex_outputs, &emit_ms_vertex, s);
4663       emit_ms_outputs(b, invocation_index, row_start, num_prm, true, false,
4664                       per_primitive_outputs, &emit_ms_primitive, s);
4665    }
4666 }
4667 
4668 static void
handle_smaller_ms_api_workgroup(nir_builder * b,lower_ngg_ms_state * s)4669 handle_smaller_ms_api_workgroup(nir_builder *b,
4670                                 lower_ngg_ms_state *s)
4671 {
4672    if (s->api_workgroup_size >= s->hw_workgroup_size)
4673       return;
4674 
4675    /* Handle barriers manually when the API workgroup
4676     * size is less than the HW workgroup size.
4677     *
4678     * The problem is that the real workgroup launched on NGG HW
4679     * will be larger than the size specified by the API, and the
4680     * extra waves need to keep up with barriers in the API waves.
4681     *
4682     * There are 2 different cases:
4683     * 1. The whole API workgroup fits in a single wave.
4684     *    We can shrink the barriers to subgroup scope and
4685     *    don't need to insert any extra ones.
4686     * 2. The API workgroup occupies multiple waves, but not
4687     *    all. In this case, we emit code that consumes every
4688     *    barrier on the extra waves.
4689     */
4690    assert(s->hw_workgroup_size % s->wave_size == 0);
4691    bool scan_barriers = ALIGN(s->api_workgroup_size, s->wave_size) < s->hw_workgroup_size;
4692    bool can_shrink_barriers = s->api_workgroup_size <= s->wave_size;
4693    bool need_additional_barriers = scan_barriers && !can_shrink_barriers;
4694 
4695    unsigned api_waves_in_flight_addr = s->layout.lds.workgroup_info_addr + lds_ms_num_api_waves;
4696    unsigned num_api_waves = DIV_ROUND_UP(s->api_workgroup_size, s->wave_size);
4697 
4698    /* Scan the shader for workgroup barriers. */
4699    if (scan_barriers) {
4700       bool has_any_workgroup_barriers = false;
4701 
4702       nir_foreach_block(block, b->impl) {
4703          nir_foreach_instr_safe(instr, block) {
4704             if (instr->type != nir_instr_type_intrinsic)
4705                continue;
4706 
4707             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
4708             bool is_workgroup_barrier =
4709                intrin->intrinsic == nir_intrinsic_barrier &&
4710                nir_intrinsic_execution_scope(intrin) == SCOPE_WORKGROUP;
4711 
4712             if (!is_workgroup_barrier)
4713                continue;
4714 
4715             if (can_shrink_barriers) {
4716                /* Every API invocation runs in the first wave.
4717                 * In this case, we can change the barriers to subgroup scope
4718                 * and avoid adding additional barriers.
4719                 */
4720                nir_intrinsic_set_memory_scope(intrin, SCOPE_SUBGROUP);
4721                nir_intrinsic_set_execution_scope(intrin, SCOPE_SUBGROUP);
4722             } else {
4723                has_any_workgroup_barriers = true;
4724             }
4725          }
4726       }
4727 
4728       need_additional_barriers &= has_any_workgroup_barriers;
4729    }
4730 
4731    /* Extract the full control flow of the shader. */
4732    nir_cf_list extracted;
4733    nir_cf_extract(&extracted, nir_before_impl(b->impl),
4734                   nir_after_cf_list(&b->impl->body));
4735    b->cursor = nir_before_impl(b->impl);
4736 
4737    /* Wrap the shader in an if to ensure that only the necessary amount of lanes run it. */
4738    nir_def *invocation_index = nir_load_local_invocation_index(b);
4739    nir_def *zero = nir_imm_int(b, 0);
4740 
4741    if (need_additional_barriers) {
4742       /* First invocation stores 0 to number of API waves in flight. */
4743       nir_if *if_first_in_workgroup = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
4744       {
4745          nir_store_shared(b, nir_imm_int(b, num_api_waves), zero, .base = api_waves_in_flight_addr);
4746       }
4747       nir_pop_if(b, if_first_in_workgroup);
4748 
4749       nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4750                             .memory_scope = SCOPE_WORKGROUP,
4751                             .memory_semantics = NIR_MEMORY_ACQ_REL,
4752                             .memory_modes = nir_var_shader_out | nir_var_mem_shared);
4753    }
4754 
4755    nir_def *has_api_ms_invocation = nir_ult_imm(b, invocation_index, s->api_workgroup_size);
4756    nir_if *if_has_api_ms_invocation = nir_push_if(b, has_api_ms_invocation);
4757    {
4758       nir_cf_reinsert(&extracted, b->cursor);
4759       b->cursor = nir_after_cf_list(&if_has_api_ms_invocation->then_list);
4760 
4761       if (need_additional_barriers) {
4762          /* One invocation in each API wave decrements the number of API waves in flight. */
4763          nir_if *if_elected_again = nir_push_if(b, nir_elect(b, 1));
4764          {
4765             nir_shared_atomic(b, 32, zero, nir_imm_int(b, -1u),
4766                               .base = api_waves_in_flight_addr,
4767                               .atomic_op = nir_atomic_op_iadd);
4768          }
4769          nir_pop_if(b, if_elected_again);
4770 
4771          nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4772                                .memory_scope = SCOPE_WORKGROUP,
4773                                .memory_semantics = NIR_MEMORY_ACQ_REL,
4774                                .memory_modes = nir_var_shader_out | nir_var_mem_shared);
4775       }
4776 
4777       ms_invocation_query(b, invocation_index, s);
4778    }
4779    nir_pop_if(b, if_has_api_ms_invocation);
4780 
4781    if (need_additional_barriers) {
4782       /* Make sure that waves that don't run any API invocations execute
4783        * the same amount of barriers as those that do.
4784        *
4785        * We do this by executing a barrier until the number of API waves
4786        * in flight becomes zero.
4787        */
4788       nir_def *has_api_ms_ballot = nir_ballot(b, 1, s->wave_size, has_api_ms_invocation);
4789       nir_def *wave_has_no_api_ms = nir_ieq_imm(b, has_api_ms_ballot, 0);
4790       nir_if *if_wave_has_no_api_ms = nir_push_if(b, wave_has_no_api_ms);
4791       {
4792          nir_if *if_elected = nir_push_if(b, nir_elect(b, 1));
4793          {
4794             nir_loop *loop = nir_push_loop(b);
4795             {
4796                nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
4797                                      .memory_scope = SCOPE_WORKGROUP,
4798                                      .memory_semantics = NIR_MEMORY_ACQ_REL,
4799                                      .memory_modes = nir_var_shader_out | nir_var_mem_shared);
4800 
4801                nir_def *loaded = nir_load_shared(b, 1, 32, zero, .base = api_waves_in_flight_addr);
4802                nir_if *if_break = nir_push_if(b, nir_ieq_imm(b, loaded, 0));
4803                {
4804                   nir_jump(b, nir_jump_break);
4805                }
4806                nir_pop_if(b, if_break);
4807             }
4808             nir_pop_loop(b, loop);
4809          }
4810          nir_pop_if(b, if_elected);
4811       }
4812       nir_pop_if(b, if_wave_has_no_api_ms);
4813    }
4814 }
4815 
4816 static void
ms_move_output(ms_out_part * from,ms_out_part * to)4817 ms_move_output(ms_out_part *from, ms_out_part *to)
4818 {
4819    uint64_t loc = util_logbase2_64(from->mask);
4820    uint64_t bit = BITFIELD64_BIT(loc);
4821    from->mask ^= bit;
4822    to->mask |= bit;
4823 }
4824 
4825 static void
ms_calculate_arrayed_output_layout(ms_out_mem_layout * l,unsigned max_vertices,unsigned max_primitives)4826 ms_calculate_arrayed_output_layout(ms_out_mem_layout *l,
4827                                    unsigned max_vertices,
4828                                    unsigned max_primitives)
4829 {
4830    uint32_t lds_vtx_attr_size = util_bitcount64(l->lds.vtx_attr.mask) * max_vertices * 16;
4831    uint32_t lds_prm_attr_size = util_bitcount64(l->lds.prm_attr.mask) * max_primitives * 16;
4832    l->lds.prm_attr.addr = ALIGN(l->lds.vtx_attr.addr + lds_vtx_attr_size, 16);
4833    l->lds.total_size = l->lds.prm_attr.addr + lds_prm_attr_size;
4834 
4835    uint32_t scratch_ring_vtx_attr_size =
4836       util_bitcount64(l->scratch_ring.vtx_attr.mask) * max_vertices * 16;
4837    l->scratch_ring.prm_attr.addr =
4838       ALIGN(l->scratch_ring.vtx_attr.addr + scratch_ring_vtx_attr_size, 16);
4839 }
4840 
4841 static ms_out_mem_layout
ms_calculate_output_layout(enum amd_gfx_level gfx_level,unsigned api_shared_size,uint64_t per_vertex_output_mask,uint64_t per_primitive_output_mask,uint64_t cross_invocation_output_access,unsigned max_vertices,unsigned max_primitives,unsigned vertices_per_prim)4842 ms_calculate_output_layout(enum amd_gfx_level gfx_level, unsigned api_shared_size,
4843                            uint64_t per_vertex_output_mask, uint64_t per_primitive_output_mask,
4844                            uint64_t cross_invocation_output_access, unsigned max_vertices,
4845                            unsigned max_primitives, unsigned vertices_per_prim)
4846 {
4847    /* These outputs always need export instructions and can't use the attributes ring. */
4848    const uint64_t always_export_mask =
4849       VARYING_BIT_POS | VARYING_BIT_CULL_DIST0 | VARYING_BIT_CULL_DIST1 | VARYING_BIT_CLIP_DIST0 |
4850       VARYING_BIT_CLIP_DIST1 | VARYING_BIT_PSIZ | VARYING_BIT_VIEWPORT |
4851       VARYING_BIT_PRIMITIVE_SHADING_RATE | VARYING_BIT_LAYER |
4852       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_COUNT) |
4853       BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES) | BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
4854 
4855    const bool use_attr_ring = gfx_level >= GFX11;
4856    const uint64_t attr_ring_per_vertex_output_mask =
4857       use_attr_ring ? per_vertex_output_mask & ~always_export_mask : 0;
4858    const uint64_t attr_ring_per_primitive_output_mask =
4859       use_attr_ring ? per_primitive_output_mask & ~always_export_mask : 0;
4860 
4861    const uint64_t lds_per_vertex_output_mask =
4862       per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & cross_invocation_output_access &
4863       ~SPECIAL_MS_OUT_MASK;
4864    const uint64_t lds_per_primitive_output_mask =
4865       per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
4866       cross_invocation_output_access & ~SPECIAL_MS_OUT_MASK;
4867 
4868    const bool cross_invocation_indices =
4869       cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
4870    const bool cross_invocation_cull_primitive =
4871       cross_invocation_output_access & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
4872 
4873    /* Shared memory used by the API shader. */
4874    ms_out_mem_layout l = { .lds = { .total_size = api_shared_size } };
4875 
4876    /* GFX11+: use attribute ring for all generic attributes. */
4877    l.attr_ring.vtx_attr.mask = attr_ring_per_vertex_output_mask;
4878    l.attr_ring.prm_attr.mask = attr_ring_per_primitive_output_mask;
4879 
4880    /* Outputs without cross-invocation access can be stored in variables. */
4881    l.var.vtx_attr.mask =
4882       per_vertex_output_mask & ~attr_ring_per_vertex_output_mask & ~cross_invocation_output_access;
4883    l.var.prm_attr.mask = per_primitive_output_mask & ~attr_ring_per_primitive_output_mask &
4884                          ~cross_invocation_output_access;
4885 
4886    /* Workgroup information, see ms_workgroup_* for the layout. */
4887    l.lds.workgroup_info_addr = ALIGN(l.lds.total_size, 16);
4888    l.lds.total_size = l.lds.workgroup_info_addr + 16;
4889 
4890    /* Per-vertex and per-primitive output attributes.
4891     * Outputs without cross-invocation access are not included here.
4892     * First, try to put all outputs into LDS (shared memory).
4893     * If they don't fit, try to move them to VRAM one by one.
4894     */
4895    l.lds.vtx_attr.addr = ALIGN(l.lds.total_size, 16);
4896    l.lds.vtx_attr.mask = lds_per_vertex_output_mask;
4897    l.lds.prm_attr.mask = lds_per_primitive_output_mask;
4898    ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
4899 
4900    /* NGG shaders can only address up to 32K LDS memory.
4901     * The spec requires us to allow the application to use at least up to 28K
4902     * shared memory. Additionally, we reserve 2K for driver internal use
4903     * (eg. primitive indices and such, see below).
4904     *
4905     * Move the outputs that do not fit LDS, to VRAM.
4906     * Start with per-primitive attributes, because those are grouped at the end.
4907     */
4908    const unsigned usable_lds_kbytes =
4909       (cross_invocation_cull_primitive || cross_invocation_indices) ? 30 : 31;
4910    while (l.lds.total_size >= usable_lds_kbytes * 1024) {
4911       if (l.lds.prm_attr.mask)
4912          ms_move_output(&l.lds.prm_attr, &l.scratch_ring.prm_attr);
4913       else if (l.lds.vtx_attr.mask)
4914          ms_move_output(&l.lds.vtx_attr, &l.scratch_ring.vtx_attr);
4915       else
4916          unreachable("API shader uses too much shared memory.");
4917 
4918       ms_calculate_arrayed_output_layout(&l, max_vertices, max_primitives);
4919    }
4920 
4921    if (cross_invocation_indices) {
4922       /* Indices: flat array of 8-bit vertex indices for each primitive. */
4923       l.lds.indices_addr = ALIGN(l.lds.total_size, 16);
4924       l.lds.total_size = l.lds.indices_addr + max_primitives * vertices_per_prim;
4925    }
4926 
4927    if (cross_invocation_cull_primitive) {
4928       /* Cull flags: array of 8-bit cull flags for each primitive, 1=cull, 0=keep. */
4929       l.lds.cull_flags_addr = ALIGN(l.lds.total_size, 16);
4930       l.lds.total_size = l.lds.cull_flags_addr + max_primitives;
4931    }
4932 
4933    /* NGG is only allowed to address up to 32K of LDS. */
4934    assert(l.lds.total_size <= 32 * 1024);
4935    return l;
4936 }
4937 
4938 void
ac_nir_lower_ngg_ms(nir_shader * shader,enum amd_gfx_level gfx_level,uint32_t clipdist_enable_mask,const uint8_t * vs_output_param_offset,bool has_param_exports,bool * out_needs_scratch_ring,unsigned wave_size,unsigned hw_workgroup_size,bool multiview,bool has_query,bool fast_launch_2)4939 ac_nir_lower_ngg_ms(nir_shader *shader,
4940                     enum amd_gfx_level gfx_level,
4941                     uint32_t clipdist_enable_mask,
4942                     const uint8_t *vs_output_param_offset,
4943                     bool has_param_exports,
4944                     bool *out_needs_scratch_ring,
4945                     unsigned wave_size,
4946                     unsigned hw_workgroup_size,
4947                     bool multiview,
4948                     bool has_query,
4949                     bool fast_launch_2)
4950 {
4951    unsigned vertices_per_prim =
4952       mesa_vertices_per_prim(shader->info.mesh.primitive_type);
4953 
4954    uint64_t per_vertex_outputs =
4955       shader->info.outputs_written & ~shader->info.per_primitive_outputs & ~SPECIAL_MS_OUT_MASK;
4956    uint64_t per_primitive_outputs =
4957       shader->info.per_primitive_outputs & shader->info.outputs_written;
4958 
4959    /* Whether the shader uses CullPrimitiveEXT */
4960    bool uses_cull = shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_CULL_PRIMITIVE);
4961    /* Can't handle indirect register addressing, pretend as if they were cross-invocation. */
4962    uint64_t cross_invocation_access = shader->info.mesh.ms_cross_invocation_output_access |
4963                                       shader->info.outputs_accessed_indirectly;
4964 
4965    unsigned max_vertices = shader->info.mesh.max_vertices_out;
4966    unsigned max_primitives = shader->info.mesh.max_primitives_out;
4967 
4968    ms_out_mem_layout layout = ms_calculate_output_layout(
4969       gfx_level, shader->info.shared_size, per_vertex_outputs, per_primitive_outputs,
4970       cross_invocation_access, max_vertices, max_primitives, vertices_per_prim);
4971 
4972    shader->info.shared_size = layout.lds.total_size;
4973    *out_needs_scratch_ring = layout.scratch_ring.vtx_attr.mask || layout.scratch_ring.prm_attr.mask;
4974 
4975    /* The workgroup size that is specified by the API shader may be different
4976     * from the size of the workgroup that actually runs on the HW, due to the
4977     * limitations of NGG: max 0/1 vertex and 0/1 primitive per lane is allowed.
4978     *
4979     * Therefore, we must make sure that when the API workgroup size is smaller,
4980     * we don't run the API shader on more HW invocations than is necessary.
4981     */
4982    unsigned api_workgroup_size = shader->info.workgroup_size[0] *
4983                                  shader->info.workgroup_size[1] *
4984                                  shader->info.workgroup_size[2];
4985 
4986    lower_ngg_ms_state state = {
4987       .layout = layout,
4988       .wave_size = wave_size,
4989       .per_vertex_outputs = per_vertex_outputs,
4990       .per_primitive_outputs = per_primitive_outputs,
4991       .vertices_per_prim = vertices_per_prim,
4992       .api_workgroup_size = api_workgroup_size,
4993       .hw_workgroup_size = hw_workgroup_size,
4994       .insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
4995       .uses_cull_flags = uses_cull,
4996       .gfx_level = gfx_level,
4997       .fast_launch_2 = fast_launch_2,
4998       .vert_multirow_export = fast_launch_2 && max_vertices > hw_workgroup_size,
4999       .prim_multirow_export = fast_launch_2 && max_primitives > hw_workgroup_size,
5000       .clipdist_enable_mask = clipdist_enable_mask,
5001       .vs_output_param_offset = vs_output_param_offset,
5002       .has_param_exports = has_param_exports,
5003       .has_query = has_query,
5004    };
5005 
5006    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
5007    assert(impl);
5008 
5009    state.vertex_count_var =
5010       nir_local_variable_create(impl, glsl_uint_type(), "vertex_count_var");
5011    state.primitive_count_var =
5012       nir_local_variable_create(impl, glsl_uint_type(), "primitive_count_var");
5013 
5014    nir_builder builder = nir_builder_at(nir_before_impl(impl));
5015    nir_builder *b = &builder; /* This is to avoid the & */
5016 
5017    handle_smaller_ms_api_workgroup(b, &state);
5018    if (!fast_launch_2)
5019       ms_emit_legacy_workgroup_index(b, &state);
5020    ms_create_same_invocation_vars(b, &state);
5021    nir_metadata_preserve(impl, nir_metadata_none);
5022 
5023    lower_ms_intrinsics(shader, &state);
5024 
5025    emit_ms_finale(b, &state);
5026    nir_metadata_preserve(impl, nir_metadata_none);
5027 
5028    /* Cleanup */
5029    nir_lower_vars_to_ssa(shader);
5030    nir_remove_dead_variables(shader, nir_var_function_temp, NULL);
5031    nir_lower_alu_to_scalar(shader, NULL, NULL);
5032    nir_lower_phis_to_scalar(shader, true);
5033 
5034    /* Optimize load_local_invocation_index. When the API workgroup is smaller than the HW workgroup,
5035     * local_invocation_id isn't initialized for all lanes and we can't perform this optimization for
5036     * all load_local_invocation_index.
5037     */
5038    if (fast_launch_2 && api_workgroup_size == hw_workgroup_size &&
5039        ((shader->info.workgroup_size[0] == 1) + (shader->info.workgroup_size[1] == 1) +
5040         (shader->info.workgroup_size[2] == 1)) == 2) {
5041       nir_lower_compute_system_values_options csv_options = {
5042          .lower_local_invocation_index = true,
5043       };
5044       nir_lower_compute_system_values(shader, &csv_options);
5045    }
5046 
5047    nir_validate_shader(shader, "after emitting NGG MS");
5048 }
5049