xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_task_shader.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2022 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Timur Kristóf
25  *
26  */
27 
28 #include "util/u_math.h"
29 #include "nir.h"
30 #include "nir_builder.h"
31 
32 typedef struct {
33    uint32_t task_count_shared_addr;
34 } lower_task_nv_state;
35 
36 typedef struct {
37    /* If true, lower all task_payload I/O to use shared memory. */
38    bool payload_in_shared;
39    /* Shared memory address where task_payload will be located. */
40    uint32_t payload_shared_addr;
41    uint32_t payload_offset_in_bytes;
42 } lower_task_state;
43 
44 static bool
lower_nv_task_output(nir_builder * b,nir_instr * instr,void * state)45 lower_nv_task_output(nir_builder *b,
46                      nir_instr *instr,
47                      void *state)
48 {
49    if (instr->type != nir_instr_type_intrinsic)
50       return false;
51 
52    lower_task_nv_state *s = (lower_task_nv_state *)state;
53    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
54 
55    switch (intrin->intrinsic) {
56    case nir_intrinsic_load_output: {
57       b->cursor = nir_after_instr(instr);
58       nir_def *load =
59          nir_load_shared(b, 1, 32, nir_imm_int(b, 0),
60                          .base = s->task_count_shared_addr);
61       nir_def_rewrite_uses(&intrin->def, load);
62       nir_instr_remove(instr);
63       return true;
64    }
65 
66    case nir_intrinsic_store_output: {
67       b->cursor = nir_after_instr(instr);
68       nir_def *store_val = intrin->src[0].ssa;
69       nir_store_shared(b, store_val, nir_imm_int(b, 0),
70                        .base = s->task_count_shared_addr);
71       nir_instr_remove(instr);
72       return true;
73    }
74 
75    default:
76       return false;
77    }
78 }
79 
80 static void
append_launch_mesh_workgroups_to_nv_task(nir_builder * b,lower_task_nv_state * s)81 append_launch_mesh_workgroups_to_nv_task(nir_builder *b,
82                                          lower_task_nv_state *s)
83 {
84    /* At the beginning of the shader, write 0 to the task count.
85     * This ensures that 0 mesh workgroups are launched when the
86     * shader doesn't write the TASK_COUNT output.
87     */
88    b->cursor = nir_before_impl(b->impl);
89    nir_def *zero = nir_imm_int(b, 0);
90    nir_store_shared(b, zero, zero, .base = s->task_count_shared_addr);
91 
92    nir_barrier(b,
93                .execution_scope = SCOPE_WORKGROUP,
94                .memory_scope = SCOPE_WORKGROUP,
95                .memory_semantics = NIR_MEMORY_RELEASE,
96                .memory_modes = nir_var_mem_shared);
97 
98    /* At the end of the shader, read the task count from shared memory
99     * and emit launch_mesh_workgroups.
100     */
101    b->cursor = nir_after_cf_list(&b->impl->body);
102 
103    nir_barrier(b,
104                .execution_scope = SCOPE_WORKGROUP,
105                .memory_scope = SCOPE_WORKGROUP,
106                .memory_semantics = NIR_MEMORY_ACQUIRE,
107                .memory_modes = nir_var_mem_shared);
108 
109    nir_def *task_count =
110       nir_load_shared(b, 1, 32, zero, .base = s->task_count_shared_addr);
111 
112    /* NV_mesh_shader doesn't offer to choose which task_payload variable
113     * should be passed to mesh shaders, we just pass all.
114     */
115    uint32_t range = b->shader->info.task_payload_size;
116 
117    nir_def *one = nir_imm_int(b, 1);
118    nir_def *dispatch_3d = nir_vec3(b, task_count, one, one);
119    nir_launch_mesh_workgroups(b, dispatch_3d, .base = 0, .range = range);
120 }
121 
122 /**
123  * For NV_mesh_shader:
124  * Task shaders only have 1 output, TASK_COUNT which is a 32-bit
125  * unsigned int that contains the 1-dimensional mesh dispatch size.
126  * This output should behave like a shared variable.
127  *
128  * We lower this output to a shared variable and then we emit
129  * the new launch_mesh_workgroups intrinsic at the end of the shader.
130  */
131 static void
nir_lower_nv_task_count(nir_shader * shader)132 nir_lower_nv_task_count(nir_shader *shader)
133 {
134    lower_task_nv_state state = {
135       .task_count_shared_addr = ALIGN(shader->info.shared_size, 4),
136    };
137 
138    shader->info.shared_size += 4;
139    nir_shader_instructions_pass(shader, lower_nv_task_output,
140                                 nir_metadata_none, &state);
141 
142    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
143    nir_builder builder = nir_builder_create(impl);
144 
145    append_launch_mesh_workgroups_to_nv_task(&builder, &state);
146    nir_metadata_preserve(impl, nir_metadata_none);
147 }
148 
149 static nir_intrinsic_op
shared_opcode_for_task_payload(nir_intrinsic_op task_payload_op)150 shared_opcode_for_task_payload(nir_intrinsic_op task_payload_op)
151 {
152    switch (task_payload_op) {
153    case nir_intrinsic_task_payload_atomic:
154       return nir_intrinsic_shared_atomic;
155    case nir_intrinsic_task_payload_atomic_swap:
156       return nir_intrinsic_shared_atomic_swap;
157    case nir_intrinsic_load_task_payload:
158       return nir_intrinsic_load_shared;
159    case nir_intrinsic_store_task_payload:
160       return nir_intrinsic_store_shared;
161    default:
162       unreachable("Invalid task payload atomic");
163    }
164 }
165 
166 static bool
lower_task_payload_to_shared(nir_builder * b,nir_intrinsic_instr * intrin,lower_task_state * s)167 lower_task_payload_to_shared(nir_builder *b,
168                              nir_intrinsic_instr *intrin,
169                              lower_task_state *s)
170 {
171    /* This assumes that shared and task_payload intrinsics
172     * have the same number of sources and same indices.
173     */
174    unsigned base = nir_intrinsic_base(intrin);
175    nir_atomic_op atom_op = nir_intrinsic_has_atomic_op(intrin) ? nir_intrinsic_atomic_op(intrin) : 0;
176 
177    intrin->intrinsic = shared_opcode_for_task_payload(intrin->intrinsic);
178    nir_intrinsic_set_base(intrin, base + s->payload_shared_addr);
179 
180    if (nir_intrinsic_has_atomic_op(intrin))
181       nir_intrinsic_set_atomic_op(intrin, atom_op);
182 
183    return true;
184 }
185 
186 static void
copy_shared_to_payload(nir_builder * b,unsigned num_components,nir_def * addr,unsigned shared_base,unsigned off)187 copy_shared_to_payload(nir_builder *b,
188                        unsigned num_components,
189                        nir_def *addr,
190                        unsigned shared_base,
191                        unsigned off)
192 {
193    /* Read from shared memory. */
194    nir_def *copy = nir_load_shared(b, num_components, 32, addr,
195                                    .align_mul = 16,
196                                    .base = shared_base + off);
197 
198    /* Write to task payload memory. */
199    nir_store_task_payload(b, copy, addr, .base = off);
200 }
201 
202 static void
emit_shared_to_payload_copy(nir_builder * b,uint32_t payload_addr,uint32_t payload_size,lower_task_state * s)203 emit_shared_to_payload_copy(nir_builder *b,
204                             uint32_t payload_addr,
205                             uint32_t payload_size,
206                             lower_task_state *s)
207 {
208    /* Copy from shared memory to task payload using as much parallelism
209     * as possible. This is achieved by splitting the work into max 3 phases:
210     * 1) copy maximum number of vec4s using all invocations within workgroup
211     * 2) copy maximum number of vec4s using some invocations
212     * 3) copy remaining dwords (< 4) using only the first invocation
213     */
214    const unsigned invocations = b->shader->info.workgroup_size[0] *
215                                 b->shader->info.workgroup_size[1] *
216                                 b->shader->info.workgroup_size[2];
217    const unsigned vec4size = 16;
218    const unsigned whole_wg_vec4_copies = payload_size / vec4size;
219    const unsigned vec4_copies_per_invocation = whole_wg_vec4_copies / invocations;
220    const unsigned remaining_vec4_copies = whole_wg_vec4_copies % invocations;
221    const unsigned remaining_dwords =
222       DIV_ROUND_UP(payload_size - vec4size * vec4_copies_per_invocation * invocations - vec4size * remaining_vec4_copies,
223                    4);
224    const unsigned base_shared_addr = s->payload_shared_addr + payload_addr;
225 
226    nir_def *invocation_index = nir_load_local_invocation_index(b);
227    nir_def *addr = nir_imul_imm(b, invocation_index, vec4size);
228 
229    /* Wait for all previous shared stores to finish.
230     * This is necessary because we placed the payload in shared memory.
231     */
232    nir_barrier(b, .execution_scope = SCOPE_WORKGROUP,
233                .memory_scope = SCOPE_WORKGROUP,
234                .memory_semantics = NIR_MEMORY_ACQ_REL,
235                .memory_modes = nir_var_mem_shared);
236 
237    /* Payload_size is a size of user-accessible payload, but on some
238     * hardware (e.g. Intel) payload has a private header, which we have
239     * to offset (payload_offset_in_bytes).
240     */
241    unsigned off = s->payload_offset_in_bytes;
242 
243    /* Technically dword-alignment is not necessary for correctness
244     * of the code below, but even if backend implements unaligned
245     * load/stores, they will very likely be slow(er).
246     */
247    assert(off % 4 == 0);
248 
249    /* Copy full vec4s using all invocations in workgroup. */
250    for (unsigned i = 0; i < vec4_copies_per_invocation; ++i) {
251       copy_shared_to_payload(b, vec4size / 4, addr, base_shared_addr, off);
252       off += vec4size * invocations;
253    }
254 
255    /* Copy full vec4s using only the invocations needed to not overflow. */
256    if (remaining_vec4_copies > 0) {
257       assert(remaining_vec4_copies < invocations);
258 
259       nir_def *cmp = nir_ilt_imm(b, invocation_index, remaining_vec4_copies);
260       nir_if *if_stmt = nir_push_if(b, cmp);
261       {
262          copy_shared_to_payload(b, vec4size / 4, addr, base_shared_addr, off);
263       }
264       nir_pop_if(b, if_stmt);
265       off += vec4size * remaining_vec4_copies;
266    }
267 
268    /* Copy the last few dwords not forming full vec4. */
269    if (remaining_dwords > 0) {
270       assert(remaining_dwords < 4);
271       nir_def *cmp = nir_ieq_imm(b, invocation_index, 0);
272       nir_if *if_stmt = nir_push_if(b, cmp);
273       {
274          copy_shared_to_payload(b, remaining_dwords, addr, base_shared_addr, off);
275       }
276       nir_pop_if(b, if_stmt);
277       off += remaining_dwords * 4;
278    }
279 
280    assert(s->payload_offset_in_bytes + ALIGN(payload_size, 4) == off);
281 }
282 
283 static bool
lower_task_launch_mesh_workgroups(nir_builder * b,nir_intrinsic_instr * intrin,lower_task_state * s)284 lower_task_launch_mesh_workgroups(nir_builder *b,
285                                   nir_intrinsic_instr *intrin,
286                                   lower_task_state *s)
287 {
288    if (s->payload_in_shared) {
289       /* Copy the payload from shared memory.
290        * Because launch_mesh_workgroups may only occur in
291        * workgroup-uniform control flow, here we assume that
292        * all invocations in the workgroup are active and therefore
293        * they can all participate in the copy.
294        *
295        * TODO: Skip the copy when the mesh dispatch size is (0, 0, 0).
296        *       This is problematic because the dispatch size can be divergent,
297        *       and may differ accross subgroups.
298        */
299 
300       uint32_t payload_addr = nir_intrinsic_base(intrin);
301       uint32_t payload_size = nir_intrinsic_range(intrin);
302 
303       b->cursor = nir_before_instr(&intrin->instr);
304       emit_shared_to_payload_copy(b, payload_addr, payload_size, s);
305    }
306 
307    /* The launch_mesh_workgroups intrinsic is a terminating instruction,
308     * so let's delete everything after it.
309     */
310    b->cursor = nir_after_instr(&intrin->instr);
311    nir_block *current_block = nir_cursor_current_block(b->cursor);
312 
313    /* Delete following instructions in the current block. */
314    nir_foreach_instr_reverse_safe(instr, current_block) {
315       if (instr == &intrin->instr)
316          break;
317       nir_instr_remove(instr);
318    }
319 
320    /* Delete following CF at the same level. */
321    b->cursor = nir_after_instr(&intrin->instr);
322    nir_cf_list extracted;
323    nir_cf_node *end_node = &current_block->cf_node;
324    while (!nir_cf_node_is_last(end_node))
325       end_node = nir_cf_node_next(end_node);
326    nir_cf_extract(&extracted, b->cursor, nir_after_cf_node(end_node));
327    nir_cf_delete(&extracted);
328 
329    /* Terminate the task shader. */
330    b->cursor = nir_after_instr(&intrin->instr);
331    nir_jump(b, nir_jump_return);
332 
333    return true;
334 }
335 
336 static bool
lower_task_intrin(nir_builder * b,nir_instr * instr,void * state)337 lower_task_intrin(nir_builder *b,
338                   nir_instr *instr,
339                   void *state)
340 {
341    if (instr->type != nir_instr_type_intrinsic)
342       return false;
343 
344    lower_task_state *s = (lower_task_state *)state;
345    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
346 
347    switch (intrin->intrinsic) {
348    case nir_intrinsic_task_payload_atomic:
349    case nir_intrinsic_task_payload_atomic_swap:
350    case nir_intrinsic_store_task_payload:
351    case nir_intrinsic_load_task_payload:
352       if (s->payload_in_shared)
353          return lower_task_payload_to_shared(b, intrin, s);
354       return false;
355    case nir_intrinsic_launch_mesh_workgroups:
356       return lower_task_launch_mesh_workgroups(b, intrin, s);
357    default:
358       return false;
359    }
360 }
361 
362 static bool
requires_payload_in_shared(nir_shader * shader,bool atomics,bool small_types)363 requires_payload_in_shared(nir_shader *shader, bool atomics, bool small_types)
364 {
365    nir_foreach_function_impl(impl, shader) {
366       nir_foreach_block(block, impl) {
367          nir_foreach_instr(instr, block) {
368             if (instr->type != nir_instr_type_intrinsic)
369                continue;
370 
371             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
372             switch (intrin->intrinsic) {
373             case nir_intrinsic_task_payload_atomic:
374             case nir_intrinsic_task_payload_atomic_swap:
375                if (atomics)
376                   return true;
377                break;
378             case nir_intrinsic_load_task_payload:
379                if (small_types && intrin->def.bit_size < 32)
380                   return true;
381                break;
382             case nir_intrinsic_store_task_payload:
383                if (small_types && nir_src_bit_size(intrin->src[0]) < 32)
384                   return true;
385                break;
386             default:
387                break;
388             }
389          }
390       }
391    }
392 
393    return false;
394 }
395 
396 static bool
nir_lower_task_intrins(nir_shader * shader,lower_task_state * state)397 nir_lower_task_intrins(nir_shader *shader, lower_task_state *state)
398 {
399    return nir_shader_instructions_pass(shader, lower_task_intrin,
400                                        nir_metadata_none, state);
401 }
402 
403 /**
404  * Common Task Shader lowering to make the job of the backends easier.
405  *
406  * - Lowers NV_mesh_shader TASK_COUNT output to launch_mesh_workgroups.
407  * - Removes all code after launch_mesh_workgroups, enforcing the
408  *   fact that it's a terminating instruction.
409  * - Ensures that task shaders always have at least one
410  *   launch_mesh_workgroups instruction, so the backend doesn't
411  *   need to implement a special case when the shader doesn't have it.
412  * - Optionally, implements task_payload using shared memory when
413  *   task_payload atomics are used.
414  *   This is useful when the backend is otherwise not capable of
415  *   handling the same atomic features as it can for shared memory.
416  *   If this is used, the backend only has to implement the basic
417  *   load/store operations for task_payload.
418  *
419  * Note, this pass operates on lowered explicit I/O intrinsics, so
420  * it should be called after nir_lower_io + nir_lower_explicit_io.
421  */
422 bool
nir_lower_task_shader(nir_shader * shader,nir_lower_task_shader_options options)423 nir_lower_task_shader(nir_shader *shader,
424                       nir_lower_task_shader_options options)
425 {
426    if (shader->info.stage != MESA_SHADER_TASK)
427       return false;
428 
429    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
430    nir_builder builder = nir_builder_create(impl);
431 
432    if (shader->info.outputs_written & BITFIELD64_BIT(VARYING_SLOT_TASK_COUNT)) {
433       /* NV_mesh_shader:
434        * If the shader writes TASK_COUNT, lower that to emit
435        * the new launch_mesh_workgroups intrinsic instead.
436        */
437       NIR_PASS_V(shader, nir_lower_nv_task_count);
438    } else {
439       /* To make sure that task shaders always have a code path that
440        * executes a launch_mesh_workgroups, let's add one at the end.
441        * If the shader already had a launch_mesh_workgroups by any chance,
442        * this will be removed.
443        */
444       nir_block *last_block = nir_impl_last_block(impl);
445       builder.cursor = nir_after_block_before_jump(last_block);
446       nir_launch_mesh_workgroups(&builder, nir_imm_zero(&builder, 3, 32));
447    }
448 
449    bool atomics = options.payload_to_shared_for_atomics;
450    bool small_types = options.payload_to_shared_for_small_types;
451    bool payload_in_shared = (atomics || small_types) &&
452                             requires_payload_in_shared(shader, atomics, small_types);
453 
454    lower_task_state state = {
455       .payload_shared_addr = ALIGN(shader->info.shared_size, 16),
456       .payload_in_shared = payload_in_shared,
457       .payload_offset_in_bytes = options.payload_offset_in_bytes,
458    };
459 
460    if (payload_in_shared)
461       shader->info.shared_size =
462          state.payload_shared_addr + shader->info.task_payload_size;
463 
464    NIR_PASS(_, shader, nir_lower_task_intrins, &state);
465 
466    /* Delete all code that potentially can't be reached due to
467     * launch_mesh_workgroups being a terminating instruction.
468     */
469    NIR_PASS(_, shader, nir_lower_returns);
470 
471    bool progress;
472    do {
473       progress = false;
474       NIR_PASS(progress, shader, nir_opt_dead_cf);
475       NIR_PASS(progress, shader, nir_opt_dce);
476    } while (progress);
477    return true;
478 }
479