xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/nir/radv_nir_lower_io.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2016 Red Hat.
3  * Copyright © 2016 Bas Nieuwenhuizen
4  * Copyright © 2023 Valve Corporation
5  *
6  * SPDX-License-Identifier: MIT
7  */
8 
9 #include "ac_nir.h"
10 #include "nir.h"
11 #include "nir_builder.h"
12 #include "radv_device.h"
13 #include "radv_nir.h"
14 #include "radv_physical_device.h"
15 #include "radv_shader.h"
16 
17 static int
type_size_vec4(const struct glsl_type * type,bool bindless)18 type_size_vec4(const struct glsl_type *type, bool bindless)
19 {
20    return glsl_count_attribute_slots(type, false);
21 }
22 
23 void
radv_nir_lower_io_to_scalar_early(nir_shader * nir,nir_variable_mode mask)24 radv_nir_lower_io_to_scalar_early(nir_shader *nir, nir_variable_mode mask)
25 {
26    bool progress = false;
27 
28    NIR_PASS(progress, nir, nir_lower_io_to_scalar_early, mask);
29    if (progress) {
30       /* Optimize the new vector code and then remove dead vars */
31       NIR_PASS(_, nir, nir_copy_prop);
32       NIR_PASS(_, nir, nir_opt_shrink_vectors, true);
33 
34       if (mask & nir_var_shader_out) {
35          /* Optimize swizzled movs of load_const for nir_link_opt_varyings's constant propagation. */
36          NIR_PASS(_, nir, nir_opt_constant_folding);
37 
38          /* For nir_link_opt_varyings's duplicate input opt */
39          NIR_PASS(_, nir, nir_opt_cse);
40       }
41 
42       /* Run copy-propagation to help remove dead output variables (some shaders have useless copies
43        * to/from an output), so compaction later will be more effective.
44        *
45        * This will have been done earlier but it might not have worked because the outputs were
46        * vector.
47        */
48       if (nir->info.stage == MESA_SHADER_TESS_CTRL)
49          NIR_PASS(_, nir, nir_opt_copy_prop_vars);
50 
51       NIR_PASS(_, nir, nir_opt_dce);
52       NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_in | nir_var_shader_out, NULL);
53    }
54 }
55 
56 void
radv_nir_lower_io(struct radv_device * device,nir_shader * nir)57 radv_nir_lower_io(struct radv_device *device, nir_shader *nir)
58 {
59    const struct radv_physical_device *pdev = radv_device_physical(device);
60 
61    if (nir->info.stage == MESA_SHADER_VERTEX) {
62       NIR_PASS(_, nir, nir_lower_io, nir_var_shader_in, type_size_vec4, 0);
63       NIR_PASS(_, nir, nir_lower_io, nir_var_shader_out, type_size_vec4, nir_lower_io_lower_64bit_to_32);
64    } else {
65       NIR_PASS(_, nir, nir_lower_io, nir_var_shader_in | nir_var_shader_out, type_size_vec4,
66                nir_lower_io_lower_64bit_to_32);
67    }
68 
69    /* This pass needs actual constants */
70    NIR_PASS(_, nir, nir_opt_constant_folding);
71 
72    NIR_PASS(_, nir, nir_io_add_const_offset_to_base, nir_var_shader_in | nir_var_shader_out);
73 
74    if (nir->xfb_info) {
75       NIR_PASS(_, nir, nir_io_add_intrinsic_xfb_info);
76 
77       if (pdev->use_ngg_streamout) {
78          /* The total number of shader outputs is required for computing the pervertex LDS size for
79           * VS/TES when lowering NGG streamout.
80           */
81          nir_assign_io_var_locations(nir, nir_var_shader_out, &nir->num_outputs, nir->info.stage);
82       }
83    }
84 
85    if (nir->info.stage == MESA_SHADER_FRAGMENT) {
86       /* Recompute FS input intrinsic bases to make sure that there are no gaps
87        * between the FS input slots.
88        */
89       nir_recompute_io_bases(nir, nir_var_shader_in);
90    }
91 
92    NIR_PASS_V(nir, nir_opt_dce);
93    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_shader_in | nir_var_shader_out, NULL);
94 }
95 
96 /* IO slot layout for stages that aren't linked. */
97 enum {
98    RADV_IO_SLOT_POS = 0,
99    RADV_IO_SLOT_CLIP_DIST0,
100    RADV_IO_SLOT_CLIP_DIST1,
101    RADV_IO_SLOT_PSIZ,
102    RADV_IO_SLOT_VAR0, /* 0..31 */
103 };
104 
105 unsigned
radv_map_io_driver_location(unsigned semantic)106 radv_map_io_driver_location(unsigned semantic)
107 {
108    if ((semantic >= VARYING_SLOT_PATCH0 && semantic < VARYING_SLOT_TESS_MAX) ||
109        semantic == VARYING_SLOT_TESS_LEVEL_INNER || semantic == VARYING_SLOT_TESS_LEVEL_OUTER)
110       return ac_shader_io_get_unique_index_patch(semantic);
111 
112    switch (semantic) {
113    case VARYING_SLOT_POS:
114       return RADV_IO_SLOT_POS;
115    case VARYING_SLOT_CLIP_DIST0:
116       return RADV_IO_SLOT_CLIP_DIST0;
117    case VARYING_SLOT_CLIP_DIST1:
118       return RADV_IO_SLOT_CLIP_DIST1;
119    case VARYING_SLOT_PSIZ:
120       return RADV_IO_SLOT_PSIZ;
121    default:
122       assert(semantic >= VARYING_SLOT_VAR0 && semantic <= VARYING_SLOT_VAR31);
123       return RADV_IO_SLOT_VAR0 + (semantic - VARYING_SLOT_VAR0);
124    }
125 }
126 
127 bool
radv_nir_lower_io_to_mem(struct radv_device * device,struct radv_shader_stage * stage)128 radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_stage *stage)
129 {
130    const struct radv_physical_device *pdev = radv_device_physical(device);
131    const struct radv_shader_info *info = &stage->info;
132    ac_nir_map_io_driver_location map_input = info->inputs_linked ? NULL : radv_map_io_driver_location;
133    ac_nir_map_io_driver_location map_output = info->outputs_linked ? NULL : radv_map_io_driver_location;
134    nir_shader *nir = stage->nir;
135 
136    if (nir->info.stage == MESA_SHADER_VERTEX) {
137       if (info->vs.as_ls) {
138          NIR_PASS_V(nir, ac_nir_lower_ls_outputs_to_mem, map_output, info->vs.tcs_in_out_eq,
139                     info->vs.hs_inputs_read, info->vs.tcs_temp_only_input_mask);
140          return true;
141       } else if (info->vs.as_es) {
142          NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, map_output, pdev->info.gfx_level, info->esgs_itemsize, info->gs_inputs_read);
143          return true;
144       }
145    } else if (nir->info.stage == MESA_SHADER_TESS_CTRL) {
146       NIR_PASS_V(nir, ac_nir_lower_hs_inputs_to_mem, map_input, info->vs.tcs_in_out_eq, info->vs.tcs_temp_only_input_mask);
147       NIR_PASS_V(nir, ac_nir_lower_hs_outputs_to_mem, map_output, pdev->info.gfx_level,
148                  info->tcs.tes_inputs_read, info->tcs.tes_patch_inputs_read, info->wave_size, false);
149 
150       return true;
151    } else if (nir->info.stage == MESA_SHADER_TESS_EVAL) {
152       NIR_PASS_V(nir, ac_nir_lower_tes_inputs_to_mem, map_input);
153 
154       if (info->tes.as_es) {
155          NIR_PASS_V(nir, ac_nir_lower_es_outputs_to_mem, map_output, pdev->info.gfx_level, info->esgs_itemsize, info->gs_inputs_read);
156       }
157 
158       return true;
159    } else if (nir->info.stage == MESA_SHADER_GEOMETRY) {
160       NIR_PASS_V(nir, ac_nir_lower_gs_inputs_to_mem, map_input, pdev->info.gfx_level, false);
161       return true;
162    } else if (nir->info.stage == MESA_SHADER_TASK) {
163       ac_nir_lower_task_outputs_to_mem(nir, AC_TASK_PAYLOAD_ENTRY_BYTES, pdev->task_info.num_entries,
164                                        info->cs.has_query);
165       return true;
166    } else if (nir->info.stage == MESA_SHADER_MESH) {
167       ac_nir_lower_mesh_inputs_to_mem(nir, AC_TASK_PAYLOAD_ENTRY_BYTES, pdev->task_info.num_entries);
168       return true;
169    }
170 
171    return false;
172 }
173 
174 static bool
radv_nir_lower_draw_id_to_zero_callback(struct nir_builder * b,nir_intrinsic_instr * intrin,UNUSED void * state)175 radv_nir_lower_draw_id_to_zero_callback(struct nir_builder *b, nir_intrinsic_instr *intrin, UNUSED void *state)
176 {
177    if (intrin->intrinsic != nir_intrinsic_load_draw_id)
178       return false;
179 
180    nir_def *replacement = nir_imm_zero(b, intrin->def.num_components, intrin->def.bit_size);
181    nir_def_replace(&intrin->def, replacement);
182    nir_instr_free(&intrin->instr);
183 
184    return true;
185 }
186 
187 bool
radv_nir_lower_draw_id_to_zero(nir_shader * shader)188 radv_nir_lower_draw_id_to_zero(nir_shader *shader)
189 {
190    return nir_shader_intrinsics_pass(shader, radv_nir_lower_draw_id_to_zero_callback, nir_metadata_control_flow, NULL);
191 }
192