xref: /aosp_15_r20/external/mesa3d/src/asahi/vulkan/hk_nir_lower_descriptors.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2024 Valve Corporation
3  * Copyright 2024 Alyssa Rosenzweig
4  * Copyright 2022-2023 Collabora Ltd. and Red Hat Inc.
5  * SPDX-License-Identifier: MIT
6  */
7 #include "pipe/p_defines.h"
8 #include "vulkan/vulkan_core.h"
9 #include "agx_nir_passes.h"
10 #include "agx_pack.h"
11 #include "hk_cmd_buffer.h"
12 #include "hk_descriptor_set.h"
13 #include "hk_descriptor_set_layout.h"
14 #include "hk_shader.h"
15 
16 #include "nir.h"
17 #include "nir_builder.h"
18 #include "nir_builder_opcodes.h"
19 #include "nir_deref.h"
20 #include "nir_intrinsics.h"
21 #include "nir_intrinsics_indices.h"
22 #include "shader_enums.h"
23 #include "vk_pipeline.h"
24 
25 struct lower_descriptors_ctx {
26    const struct hk_descriptor_set_layout *set_layouts[HK_MAX_SETS];
27 
28    bool clamp_desc_array_bounds;
29    nir_address_format ubo_addr_format;
30    nir_address_format ssbo_addr_format;
31 };
32 
33 static const struct hk_descriptor_set_binding_layout *
get_binding_layout(uint32_t set,uint32_t binding,const struct lower_descriptors_ctx * ctx)34 get_binding_layout(uint32_t set, uint32_t binding,
35                    const struct lower_descriptors_ctx *ctx)
36 {
37    assert(set < HK_MAX_SETS);
38    assert(ctx->set_layouts[set] != NULL);
39 
40    const struct hk_descriptor_set_layout *set_layout = ctx->set_layouts[set];
41 
42    assert(binding < set_layout->binding_count);
43    return &set_layout->binding[binding];
44 }
45 
46 static nir_def *
load_speculatable(nir_builder * b,unsigned num_components,unsigned bit_size,nir_def * addr,unsigned align)47 load_speculatable(nir_builder *b, unsigned num_components, unsigned bit_size,
48                   nir_def *addr, unsigned align)
49 {
50    return nir_build_load_global_constant(b, num_components, bit_size, addr,
51                                          .align_mul = align,
52                                          .access = ACCESS_CAN_SPECULATE);
53 }
54 
55 static nir_def *
load_root(nir_builder * b,unsigned num_components,unsigned bit_size,nir_def * offset,unsigned align)56 load_root(nir_builder *b, unsigned num_components, unsigned bit_size,
57           nir_def *offset, unsigned align)
58 {
59    nir_def *root = nir_load_preamble(b, 1, 64, .base = HK_ROOT_UNIFORM);
60 
61    /* We've bound the address of the root descriptor, index in. */
62    nir_def *addr = nir_iadd(b, root, nir_u2u64(b, offset));
63 
64    return load_speculatable(b, num_components, bit_size, addr, align);
65 }
66 
67 static bool
lower_load_constant(nir_builder * b,nir_intrinsic_instr * load,const struct lower_descriptors_ctx * ctx)68 lower_load_constant(nir_builder *b, nir_intrinsic_instr *load,
69                     const struct lower_descriptors_ctx *ctx)
70 {
71    assert(load->intrinsic == nir_intrinsic_load_constant);
72    unreachable("todo: stick an address in the root descriptor or something");
73 
74    uint32_t base = nir_intrinsic_base(load);
75    uint32_t range = nir_intrinsic_range(load);
76 
77    b->cursor = nir_before_instr(&load->instr);
78 
79    nir_def *offset = nir_iadd_imm(b, load->src[0].ssa, base);
80    nir_def *data = nir_load_ubo(
81       b, load->def.num_components, load->def.bit_size, nir_imm_int(b, 0),
82       offset, .align_mul = nir_intrinsic_align_mul(load),
83       .align_offset = nir_intrinsic_align_offset(load), .range_base = base,
84       .range = range);
85 
86    nir_def_rewrite_uses(&load->def, data);
87 
88    return true;
89 }
90 
91 static nir_def *
load_descriptor_set_addr(nir_builder * b,uint32_t set,UNUSED const struct lower_descriptors_ctx * ctx)92 load_descriptor_set_addr(nir_builder *b, uint32_t set,
93                          UNUSED const struct lower_descriptors_ctx *ctx)
94 {
95    uint32_t set_addr_offset =
96       hk_root_descriptor_offset(sets) + set * sizeof(uint64_t);
97 
98    return load_root(b, 1, 64, nir_imm_int(b, set_addr_offset), 8);
99 }
100 
101 static nir_def *
load_dynamic_buffer_start(nir_builder * b,uint32_t set,const struct lower_descriptors_ctx * ctx)102 load_dynamic_buffer_start(nir_builder *b, uint32_t set,
103                           const struct lower_descriptors_ctx *ctx)
104 {
105    int dynamic_buffer_start_imm = 0;
106    for (uint32_t s = 0; s < set; s++) {
107       if (ctx->set_layouts[s] == NULL) {
108          dynamic_buffer_start_imm = -1;
109          break;
110       }
111 
112       dynamic_buffer_start_imm += ctx->set_layouts[s]->dynamic_buffer_count;
113    }
114 
115    if (dynamic_buffer_start_imm >= 0) {
116       return nir_imm_int(b, dynamic_buffer_start_imm);
117    } else {
118       uint32_t root_offset =
119          hk_root_descriptor_offset(set_dynamic_buffer_start) + set;
120 
121       return nir_u2u32(b, load_root(b, 1, 8, nir_imm_int(b, root_offset), 1));
122    }
123 }
124 
125 static nir_def *
load_descriptor(nir_builder * b,unsigned num_components,unsigned bit_size,uint32_t set,uint32_t binding,nir_def * index,unsigned offset_B,const struct lower_descriptors_ctx * ctx)126 load_descriptor(nir_builder *b, unsigned num_components, unsigned bit_size,
127                 uint32_t set, uint32_t binding, nir_def *index,
128                 unsigned offset_B, const struct lower_descriptors_ctx *ctx)
129 {
130    const struct hk_descriptor_set_binding_layout *binding_layout =
131       get_binding_layout(set, binding, ctx);
132 
133    if (ctx->clamp_desc_array_bounds)
134       index =
135          nir_umin(b, index, nir_imm_int(b, binding_layout->array_size - 1));
136 
137    switch (binding_layout->type) {
138    case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC:
139    case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC: {
140       /* Get the index in the root descriptor table dynamic_buffers array. */
141       nir_def *dynamic_buffer_start = load_dynamic_buffer_start(b, set, ctx);
142 
143       index = nir_iadd(b, index,
144                        nir_iadd_imm(b, dynamic_buffer_start,
145                                     binding_layout->dynamic_buffer_index));
146 
147       nir_def *root_desc_offset = nir_iadd_imm(
148          b, nir_imul_imm(b, index, sizeof(struct hk_buffer_address)),
149          hk_root_descriptor_offset(dynamic_buffers));
150 
151       assert(num_components == 4 && bit_size == 32);
152       nir_def *desc = load_root(b, 4, 32, root_desc_offset, 16);
153 
154       /* We know a priori that the the .w compnent (offset) is zero */
155       return nir_vector_insert_imm(b, desc, nir_imm_int(b, 0), 3);
156    }
157 
158    case VK_DESCRIPTOR_TYPE_INLINE_UNIFORM_BLOCK: {
159       nir_def *base_addr = nir_iadd_imm(
160          b, load_descriptor_set_addr(b, set, ctx), binding_layout->offset);
161 
162       assert(binding_layout->stride == 1);
163       const uint32_t binding_size = binding_layout->array_size;
164 
165       /* Convert it to nir_address_format_64bit_bounded_global */
166       assert(num_components == 4 && bit_size == 32);
167       return nir_vec4(b, nir_unpack_64_2x32_split_x(b, base_addr),
168                       nir_unpack_64_2x32_split_y(b, base_addr),
169                       nir_imm_int(b, binding_size), nir_imm_int(b, 0));
170    }
171 
172    default: {
173       assert(binding_layout->stride > 0);
174       nir_def *desc_ubo_offset =
175          nir_iadd_imm(b, nir_imul_imm(b, index, binding_layout->stride),
176                       binding_layout->offset + offset_B);
177 
178       unsigned desc_align_mul = (1 << (ffs(binding_layout->stride) - 1));
179       desc_align_mul = MIN2(desc_align_mul, 16);
180       unsigned desc_align_offset = binding_layout->offset + offset_B;
181       desc_align_offset %= desc_align_mul;
182 
183       nir_def *desc;
184       nir_def *set_addr = load_descriptor_set_addr(b, set, ctx);
185       desc = nir_load_global_constant_offset(
186          b, num_components, bit_size, set_addr, desc_ubo_offset,
187          .align_mul = desc_align_mul, .align_offset = desc_align_offset,
188          .access = ACCESS_CAN_SPECULATE);
189 
190       if (binding_layout->type == VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER ||
191           binding_layout->type == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER) {
192          /* We know a priori that the the .w compnent (offset) is zero */
193          assert(num_components == 4 && bit_size == 32);
194          desc = nir_vector_insert_imm(b, desc, nir_imm_int(b, 0), 3);
195       }
196       return desc;
197    }
198    }
199 }
200 
201 static bool
is_idx_intrin(nir_intrinsic_instr * intrin)202 is_idx_intrin(nir_intrinsic_instr *intrin)
203 {
204    while (intrin->intrinsic == nir_intrinsic_vulkan_resource_reindex) {
205       intrin = nir_src_as_intrinsic(intrin->src[0]);
206       if (intrin == NULL)
207          return false;
208    }
209 
210    return intrin->intrinsic == nir_intrinsic_vulkan_resource_index;
211 }
212 
213 static nir_def *
load_descriptor_for_idx_intrin(nir_builder * b,nir_intrinsic_instr * intrin,const struct lower_descriptors_ctx * ctx)214 load_descriptor_for_idx_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
215                                const struct lower_descriptors_ctx *ctx)
216 {
217    nir_def *index = nir_imm_int(b, 0);
218 
219    while (intrin->intrinsic == nir_intrinsic_vulkan_resource_reindex) {
220       index = nir_iadd(b, index, intrin->src[1].ssa);
221       intrin = nir_src_as_intrinsic(intrin->src[0]);
222    }
223 
224    assert(intrin->intrinsic == nir_intrinsic_vulkan_resource_index);
225    uint32_t set = nir_intrinsic_desc_set(intrin);
226    uint32_t binding = nir_intrinsic_binding(intrin);
227    index = nir_iadd(b, index, intrin->src[0].ssa);
228 
229    return load_descriptor(b, 4, 32, set, binding, index, 0, ctx);
230 }
231 
232 static bool
try_lower_load_vulkan_descriptor(nir_builder * b,nir_intrinsic_instr * intrin,const struct lower_descriptors_ctx * ctx)233 try_lower_load_vulkan_descriptor(nir_builder *b, nir_intrinsic_instr *intrin,
234                                  const struct lower_descriptors_ctx *ctx)
235 {
236    ASSERTED const VkDescriptorType desc_type = nir_intrinsic_desc_type(intrin);
237    b->cursor = nir_before_instr(&intrin->instr);
238 
239    nir_intrinsic_instr *idx_intrin = nir_src_as_intrinsic(intrin->src[0]);
240    if (idx_intrin == NULL || !is_idx_intrin(idx_intrin)) {
241       assert(desc_type == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER ||
242              desc_type == VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC);
243       return false;
244    }
245 
246    nir_def *desc = load_descriptor_for_idx_intrin(b, idx_intrin, ctx);
247 
248    nir_def_rewrite_uses(&intrin->def, desc);
249 
250    return true;
251 }
252 
253 static bool
_lower_sysval_to_root_table(nir_builder * b,nir_intrinsic_instr * intrin,uint32_t root_table_offset)254 _lower_sysval_to_root_table(nir_builder *b, nir_intrinsic_instr *intrin,
255                             uint32_t root_table_offset)
256 {
257    b->cursor = nir_instr_remove(&intrin->instr);
258    assert((root_table_offset & 3) == 0 && "aligned");
259 
260    nir_def *val = load_root(b, intrin->def.num_components, intrin->def.bit_size,
261                             nir_imm_int(b, root_table_offset), 4);
262 
263    nir_def_rewrite_uses(&intrin->def, val);
264 
265    return true;
266 }
267 
268 #define lower_sysval_to_root_table(b, intrin, member)                          \
269    _lower_sysval_to_root_table(b, intrin, hk_root_descriptor_offset(member))
270 
271 static bool
lower_load_push_constant(nir_builder * b,nir_intrinsic_instr * load,const struct lower_descriptors_ctx * ctx)272 lower_load_push_constant(nir_builder *b, nir_intrinsic_instr *load,
273                          const struct lower_descriptors_ctx *ctx)
274 {
275    const uint32_t push_region_offset = hk_root_descriptor_offset(push);
276    const uint32_t base = nir_intrinsic_base(load);
277 
278    b->cursor = nir_before_instr(&load->instr);
279 
280    nir_def *offset =
281       nir_iadd_imm(b, load->src[0].ssa, push_region_offset + base);
282 
283    nir_def *val = load_root(b, load->def.num_components, load->def.bit_size,
284                             offset, load->def.bit_size / 8);
285 
286    nir_def_rewrite_uses(&load->def, val);
287 
288    return true;
289 }
290 
291 static void
get_resource_deref_binding(nir_builder * b,nir_deref_instr * deref,uint32_t * set,uint32_t * binding,nir_def ** index)292 get_resource_deref_binding(nir_builder *b, nir_deref_instr *deref,
293                            uint32_t *set, uint32_t *binding, nir_def **index)
294 {
295    if (deref->deref_type == nir_deref_type_array) {
296       *index = deref->arr.index.ssa;
297       deref = nir_deref_instr_parent(deref);
298    } else {
299       *index = nir_imm_int(b, 0);
300    }
301 
302    assert(deref->deref_type == nir_deref_type_var);
303    nir_variable *var = deref->var;
304 
305    *set = var->data.descriptor_set;
306    *binding = var->data.binding;
307 }
308 
309 static nir_def *
load_resource_deref_desc(nir_builder * b,unsigned num_components,unsigned bit_size,nir_deref_instr * deref,unsigned offset_B,const struct lower_descriptors_ctx * ctx)310 load_resource_deref_desc(nir_builder *b, unsigned num_components,
311                          unsigned bit_size, nir_deref_instr *deref,
312                          unsigned offset_B,
313                          const struct lower_descriptors_ctx *ctx)
314 {
315    uint32_t set, binding;
316    nir_def *index;
317    get_resource_deref_binding(b, deref, &set, &binding, &index);
318    return load_descriptor(b, num_components, bit_size, set, binding, index,
319                           offset_B, ctx);
320 }
321 
322 /*
323  * Returns an AGX bindless handle to access an indexed image within the global
324  * image heap.
325  */
326 static nir_def *
image_heap_handle(nir_builder * b,nir_def * offset)327 image_heap_handle(nir_builder *b, nir_def *offset)
328 {
329    return nir_vec2(b, nir_imm_int(b, HK_IMAGE_HEAP_UNIFORM), offset);
330 }
331 
332 static bool
lower_image_intrin(nir_builder * b,nir_intrinsic_instr * intr,const struct lower_descriptors_ctx * ctx)333 lower_image_intrin(nir_builder *b, nir_intrinsic_instr *intr,
334                    const struct lower_descriptors_ctx *ctx)
335 {
336    b->cursor = nir_before_instr(&intr->instr);
337    nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
338 
339    /* Reads and queries use the texture descriptor; writes and atomics PBE. */
340    unsigned offs;
341    if (intr->intrinsic != nir_intrinsic_image_deref_load &&
342        intr->intrinsic != nir_intrinsic_image_deref_size &&
343        intr->intrinsic != nir_intrinsic_image_deref_samples) {
344 
345       offs = offsetof(struct hk_storage_image_descriptor, pbe_offset);
346    } else {
347       offs = offsetof(struct hk_storage_image_descriptor, tex_offset);
348    }
349 
350    nir_def *offset = load_resource_deref_desc(b, 1, 32, deref, offs, ctx);
351    nir_rewrite_image_intrinsic(intr, image_heap_handle(b, offset), true);
352 
353    return true;
354 }
355 
356 static VkQueryPipelineStatisticFlagBits
translate_pipeline_stat_bit(enum pipe_statistics_query_index pipe)357 translate_pipeline_stat_bit(enum pipe_statistics_query_index pipe)
358 {
359    switch (pipe) {
360    case PIPE_STAT_QUERY_IA_VERTICES:
361       return VK_QUERY_PIPELINE_STATISTIC_INPUT_ASSEMBLY_VERTICES_BIT;
362    case PIPE_STAT_QUERY_IA_PRIMITIVES:
363       return VK_QUERY_PIPELINE_STATISTIC_INPUT_ASSEMBLY_PRIMITIVES_BIT;
364    case PIPE_STAT_QUERY_VS_INVOCATIONS:
365       return VK_QUERY_PIPELINE_STATISTIC_VERTEX_SHADER_INVOCATIONS_BIT;
366    case PIPE_STAT_QUERY_GS_INVOCATIONS:
367       return VK_QUERY_PIPELINE_STATISTIC_GEOMETRY_SHADER_INVOCATIONS_BIT;
368    case PIPE_STAT_QUERY_GS_PRIMITIVES:
369       return VK_QUERY_PIPELINE_STATISTIC_GEOMETRY_SHADER_PRIMITIVES_BIT;
370    case PIPE_STAT_QUERY_C_INVOCATIONS:
371       return VK_QUERY_PIPELINE_STATISTIC_CLIPPING_INVOCATIONS_BIT;
372    case PIPE_STAT_QUERY_C_PRIMITIVES:
373       return VK_QUERY_PIPELINE_STATISTIC_CLIPPING_PRIMITIVES_BIT;
374    case PIPE_STAT_QUERY_PS_INVOCATIONS:
375       return VK_QUERY_PIPELINE_STATISTIC_FRAGMENT_SHADER_INVOCATIONS_BIT;
376    case PIPE_STAT_QUERY_HS_INVOCATIONS:
377       return VK_QUERY_PIPELINE_STATISTIC_TESSELLATION_CONTROL_SHADER_PATCHES_BIT;
378    case PIPE_STAT_QUERY_DS_INVOCATIONS:
379       return VK_QUERY_PIPELINE_STATISTIC_TESSELLATION_EVALUATION_SHADER_INVOCATIONS_BIT;
380    case PIPE_STAT_QUERY_CS_INVOCATIONS:
381       return VK_QUERY_PIPELINE_STATISTIC_COMPUTE_SHADER_INVOCATIONS_BIT;
382    case PIPE_STAT_QUERY_TS_INVOCATIONS:
383       return VK_QUERY_PIPELINE_STATISTIC_TASK_SHADER_INVOCATIONS_BIT_EXT;
384    case PIPE_STAT_QUERY_MS_INVOCATIONS:
385       return VK_QUERY_PIPELINE_STATISTIC_MESH_SHADER_INVOCATIONS_BIT_EXT;
386    }
387 
388    unreachable("invalid statistic");
389 }
390 
391 static bool
lower_uvs_index(nir_builder * b,nir_intrinsic_instr * intrin,void * data)392 lower_uvs_index(nir_builder *b, nir_intrinsic_instr *intrin, void *data)
393 {
394    unsigned *vs_uniform_base = data;
395 
396    switch (intrin->intrinsic) {
397    case nir_intrinsic_load_uvs_index_agx: {
398       gl_varying_slot slot = nir_intrinsic_io_semantics(intrin).location;
399       unsigned offset = hk_root_descriptor_offset(draw.uvs_index[slot]);
400       b->cursor = nir_instr_remove(&intrin->instr);
401 
402       nir_def *val = load_root(b, 1, 8, nir_imm_int(b, offset), 1);
403       nir_def_rewrite_uses(&intrin->def, nir_u2u16(b, val));
404       return true;
405    }
406 
407    case nir_intrinsic_load_shader_part_tests_zs_agx:
408       return lower_sysval_to_root_table(b, intrin, draw.no_epilog_discard);
409 
410    case nir_intrinsic_load_api_sample_mask_agx:
411       return lower_sysval_to_root_table(b, intrin, draw.api_sample_mask);
412 
413    case nir_intrinsic_load_sample_positions_agx:
414       return lower_sysval_to_root_table(b, intrin, draw.ppp_multisamplectl);
415 
416    case nir_intrinsic_load_depth_never_agx:
417       return lower_sysval_to_root_table(b, intrin, draw.force_never_in_shader);
418 
419    case nir_intrinsic_load_geometry_param_buffer_agx:
420       return lower_sysval_to_root_table(b, intrin, draw.geometry_params);
421 
422    case nir_intrinsic_load_vs_output_buffer_agx:
423       return lower_sysval_to_root_table(b, intrin, draw.vertex_output_buffer);
424 
425    case nir_intrinsic_load_vs_outputs_agx:
426       return lower_sysval_to_root_table(b, intrin, draw.vertex_outputs);
427 
428    case nir_intrinsic_load_tess_param_buffer_agx:
429       return lower_sysval_to_root_table(b, intrin, draw.tess_params);
430 
431    case nir_intrinsic_load_is_first_fan_agx: {
432       unsigned offset = hk_root_descriptor_offset(draw.provoking);
433       b->cursor = nir_instr_remove(&intrin->instr);
434       nir_def *val = load_root(b, 1, 16, nir_imm_int(b, offset), 2);
435       nir_def_rewrite_uses(&intrin->def, nir_ieq_imm(b, val, 1));
436       return true;
437    }
438 
439    case nir_intrinsic_load_provoking_last: {
440       unsigned offset = hk_root_descriptor_offset(draw.provoking);
441       b->cursor = nir_instr_remove(&intrin->instr);
442       nir_def *val = load_root(b, 1, 16, nir_imm_int(b, offset), 2);
443       nir_def_rewrite_uses(&intrin->def, nir_b2b32(b, nir_ieq_imm(b, val, 2)));
444       return true;
445    }
446 
447    case nir_intrinsic_load_base_vertex:
448    case nir_intrinsic_load_first_vertex:
449    case nir_intrinsic_load_base_instance:
450    case nir_intrinsic_load_draw_id:
451    case nir_intrinsic_load_input_assembly_buffer_agx: {
452       b->cursor = nir_instr_remove(&intrin->instr);
453 
454       unsigned base = *vs_uniform_base;
455       unsigned size = 32;
456 
457       if (intrin->intrinsic == nir_intrinsic_load_base_instance) {
458          base += 2;
459       } else if (intrin->intrinsic == nir_intrinsic_load_draw_id) {
460          base += 4;
461          size = 16;
462       } else if (intrin->intrinsic ==
463                  nir_intrinsic_load_input_assembly_buffer_agx) {
464          base += 8;
465          size = 64;
466       }
467 
468       nir_def *val = nir_load_preamble(b, 1, size, .base = base);
469       nir_def_rewrite_uses(&intrin->def,
470                            nir_u2uN(b, val, intrin->def.bit_size));
471       return true;
472    }
473 
474    case nir_intrinsic_load_stat_query_address_agx: {
475       b->cursor = nir_instr_remove(&intrin->instr);
476 
477       unsigned off1 = hk_root_descriptor_offset(draw.pipeline_stats);
478       unsigned off2 = hk_root_descriptor_offset(draw.pipeline_stats_flags);
479 
480       nir_def *base = load_root(b, 1, 64, nir_imm_int(b, off1), 8);
481       nir_def *flags = load_root(b, 1, 16, nir_imm_int(b, off2), 2);
482 
483       unsigned query = nir_intrinsic_base(intrin);
484       VkQueryPipelineStatisticFlagBits bit = translate_pipeline_stat_bit(query);
485 
486       /* Prefix sum to find the compacted offset */
487       nir_def *idx = nir_bit_count(b, nir_iand_imm(b, flags, bit - 1));
488       nir_def *addr = nir_iadd(
489          b, base, nir_imul_imm(b, nir_u2u64(b, idx), sizeof(uint64_t)));
490 
491       /* The above returns garbage if the query isn't actually enabled, handle
492        * that case.
493        *
494        * TODO: Optimize case where we *know* the query is present?
495        */
496       nir_def *present = nir_ine_imm(b, nir_iand_imm(b, flags, bit), 0);
497       addr = nir_bcsel(b, present, addr, nir_imm_int64(b, 0));
498 
499       nir_def_rewrite_uses(&intrin->def, addr);
500       return true;
501    }
502 
503    default:
504       return false;
505    }
506 }
507 
508 bool
hk_lower_uvs_index(nir_shader * s,unsigned vs_uniform_base)509 hk_lower_uvs_index(nir_shader *s, unsigned vs_uniform_base)
510 {
511    return nir_shader_intrinsics_pass(
512       s, lower_uvs_index, nir_metadata_control_flow, &vs_uniform_base);
513 }
514 
515 static bool
try_lower_intrin(nir_builder * b,nir_intrinsic_instr * intrin,const struct lower_descriptors_ctx * ctx)516 try_lower_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
517                  const struct lower_descriptors_ctx *ctx)
518 {
519    switch (intrin->intrinsic) {
520    case nir_intrinsic_load_constant:
521       return lower_load_constant(b, intrin, ctx);
522 
523    case nir_intrinsic_load_vulkan_descriptor:
524       return try_lower_load_vulkan_descriptor(b, intrin, ctx);
525 
526    case nir_intrinsic_load_workgroup_size:
527       unreachable("Should have been lowered by nir_lower_cs_intrinsics()");
528 
529    case nir_intrinsic_load_base_workgroup_id:
530       return lower_sysval_to_root_table(b, intrin, cs.base_group);
531 
532    case nir_intrinsic_load_push_constant:
533       return lower_load_push_constant(b, intrin, ctx);
534 
535    case nir_intrinsic_load_view_index:
536       return lower_sysval_to_root_table(b, intrin, draw.view_index);
537 
538    case nir_intrinsic_image_deref_load:
539    case nir_intrinsic_image_deref_sparse_load:
540    case nir_intrinsic_image_deref_store:
541    case nir_intrinsic_image_deref_atomic:
542    case nir_intrinsic_image_deref_atomic_swap:
543    case nir_intrinsic_image_deref_size:
544    case nir_intrinsic_image_deref_samples:
545    case nir_intrinsic_image_deref_store_block_agx:
546       return lower_image_intrin(b, intrin, ctx);
547 
548    case nir_intrinsic_load_num_workgroups: {
549       b->cursor = nir_instr_remove(&intrin->instr);
550 
551       unsigned offset = hk_root_descriptor_offset(cs.group_count_addr);
552       nir_def *ptr = load_root(b, 1, 64, nir_imm_int(b, offset), 4);
553       nir_def *val = load_speculatable(b, 3, 32, ptr, 4);
554 
555       nir_def_rewrite_uses(&intrin->def, val);
556       return true;
557    }
558 
559    default:
560       return false;
561    }
562 }
563 
564 static bool
lower_tex(nir_builder * b,nir_tex_instr * tex,const struct lower_descriptors_ctx * ctx)565 lower_tex(nir_builder *b, nir_tex_instr *tex,
566           const struct lower_descriptors_ctx *ctx)
567 {
568    b->cursor = nir_before_instr(&tex->instr);
569 
570    nir_def *texture = nir_steal_tex_src(tex, nir_tex_src_texture_deref);
571    nir_def *sampler = nir_steal_tex_src(tex, nir_tex_src_sampler_deref);
572    if (!texture) {
573       assert(!sampler);
574       return false;
575    }
576 
577    nir_def *plane_ssa = nir_steal_tex_src(tex, nir_tex_src_plane);
578    const uint32_t plane =
579       plane_ssa ? nir_src_as_uint(nir_src_for_ssa(plane_ssa)) : 0;
580    const uint64_t plane_offset_B =
581       plane * sizeof(struct hk_sampled_image_descriptor);
582 
583    /* LOD bias is passed in the descriptor set, rather than embedded into
584     * the sampler descriptor. There's no spot in the hardware descriptor,
585     * plus this saves on precious sampler heap spots.
586     */
587    if (tex->op == nir_texop_lod_bias_agx) {
588       unsigned offs =
589          offsetof(struct hk_sampled_image_descriptor, lod_bias_fp16);
590 
591       nir_def *bias = load_resource_deref_desc(
592          b, 1, 16, nir_src_as_deref(nir_src_for_ssa(sampler)),
593          plane_offset_B + offs, ctx);
594 
595       nir_def_replace(&tex->def, bias);
596       return true;
597    }
598 
599    if (tex->op == nir_texop_has_custom_border_color_agx) {
600       unsigned offs = offsetof(struct hk_sampled_image_descriptor, has_border);
601 
602       nir_def *res = load_resource_deref_desc(
603          b, 1, 16, nir_src_as_deref(nir_src_for_ssa(sampler)),
604          plane_offset_B + offs, ctx);
605 
606       nir_def_replace(&tex->def, nir_ine_imm(b, res, 0));
607       return true;
608    }
609 
610    if (tex->op == nir_texop_custom_border_color_agx) {
611       unsigned offs = offsetof(struct hk_sampled_image_descriptor, border);
612 
613       nir_def *border = load_resource_deref_desc(
614          b, 4, 32, nir_src_as_deref(nir_src_for_ssa(sampler)),
615          plane_offset_B + offs, ctx);
616 
617       nir_alu_type T = nir_alu_type_get_base_type(tex->dest_type);
618       border = nir_convert_to_bit_size(b, border, T, tex->def.bit_size);
619 
620       nir_def_replace(&tex->def, border);
621       return true;
622    }
623 
624    {
625       unsigned offs =
626          offsetof(struct hk_sampled_image_descriptor, image_offset);
627 
628       nir_def *offset = load_resource_deref_desc(
629          b, 1, 32, nir_src_as_deref(nir_src_for_ssa(texture)),
630          plane_offset_B + offs, ctx);
631 
632       nir_def *handle = image_heap_handle(b, offset);
633       nir_tex_instr_add_src(tex, nir_tex_src_texture_handle, handle);
634    }
635 
636    if (sampler != NULL) {
637       unsigned offs =
638          offsetof(struct hk_sampled_image_descriptor, sampler_index);
639 
640       if (tex->backend_flags & AGX_TEXTURE_FLAG_CLAMP_TO_0) {
641          offs =
642             offsetof(struct hk_sampled_image_descriptor, clamp_0_sampler_index);
643       }
644 
645       nir_def *index = load_resource_deref_desc(
646          b, 1, 16, nir_src_as_deref(nir_src_for_ssa(sampler)),
647          plane_offset_B + offs, ctx);
648 
649       nir_tex_instr_add_src(tex, nir_tex_src_sampler_handle, index);
650    }
651 
652    return true;
653 }
654 
655 static bool
try_lower_descriptors_instr(nir_builder * b,nir_instr * instr,void * _data)656 try_lower_descriptors_instr(nir_builder *b, nir_instr *instr, void *_data)
657 {
658    const struct lower_descriptors_ctx *ctx = _data;
659 
660    switch (instr->type) {
661    case nir_instr_type_tex:
662       return lower_tex(b, nir_instr_as_tex(instr), ctx);
663    case nir_instr_type_intrinsic:
664       return try_lower_intrin(b, nir_instr_as_intrinsic(instr), ctx);
665    default:
666       return false;
667    }
668 }
669 
670 static bool
lower_ssbo_resource_index(nir_builder * b,nir_intrinsic_instr * intrin,const struct lower_descriptors_ctx * ctx)671 lower_ssbo_resource_index(nir_builder *b, nir_intrinsic_instr *intrin,
672                           const struct lower_descriptors_ctx *ctx)
673 {
674    const VkDescriptorType desc_type = nir_intrinsic_desc_type(intrin);
675    if (desc_type != VK_DESCRIPTOR_TYPE_STORAGE_BUFFER &&
676        desc_type != VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC)
677       return false;
678 
679    b->cursor = nir_instr_remove(&intrin->instr);
680 
681    uint32_t set = nir_intrinsic_desc_set(intrin);
682    uint32_t binding = nir_intrinsic_binding(intrin);
683    nir_def *index = intrin->src[0].ssa;
684 
685    const struct hk_descriptor_set_binding_layout *binding_layout =
686       get_binding_layout(set, binding, ctx);
687 
688    nir_def *binding_addr;
689    uint8_t binding_stride;
690    switch (binding_layout->type) {
691    case VK_DESCRIPTOR_TYPE_MUTABLE_EXT:
692    case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER: {
693       nir_def *set_addr = load_descriptor_set_addr(b, set, ctx);
694       binding_addr = nir_iadd_imm(b, set_addr, binding_layout->offset);
695       binding_stride = binding_layout->stride;
696       break;
697    }
698 
699    case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC: {
700       const uint32_t root_desc_addr_offset =
701          hk_root_descriptor_offset(root_desc_addr);
702 
703       nir_def *root_desc_addr =
704          load_root(b, 1, 64, nir_imm_int(b, root_desc_addr_offset), 8);
705 
706       nir_def *dynamic_buffer_start =
707          nir_iadd_imm(b, load_dynamic_buffer_start(b, set, ctx),
708                       binding_layout->dynamic_buffer_index);
709 
710       nir_def *dynamic_binding_offset =
711          nir_iadd_imm(b,
712                       nir_imul_imm(b, dynamic_buffer_start,
713                                    sizeof(struct hk_buffer_address)),
714                       hk_root_descriptor_offset(dynamic_buffers));
715 
716       binding_addr =
717          nir_iadd(b, root_desc_addr, nir_u2u64(b, dynamic_binding_offset));
718       binding_stride = sizeof(struct hk_buffer_address);
719       break;
720    }
721 
722    default:
723       unreachable("Not an SSBO descriptor");
724    }
725 
726    /* Tuck the stride in the top 8 bits of the binding address */
727    binding_addr = nir_ior_imm(b, binding_addr, (uint64_t)binding_stride << 56);
728 
729    const uint32_t binding_size = binding_layout->array_size * binding_stride;
730    nir_def *offset_in_binding = nir_imul_imm(b, index, binding_stride);
731 
732    nir_def *addr = nir_vec4(b, nir_unpack_64_2x32_split_x(b, binding_addr),
733                             nir_unpack_64_2x32_split_y(b, binding_addr),
734                             nir_imm_int(b, binding_size), offset_in_binding);
735 
736    nir_def_rewrite_uses(&intrin->def, addr);
737 
738    return true;
739 }
740 
741 static bool
lower_ssbo_resource_reindex(nir_builder * b,nir_intrinsic_instr * intrin,const struct lower_descriptors_ctx * ctx)742 lower_ssbo_resource_reindex(nir_builder *b, nir_intrinsic_instr *intrin,
743                             const struct lower_descriptors_ctx *ctx)
744 {
745    const VkDescriptorType desc_type = nir_intrinsic_desc_type(intrin);
746    if (desc_type != VK_DESCRIPTOR_TYPE_STORAGE_BUFFER &&
747        desc_type != VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC)
748       return false;
749 
750    b->cursor = nir_instr_remove(&intrin->instr);
751 
752    nir_def *addr = intrin->src[0].ssa;
753    nir_def *index = intrin->src[1].ssa;
754 
755    nir_def *addr_high32 = nir_channel(b, addr, 1);
756    nir_def *stride = nir_ushr_imm(b, addr_high32, 24);
757    nir_def *offset = nir_imul(b, index, stride);
758 
759    addr = nir_build_addr_iadd(b, addr, ctx->ssbo_addr_format, nir_var_mem_ssbo,
760                               offset);
761    nir_def_rewrite_uses(&intrin->def, addr);
762 
763    return true;
764 }
765 
766 static bool
lower_load_ssbo_descriptor(nir_builder * b,nir_intrinsic_instr * intrin,const struct lower_descriptors_ctx * ctx)767 lower_load_ssbo_descriptor(nir_builder *b, nir_intrinsic_instr *intrin,
768                            const struct lower_descriptors_ctx *ctx)
769 {
770    const VkDescriptorType desc_type = nir_intrinsic_desc_type(intrin);
771    if (desc_type != VK_DESCRIPTOR_TYPE_STORAGE_BUFFER &&
772        desc_type != VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC)
773       return false;
774 
775    b->cursor = nir_instr_remove(&intrin->instr);
776 
777    nir_def *addr = intrin->src[0].ssa;
778 
779    nir_def *desc;
780    switch (ctx->ssbo_addr_format) {
781    case nir_address_format_64bit_global_32bit_offset: {
782       nir_def *base = nir_pack_64_2x32(b, nir_trim_vector(b, addr, 2));
783       nir_def *offset = nir_channel(b, addr, 3);
784       /* Mask off the binding stride */
785       base = nir_iand_imm(b, base, BITFIELD64_MASK(56));
786       desc = nir_load_global_constant_offset(b, 4, 32, base, offset,
787                                              .align_mul = 16, .align_offset = 0,
788                                              .access = ACCESS_CAN_SPECULATE);
789       break;
790    }
791 
792    case nir_address_format_64bit_bounded_global: {
793       nir_def *base = nir_pack_64_2x32(b, nir_trim_vector(b, addr, 2));
794       nir_def *size = nir_channel(b, addr, 2);
795       nir_def *offset = nir_channel(b, addr, 3);
796       /* Mask off the binding stride */
797       base = nir_iand_imm(b, base, BITFIELD64_MASK(56));
798       desc = nir_load_global_constant_bounded(
799          b, 4, 32, base, offset, size, .align_mul = 16, .align_offset = 0,
800          .access = ACCESS_CAN_SPECULATE);
801       break;
802    }
803 
804    default:
805       unreachable("Unknown address mode");
806    }
807 
808    nir_def_rewrite_uses(&intrin->def, desc);
809 
810    return true;
811 }
812 
813 static bool
lower_ssbo_descriptor(nir_builder * b,nir_intrinsic_instr * intr,void * _data)814 lower_ssbo_descriptor(nir_builder *b, nir_intrinsic_instr *intr, void *_data)
815 {
816    const struct lower_descriptors_ctx *ctx = _data;
817 
818    switch (intr->intrinsic) {
819    case nir_intrinsic_vulkan_resource_index:
820       return lower_ssbo_resource_index(b, intr, ctx);
821    case nir_intrinsic_vulkan_resource_reindex:
822       return lower_ssbo_resource_reindex(b, intr, ctx);
823    case nir_intrinsic_load_vulkan_descriptor:
824       return lower_load_ssbo_descriptor(b, intr, ctx);
825    default:
826       return false;
827    }
828 }
829 
830 bool
hk_nir_lower_descriptors(nir_shader * nir,const struct vk_pipeline_robustness_state * rs,uint32_t set_layout_count,struct vk_descriptor_set_layout * const * set_layouts)831 hk_nir_lower_descriptors(nir_shader *nir,
832                          const struct vk_pipeline_robustness_state *rs,
833                          uint32_t set_layout_count,
834                          struct vk_descriptor_set_layout *const *set_layouts)
835 {
836    struct lower_descriptors_ctx ctx = {
837       .clamp_desc_array_bounds =
838          rs->storage_buffers !=
839             VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DISABLED_EXT ||
840 
841          rs->uniform_buffers !=
842             VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DISABLED_EXT ||
843 
844          rs->images != VK_PIPELINE_ROBUSTNESS_IMAGE_BEHAVIOR_DISABLED_EXT,
845 
846       .ssbo_addr_format = hk_buffer_addr_format(rs->storage_buffers),
847       .ubo_addr_format = hk_buffer_addr_format(rs->uniform_buffers),
848    };
849 
850    assert(set_layout_count <= HK_MAX_SETS);
851    for (uint32_t s = 0; s < set_layout_count; s++) {
852       if (set_layouts[s] != NULL)
853          ctx.set_layouts[s] = vk_to_hk_descriptor_set_layout(set_layouts[s]);
854    }
855 
856    /* First lower everything but complex SSBOs, then lower complex SSBOs.
857     *
858     * TODO: See if we can unify this, not sure if the fast path matters on
859     * Apple. This is inherited from NVK.
860     */
861    bool pass_lower_descriptors = nir_shader_instructions_pass(
862       nir, try_lower_descriptors_instr, nir_metadata_control_flow, &ctx);
863 
864    bool pass_lower_ssbo = nir_shader_intrinsics_pass(
865       nir, lower_ssbo_descriptor, nir_metadata_control_flow, &ctx);
866 
867    return pass_lower_descriptors || pass_lower_ssbo;
868 }
869