xref: /aosp_15_r20/external/mesa3d/src/microsoft/spirv_to_dxil/dxil_spirv_nir_lower_bindless.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2 * Copyright © Microsoft Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27 
28 #include "dxil_spirv_nir.h"
29 #include "dxil_nir.h"
30 #include "vulkan/vulkan_core.h"
31 
32 const uint32_t descriptor_size = sizeof(struct dxil_spirv_bindless_entry);
33 
34 static void
type_size_align_1(const struct glsl_type * type,unsigned * size,unsigned * align)35 type_size_align_1(const struct glsl_type *type, unsigned *size, unsigned *align)
36 {
37    if (glsl_type_is_array(type))
38       *size = glsl_get_aoa_size(type);
39    else
40       *size = 1;
41    *align = *size;
42 }
43 
44 static nir_def *
load_vulkan_ssbo(nir_builder * b,unsigned buf_idx,nir_def * offset,unsigned num_comps)45 load_vulkan_ssbo(nir_builder *b, unsigned buf_idx,
46                  nir_def *offset, unsigned num_comps)
47 {
48    nir_def *res_index =
49       nir_vulkan_resource_index(b, 2, 32,
50                                 nir_imm_int(b, 0),
51                                 .desc_set = 0,
52                                 .binding = buf_idx,
53                                 .desc_type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
54    nir_def *descriptor =
55       nir_load_vulkan_descriptor(b, 2, 32, res_index,
56                                  .desc_type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
57    return nir_load_ssbo(b, num_comps, 32,
58                         nir_channel(b, descriptor, 0),
59                         offset,
60                         .align_mul = num_comps * 4,
61                         .align_offset = 0,
62                         .access = ACCESS_NON_WRITEABLE | ACCESS_CAN_REORDER);
63 }
64 
65 static nir_def *
lower_deref_to_index(nir_builder * b,nir_deref_instr * deref,bool is_sampler_handle,struct dxil_spirv_nir_lower_bindless_options * options)66 lower_deref_to_index(nir_builder *b, nir_deref_instr *deref, bool is_sampler_handle,
67                      struct dxil_spirv_nir_lower_bindless_options *options)
68 {
69    nir_variable *var = nir_deref_instr_get_variable(deref);
70    if (!var)
71       return NULL;
72 
73    struct dxil_spirv_binding_remapping remap = {
74       .descriptor_set = var->data.descriptor_set,
75       .binding = var->data.binding,
76       .is_sampler = is_sampler_handle
77    };
78    options->remap_binding(&remap, options->callback_context);
79    if (remap.descriptor_set == ~0)
80       return NULL;
81 
82    nir_def *index_in_ubo =
83       nir_iadd_imm(b,
84                    nir_build_deref_offset(b, deref, type_size_align_1),
85                    remap.binding);
86    nir_def *offset = nir_imul_imm(b, index_in_ubo, descriptor_size);
87    if (is_sampler_handle)
88       offset = nir_iadd_imm(b, offset, 4);
89    return load_vulkan_ssbo(b,
90                            var->data.descriptor_set,
91                            offset,
92                            1);
93 }
94 
95 static bool
lower_vulkan_resource_index(nir_builder * b,nir_intrinsic_instr * intr,struct dxil_spirv_nir_lower_bindless_options * options)96 lower_vulkan_resource_index(nir_builder *b, nir_intrinsic_instr *intr,
97                             struct dxil_spirv_nir_lower_bindless_options *options)
98 {
99    struct dxil_spirv_binding_remapping remap = {
100       .descriptor_set = nir_intrinsic_desc_set(intr),
101       .binding = nir_intrinsic_binding(intr)
102    };
103    if (remap.descriptor_set >= options->num_descriptor_sets)
104       return false;
105 
106    options->remap_binding(&remap, options->callback_context);
107    b->cursor = nir_before_instr(&intr->instr);
108    nir_def *index = intr->src[0].ssa;
109    nir_def *index_in_ubo = nir_iadd_imm(b, index, remap.binding);
110    nir_def *res_idx =
111       load_vulkan_ssbo(b, remap.descriptor_set, nir_imul_imm(b, index_in_ubo, descriptor_size), 2);
112 
113    nir_def_rewrite_uses(&intr->def, res_idx);
114    return true;
115 }
116 
117 static bool
lower_bindless_tex_src(nir_builder * b,nir_tex_instr * tex,nir_tex_src_type old,nir_tex_src_type new,bool is_sampler_handle,struct dxil_spirv_nir_lower_bindless_options * options)118 lower_bindless_tex_src(nir_builder *b, nir_tex_instr *tex,
119                        nir_tex_src_type old, nir_tex_src_type new,
120                        bool is_sampler_handle,
121                        struct dxil_spirv_nir_lower_bindless_options *options)
122 {
123    int index = nir_tex_instr_src_index(tex, old);
124    if (index == -1)
125       return false;
126 
127    b->cursor = nir_before_instr(&tex->instr);
128    nir_deref_instr *deref = nir_src_as_deref(tex->src[index].src);
129    nir_def *handle = lower_deref_to_index(b, deref, is_sampler_handle, options);
130    if (!handle)
131       return false;
132 
133    nir_src_rewrite(&tex->src[index].src, handle);
134    tex->src[index].src_type = new;
135    return true;
136 }
137 
138 static bool
lower_bindless_tex(nir_builder * b,nir_tex_instr * tex,struct dxil_spirv_nir_lower_bindless_options * options)139 lower_bindless_tex(nir_builder *b, nir_tex_instr *tex, struct dxil_spirv_nir_lower_bindless_options *options)
140 {
141    bool texture = lower_bindless_tex_src(b, tex, nir_tex_src_texture_deref, nir_tex_src_texture_handle, false, options);
142    bool sampler = lower_bindless_tex_src(b, tex, nir_tex_src_sampler_deref, nir_tex_src_sampler_handle, true, options);
143    return texture || sampler;
144 }
145 
146 static bool
lower_bindless_image_intr(nir_builder * b,nir_intrinsic_instr * intr,struct dxil_spirv_nir_lower_bindless_options * options)147 lower_bindless_image_intr(nir_builder *b, nir_intrinsic_instr *intr, struct dxil_spirv_nir_lower_bindless_options *options)
148 {
149    b->cursor = nir_before_instr(&intr->instr);
150    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
151    nir_def *handle = lower_deref_to_index(b, deref, false, options);
152    if (!handle)
153       return false;
154 
155    nir_rewrite_image_intrinsic(intr, handle, true);
156    return true;
157 }
158 
159 static bool
lower_bindless_instr(nir_builder * b,nir_instr * instr,void * data)160 lower_bindless_instr(nir_builder *b, nir_instr *instr, void *data)
161 {
162    struct dxil_spirv_nir_lower_bindless_options *options = data;
163 
164    if (instr->type == nir_instr_type_tex)
165       return lower_bindless_tex(b, nir_instr_as_tex(instr), options);
166    if (instr->type != nir_instr_type_intrinsic)
167       return false;
168    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
169    switch (intr->intrinsic) {
170    case nir_intrinsic_image_deref_load:
171    case nir_intrinsic_image_deref_store:
172    case nir_intrinsic_image_deref_size:
173    case nir_intrinsic_image_deref_atomic:
174    case nir_intrinsic_image_deref_atomic_swap:
175       return lower_bindless_image_intr(b, intr, options);
176    case nir_intrinsic_vulkan_resource_index:
177       return lower_vulkan_resource_index(b, intr, options);
178    default:
179       return false;
180    }
181 }
182 
183 static nir_variable *
add_bindless_data_var(nir_shader * nir,unsigned binding)184 add_bindless_data_var(nir_shader *nir, unsigned binding)
185 {
186    const struct glsl_type *array_type =
187       glsl_array_type(glsl_uint_type(), 0, sizeof(unsigned));
188    const struct glsl_struct_field field = {array_type, "arr"};
189    nir_variable *var = nir_variable_create(
190       nir, nir_var_mem_ssbo,
191       glsl_struct_type(&field, 1, "bindless_data", false), "bindless_data");
192    var->data.binding = binding;
193    var->data.how_declared = nir_var_hidden;
194    var->data.read_only = 1;
195    var->data.access = ACCESS_NON_WRITEABLE;
196    return var;
197 }
198 
199 static bool
can_remove_var(nir_variable * var,void * data)200 can_remove_var(nir_variable *var, void *data)
201 {
202    struct dxil_spirv_nir_lower_bindless_options *options = data;
203    if (var->data.descriptor_set >= options->num_descriptor_sets)
204       return false;
205    if (!glsl_type_is_sampler(glsl_without_array(var->type)))
206       return true;
207    struct dxil_spirv_binding_remapping remap = {
208       .descriptor_set = var->data.descriptor_set,
209       .binding = var->data.binding,
210       .is_sampler = true,
211    };
212    options->remap_binding(&remap, options->callback_context);
213    if (remap.descriptor_set == ~0)
214       return false;
215    return true;
216 }
217 
218 bool
dxil_spirv_nir_lower_bindless(nir_shader * nir,struct dxil_spirv_nir_lower_bindless_options * options)219 dxil_spirv_nir_lower_bindless(nir_shader *nir, struct dxil_spirv_nir_lower_bindless_options *options)
220 {
221    /* While we still have derefs for images, use that to propagate type info back to image vars,
222     * and then forward to the intrinsics that reference them. */
223    bool ret = dxil_nir_guess_image_formats(nir);
224 
225    ret |= nir_shader_instructions_pass(nir, lower_bindless_instr,
226                                        nir_metadata_control_flow |
227                                        nir_metadata_loop_analysis,
228                                        options);
229    ret |= nir_remove_dead_derefs(nir);
230 
231    unsigned descriptor_sets = 0;
232    const nir_variable_mode modes = nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_image | nir_var_uniform;
233    nir_foreach_variable_with_modes(var, nir, modes) {
234       if (var->data.descriptor_set < options->num_descriptor_sets)
235          descriptor_sets |= (1 << var->data.descriptor_set);
236    }
237 
238    if (options->dynamic_buffer_binding != ~0)
239       descriptor_sets |= (1 << options->dynamic_buffer_binding);
240 
241    nir_remove_dead_variables_options dead_var_options = {
242       .can_remove_var = can_remove_var,
243       .can_remove_var_data = options
244    };
245    ret |= nir_remove_dead_variables(nir, modes, &dead_var_options);
246 
247    if (!descriptor_sets)
248       return ret;
249 
250    while (descriptor_sets) {
251       int index = u_bit_scan(&descriptor_sets);
252       add_bindless_data_var(nir, index);
253    }
254    return true;
255 }
256 
257 /* Given a global deref chain that starts as a pointer value and ends with a load/store/atomic,
258  * create a new SSBO deref chain. The new chain starts with a load_vulkan_descriptor, then casts
259  * the resulting vec2 to an SSBO deref. */
260 static bool
lower_buffer_device_address(nir_builder * b,nir_intrinsic_instr * intr,void * data)261 lower_buffer_device_address(nir_builder *b, nir_intrinsic_instr *intr, void *data)
262 {
263    switch (intr->intrinsic) {
264    case nir_intrinsic_load_deref:
265    case nir_intrinsic_store_deref:
266    case nir_intrinsic_deref_atomic:
267    case nir_intrinsic_deref_atomic_swap:
268       break;
269    default:
270       assert(intr->intrinsic != nir_intrinsic_copy_deref);
271       return false;
272    }
273    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
274    if (!nir_deref_mode_is(deref, nir_var_mem_global))
275       return false;
276 
277    nir_deref_path path;
278    nir_deref_path_init(&path, deref, NULL);
279 
280    nir_deref_instr *old_head = path.path[0];
281    assert(old_head->deref_type == nir_deref_type_cast &&
282           old_head->parent.ssa->bit_size == 64 &&
283           old_head->parent.ssa->num_components == 1);
284    b->cursor = nir_after_instr(&old_head->instr);
285    nir_def *pointer = old_head->parent.ssa;
286    nir_def *offset = nir_unpack_64_2x32_split_x(b, pointer);
287    nir_def *index = nir_iand_imm(b, nir_unpack_64_2x32_split_y(b, pointer), 0xffffff);
288 
289    nir_def *descriptor = nir_load_vulkan_descriptor(b, 2, 32, nir_vec2(b, index, offset),
290                                                     .desc_type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
291    nir_deref_instr *head = nir_build_deref_cast_with_alignment(b, descriptor, nir_var_mem_ssbo, old_head->type,
292                                                                old_head->cast.ptr_stride,
293                                                                old_head->cast.align_mul,
294                                                                old_head->cast.align_offset);
295 
296    for (int i = 1; path.path[i]; ++i) {
297       nir_deref_instr *old = path.path[i];
298       b->cursor = nir_after_instr(&old->instr);
299       head = nir_build_deref_follower(b, head, old);
300    }
301 
302    nir_src_rewrite(&intr->src[0], &head->def);
303 
304    nir_deref_path_finish(&path);
305    return true;
306 }
307 
308 bool
dxil_spirv_nir_lower_buffer_device_address(nir_shader * nir)309 dxil_spirv_nir_lower_buffer_device_address(nir_shader *nir)
310 {
311    return nir_shader_intrinsics_pass(nir, lower_buffer_device_address,
312                                      nir_metadata_control_flow |
313                                      nir_metadata_loop_analysis,
314                                      NULL);
315 }
316