xref: /aosp_15_r20/external/mesa3d/src/amd/common/ac_nir_lower_taskmesh_io_to_mem.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2022 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "ac_nir.h"
8 #include "nir_builder.h"
9 #include "amdgfxregs.h"
10 #include "util/u_math.h"
11 
12 /*
13  * These NIR passes are used to lower NIR cross-stage I/O intrinsics
14  * between task and mesh shader stages into the memory accesses
15  * that actually happen on the HW.
16  *
17  */
18 
19 typedef struct {
20    unsigned payload_entry_bytes;
21    unsigned draw_entry_bytes;
22    unsigned num_entries;
23 
24    /* True if the lowering needs to insert shader query. */
25    bool has_query;
26 } lower_tsms_io_state;
27 
28 static nir_def *
task_workgroup_index(nir_builder * b,lower_tsms_io_state * s)29 task_workgroup_index(nir_builder *b,
30                      lower_tsms_io_state *s)
31 {
32    nir_def *id = nir_load_workgroup_id(b);
33 
34    nir_def *x = nir_channel(b, id, 0);
35    nir_def *y = nir_channel(b, id, 1);
36    nir_def *z = nir_channel(b, id, 2);
37 
38    nir_def *grid_size = nir_load_num_workgroups(b);
39    nir_def *grid_size_x = nir_channel(b, grid_size, 0);
40    nir_def *grid_size_y = nir_channel(b, grid_size, 1);
41 
42    return nir_iadd(b, nir_imul(b, nir_imul(b, grid_size_x, grid_size_y), z),
43                       nir_iadd(b, nir_imul(b, grid_size_x, y), x));
44 }
45 
46 static nir_def *
task_ring_entry_index(nir_builder * b,lower_tsms_io_state * s)47 task_ring_entry_index(nir_builder *b,
48                       lower_tsms_io_state *s)
49 {
50    /* Task shader ring_entry shader argument:
51     *
52     * - It's a copy of write_ptr[31:0] from the task control buffer.
53     * - The same value (which is the initial value at dispatch)
54     *   seems to be copied to all workgroups in the same dispatch,
55     *   therefore a workgroup index needs to be added.
56     * - write_ptr must be initialized to num_entries so ring_entry needs
57     *   AND with num_entries - 1 to get the correct meaning.
58     *   Note that num_entries must be a power of two.
59     */
60    nir_def *ring_entry = nir_load_task_ring_entry_amd(b);
61    nir_def *idx = nir_iadd_nuw(b, ring_entry, task_workgroup_index(b, s));
62    return nir_iand_imm(b, idx, s->num_entries - 1);
63 }
64 
65 static nir_def *
task_draw_ready_bit(nir_builder * b,lower_tsms_io_state * s)66 task_draw_ready_bit(nir_builder *b,
67                     lower_tsms_io_state *s)
68 {
69    /* Value of the ready bit is 1 for odd and 0 for even passes through the draw ring.
70     *
71     * The ring_entry is a copy of the write_ptr. We use that to determine whether
72     * the current pass through the draw ring is odd or even, so we can write the
73     * correct value to the draw ready bit.
74     *
75     * This tells the firmware that it can now start launching mesh shader workgroups.
76     * The encoding of the last dword of the draw ring entry is:
77     * - bit 0: Draw ready bit.
78     *          Its meaning flips on every pass through the entry.
79     * - bit 1: Packet end bit.
80     *          The firmware uses this to mark the entry after the last one
81     *          used by the current task dispatch.
82     * - bits [2:31] unused.
83     *
84     * Task shaders MUST write the draw ready bit to the draw ring
85     * before they finish. The firmware waits for the shader to write
86     * this bit before it reads the mesh dispatch size to launch the
87     * mesh shader workgroups.
88     *
89     * If the task shader doesn't write this bit, the HW hangs.
90     */
91 
92    nir_def *ring_entry = nir_load_task_ring_entry_amd(b);
93    nir_def *workgroup_index = task_workgroup_index(b, s);
94 
95    nir_def *idx = nir_iadd_nuw(b, ring_entry, workgroup_index);
96    return nir_u2u8(b, nir_ubfe_imm(b, idx, util_bitcount(s->num_entries - 1), 1));
97 }
98 
99 static nir_def *
mesh_ring_entry_index(nir_builder * b,lower_tsms_io_state * s)100 mesh_ring_entry_index(nir_builder *b,
101                       lower_tsms_io_state *s)
102 {
103    /* Mesh shader ring_entry shader argument:
104     *
105     * - It's a copy of the read_ptr[31:0] from the task control buffer.
106     * - All workgroups in the same task->mesh dispatch get the same value,
107     *   which is fine because they need to read the same entry.
108     * - read_ptr must be initialized to num_entries so ring_entry needs
109     *   AND with num_entries - 1 to get the correct meaning.
110     *   Note that num_entries must be a power of two.
111     */
112    return nir_iand_imm(b, nir_load_task_ring_entry_amd(b), s->num_entries - 1);
113 }
114 
115 static void
task_write_draw_ring(nir_builder * b,nir_def * store_val,unsigned const_off,lower_tsms_io_state * s)116 task_write_draw_ring(nir_builder *b,
117                      nir_def *store_val,
118                      unsigned const_off,
119                      lower_tsms_io_state *s)
120 {
121    nir_def *ptr = task_ring_entry_index(b, s);
122    nir_def *ring = nir_load_ring_task_draw_amd(b);
123    nir_def *scalar_off = nir_imul_imm(b, ptr, s->draw_entry_bytes);
124    nir_def *vector_off = nir_imm_int(b, 0);
125    nir_def *zero = nir_imm_int(b, 0);
126 
127    nir_store_buffer_amd(b, store_val, ring, vector_off, scalar_off, zero,
128                         .base = const_off, .memory_modes = nir_var_shader_out,
129                         .access = ACCESS_COHERENT | ACCESS_CP_GE_COHERENT_AMD);
130 }
131 
132 static bool
filter_task_intrinsics(const nir_instr * instr,UNUSED const void * state)133 filter_task_intrinsics(const nir_instr *instr,
134                        UNUSED const void *state)
135 {
136    if (instr->type != nir_instr_type_intrinsic)
137       return false;
138 
139    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
140    return intrin->intrinsic == nir_intrinsic_launch_mesh_workgroups ||
141           intrin->intrinsic == nir_intrinsic_store_task_payload ||
142           intrin->intrinsic == nir_intrinsic_load_task_payload;
143 }
144 
145 static void
task_invocation_query(nir_builder * b,lower_tsms_io_state * s)146 task_invocation_query(nir_builder *b, lower_tsms_io_state *s)
147 {
148    if (!s->has_query)
149       return;
150 
151    const unsigned invocations = b->shader->info.workgroup_size[0] *
152                                 b->shader->info.workgroup_size[1] *
153                                 b->shader->info.workgroup_size[2];
154 
155    nir_if *if_pipeline_query = nir_push_if(b, nir_load_pipeline_stat_query_enabled_amd(b));
156    {
157       nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, invocations));
158    }
159    nir_pop_if(b, if_pipeline_query);
160 }
161 
162 static nir_def *
lower_task_launch_mesh_workgroups(nir_builder * b,nir_intrinsic_instr * intrin,lower_tsms_io_state * s)163 lower_task_launch_mesh_workgroups(nir_builder *b,
164                                   nir_intrinsic_instr *intrin,
165                                   lower_tsms_io_state *s)
166 {
167    /* This intrinsic must be always in uniform control flow,
168     * so we assume that all invocations are active here.
169     */
170 
171    /* Wait for all necessary stores to finish.
172     * Device memory scope is necessary because we need to ensure there is
173     * always a waitcnt_vscnt instruction in order to avoid a race condition
174     * between payload stores and their loads after mesh shaders launch.
175     */
176    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
177                          .memory_scope = SCOPE_DEVICE,
178                          .memory_semantics = NIR_MEMORY_ACQ_REL,
179                          .memory_modes = nir_var_mem_task_payload | nir_var_shader_out |
180                                          nir_var_mem_ssbo | nir_var_mem_global);
181 
182    /* On the first invocation, write the full draw ring entry. */
183    nir_def *invocation_index = nir_load_local_invocation_index(b);
184    nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, invocation_index, 0));
185    {
186       nir_def *dimensions = intrin->src[0].ssa;
187       nir_def *x = nir_channel(b, dimensions, 0);
188       nir_def *y = nir_channel(b, dimensions, 1);
189       nir_def *z = nir_channel(b, dimensions, 2);
190 
191       /* When either Y or Z are 0, also set X to 0.
192        * Not necessary, but speeds up the job of the CP.
193        */
194       x = nir_bcsel(b, nir_ieq_imm(b, nir_ior(b, y, z), 0), nir_imm_int(b, 0), x);
195 
196       /* Dispatch dimensions of mesh shader workgroups. */
197       task_write_draw_ring(b, nir_vec3(b, x, y, z), 0, s);
198       /* Prevent the two stores from being reordered. */
199       nir_scoped_memory_barrier(b, SCOPE_INVOCATION, NIR_MEMORY_RELEASE, nir_var_shader_out);
200       /* Ready bit, only write the low 8 bits. */
201       task_write_draw_ring(b, task_draw_ready_bit(b, s), 12, s);
202 
203       task_invocation_query(b, s);
204    }
205    nir_pop_if(b, if_invocation_index_zero);
206 
207    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
208 }
209 
210 static nir_def *
lower_task_payload_store(nir_builder * b,nir_intrinsic_instr * intrin,lower_tsms_io_state * s)211 lower_task_payload_store(nir_builder *b,
212                          nir_intrinsic_instr *intrin,
213                          lower_tsms_io_state *s)
214 {
215    unsigned write_mask = nir_intrinsic_write_mask(intrin);
216    unsigned base = nir_intrinsic_base(intrin);
217 
218    nir_def *store_val = intrin->src[0].ssa;
219    nir_def *addr = intrin->src[1].ssa;
220    nir_def *ring = nir_load_ring_task_payload_amd(b);
221    nir_def *ptr = task_ring_entry_index(b, s);
222    nir_def *ring_off = nir_imul_imm(b, ptr, s->payload_entry_bytes);
223    nir_def *zero = nir_imm_int(b, 0);
224 
225    nir_store_buffer_amd(b, store_val, ring, addr, ring_off, zero, .base = base,
226                         .write_mask = write_mask,
227                         .memory_modes = nir_var_mem_task_payload,
228                         .access = ACCESS_COHERENT);
229 
230    return NIR_LOWER_INSTR_PROGRESS_REPLACE;
231 }
232 
233 static nir_def *
lower_taskmesh_payload_load(nir_builder * b,nir_intrinsic_instr * intrin,lower_tsms_io_state * s)234 lower_taskmesh_payload_load(nir_builder *b,
235                             nir_intrinsic_instr *intrin,
236                             lower_tsms_io_state *s)
237 {
238    unsigned base = nir_intrinsic_base(intrin);
239    unsigned num_components = intrin->def.num_components;
240    unsigned bit_size = intrin->def.bit_size;
241 
242    nir_def *ptr =
243       b->shader->info.stage == MESA_SHADER_TASK ?
244       task_ring_entry_index(b, s) :
245       mesh_ring_entry_index(b, s);
246 
247    nir_def *addr = intrin->src[0].ssa;
248    nir_def *ring = nir_load_ring_task_payload_amd(b);
249    nir_def *ring_off = nir_imul_imm(b, ptr, s->payload_entry_bytes);
250    nir_def *zero = nir_imm_int(b, 0);
251 
252    return nir_load_buffer_amd(b, num_components, bit_size, ring, addr, ring_off, zero, .base = base,
253                               .memory_modes = nir_var_mem_task_payload,
254                               .access = ACCESS_COHERENT);
255 }
256 
257 static nir_def *
lower_task_intrinsics(nir_builder * b,nir_instr * instr,void * state)258 lower_task_intrinsics(nir_builder *b,
259                       nir_instr *instr,
260                       void *state)
261 {
262    assert(instr->type == nir_instr_type_intrinsic);
263    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
264    lower_tsms_io_state *s = (lower_tsms_io_state *)state;
265 
266    switch (intrin->intrinsic) {
267       case nir_intrinsic_store_task_payload:
268          return lower_task_payload_store(b, intrin, s);
269       case nir_intrinsic_load_task_payload:
270          return lower_taskmesh_payload_load(b, intrin, s);
271       case nir_intrinsic_launch_mesh_workgroups:
272          return lower_task_launch_mesh_workgroups(b, intrin, s);
273       default:
274          unreachable("unsupported task shader intrinsic");
275    }
276 }
277 
278 void
ac_nir_lower_task_outputs_to_mem(nir_shader * shader,unsigned task_payload_entry_bytes,unsigned task_num_entries,bool has_query)279 ac_nir_lower_task_outputs_to_mem(nir_shader *shader,
280                                  unsigned task_payload_entry_bytes,
281                                  unsigned task_num_entries,
282                                  bool has_query)
283 {
284    assert(util_is_power_of_two_nonzero(task_num_entries));
285 
286    nir_lower_task_shader_options lower_ts_opt = {
287       .payload_to_shared_for_atomics = true,
288    };
289    nir_lower_task_shader(shader, lower_ts_opt);
290 
291    lower_tsms_io_state state = {
292       .draw_entry_bytes = 16,
293       .payload_entry_bytes = task_payload_entry_bytes,
294       .num_entries = task_num_entries,
295       .has_query = has_query,
296    };
297 
298    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
299 
300    nir_shader_lower_instructions(shader,
301                                  filter_task_intrinsics,
302                                  lower_task_intrinsics,
303                                  &state);
304 
305    nir_metadata_preserve(impl, nir_metadata_none);
306    nir_validate_shader(shader, "after lowering task shader outputs to memory stores");
307 }
308 
309 static bool
filter_mesh_input_load(const nir_instr * instr,UNUSED const void * state)310 filter_mesh_input_load(const nir_instr *instr,
311                        UNUSED const void *state)
312 {
313    if (instr->type != nir_instr_type_intrinsic)
314       return false;
315 
316    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
317    return intrin->intrinsic == nir_intrinsic_load_task_payload;
318 }
319 
320 static nir_def *
lower_mesh_intrinsics(nir_builder * b,nir_instr * instr,void * state)321 lower_mesh_intrinsics(nir_builder *b,
322                       nir_instr *instr,
323                       void *state)
324 {
325    assert(instr->type == nir_instr_type_intrinsic);
326    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
327    lower_tsms_io_state *s = (lower_tsms_io_state *)state;
328 
329    if (intrin->intrinsic == nir_intrinsic_load_task_payload)
330       return lower_taskmesh_payload_load(b, intrin, s);
331    else
332       unreachable("unsupported mesh shader intrinsic");
333 }
334 
335 void
ac_nir_lower_mesh_inputs_to_mem(nir_shader * shader,unsigned task_payload_entry_bytes,unsigned task_num_entries)336 ac_nir_lower_mesh_inputs_to_mem(nir_shader *shader,
337                                 unsigned task_payload_entry_bytes,
338                                 unsigned task_num_entries)
339 {
340    assert(util_is_power_of_two_nonzero(task_num_entries));
341 
342    lower_tsms_io_state state = {
343       .draw_entry_bytes = 16,
344       .payload_entry_bytes = task_payload_entry_bytes,
345       .num_entries = task_num_entries,
346    };
347 
348    nir_shader_lower_instructions(shader,
349                                  filter_mesh_input_load,
350                                  lower_mesh_intrinsics,
351                                  &state);
352 }
353