xref: /aosp_15_r20/external/mesa3d/src/asahi/vulkan/hk_shader.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2024 Valve Corporation
3  * Copyright 2024 Alyssa Rosenzweig
4  * Copyright 2022-2023 Collabora Ltd. and Red Hat Inc.
5  * SPDX-License-Identifier: MIT
6  */
7 #include "hk_shader.h"
8 
9 #include "agx_helpers.h"
10 #include "agx_nir_lower_gs.h"
11 #include "glsl_types.h"
12 #include "nir.h"
13 #include "nir_builder.h"
14 
15 #include "agx_bo.h"
16 #include "hk_cmd_buffer.h"
17 #include "hk_descriptor_set_layout.h"
18 #include "hk_device.h"
19 #include "hk_physical_device.h"
20 #include "hk_sampler.h"
21 #include "hk_shader.h"
22 
23 #include "nir_builder_opcodes.h"
24 #include "nir_builtin_builder.h"
25 #include "nir_intrinsics.h"
26 #include "nir_intrinsics_indices.h"
27 #include "nir_xfb_info.h"
28 #include "shader_enums.h"
29 #include "vk_nir_convert_ycbcr.h"
30 #include "vk_pipeline.h"
31 #include "vk_pipeline_layout.h"
32 #include "vk_shader_module.h"
33 #include "vk_ycbcr_conversion.h"
34 
35 #include "asahi/compiler/agx_compile.h"
36 #include "asahi/lib/agx_linker.h"
37 #include "asahi/lib/agx_nir_passes.h"
38 #include "asahi/lib/agx_tilebuffer.h"
39 #include "asahi/lib/agx_uvs.h"
40 #include "compiler/spirv/nir_spirv.h"
41 
42 #include "util/blob.h"
43 #include "util/hash_table.h"
44 #include "util/macros.h"
45 #include "util/mesa-sha1.h"
46 #include "util/simple_mtx.h"
47 #include "util/u_debug.h"
48 #include "vulkan/vulkan_core.h"
49 
50 struct hk_fs_key {
51    bool zs_self_dep;
52 
53    /** True if sample shading is forced on via an API knob such as
54     * VkPipelineMultisampleStateCreateInfo::minSampleShading
55     */
56    bool force_sample_shading;
57 
58    uint8_t pad[2];
59 };
60 static_assert(sizeof(struct hk_fs_key) == 4, "packed");
61 
62 static void
shared_var_info(const struct glsl_type * type,unsigned * size,unsigned * align)63 shared_var_info(const struct glsl_type *type, unsigned *size, unsigned *align)
64 {
65    assert(glsl_type_is_vector_or_scalar(type));
66 
67    uint32_t comp_size =
68       glsl_type_is_boolean(type) ? 4 : glsl_get_bit_size(type) / 8;
69    unsigned length = glsl_get_vector_elements(type);
70    *size = comp_size * length, *align = comp_size;
71 }
72 
73 uint64_t
hk_physical_device_compiler_flags(const struct hk_physical_device * pdev)74 hk_physical_device_compiler_flags(const struct hk_physical_device *pdev)
75 {
76    /* TODO compiler flags */
77    return 0;
78 }
79 
80 const nir_shader_compiler_options *
hk_get_nir_options(struct vk_physical_device * vk_pdev,gl_shader_stage stage,UNUSED const struct vk_pipeline_robustness_state * rs)81 hk_get_nir_options(struct vk_physical_device *vk_pdev, gl_shader_stage stage,
82                    UNUSED const struct vk_pipeline_robustness_state *rs)
83 {
84    return &agx_nir_options;
85 }
86 
87 static struct spirv_to_nir_options
hk_get_spirv_options(struct vk_physical_device * vk_pdev,UNUSED gl_shader_stage stage,const struct vk_pipeline_robustness_state * rs)88 hk_get_spirv_options(struct vk_physical_device *vk_pdev,
89                      UNUSED gl_shader_stage stage,
90                      const struct vk_pipeline_robustness_state *rs)
91 {
92    return (struct spirv_to_nir_options){
93       .ssbo_addr_format = hk_buffer_addr_format(rs->storage_buffers),
94       .phys_ssbo_addr_format = nir_address_format_64bit_global,
95       .ubo_addr_format = hk_buffer_addr_format(rs->uniform_buffers),
96       .shared_addr_format = nir_address_format_32bit_offset,
97       .min_ssbo_alignment = HK_MIN_SSBO_ALIGNMENT,
98       .min_ubo_alignment = HK_MIN_UBO_ALIGNMENT,
99    };
100 }
101 
102 static bool
lower_halt_to_return(nir_builder * b,nir_instr * instr,UNUSED void * _data)103 lower_halt_to_return(nir_builder *b, nir_instr *instr, UNUSED void *_data)
104 {
105    if (instr->type != nir_instr_type_jump)
106       return false;
107 
108    nir_jump_instr *jump = nir_instr_as_jump(instr);
109    if (jump->type != nir_jump_halt)
110       return false;
111 
112    assert(b->impl == nir_shader_get_entrypoint(b->shader));
113    jump->type = nir_jump_return;
114    return true;
115 }
116 
117 void
hk_preprocess_nir_internal(struct vk_physical_device * vk_pdev,nir_shader * nir)118 hk_preprocess_nir_internal(struct vk_physical_device *vk_pdev, nir_shader *nir)
119 {
120    /* Must lower before io to temps */
121    if (nir->info.stage == MESA_SHADER_FRAGMENT) {
122       NIR_PASS(_, nir, nir_lower_terminate_to_demote);
123       NIR_PASS(_, nir, nir_shader_instructions_pass, lower_halt_to_return,
124                nir_metadata_all, NULL);
125       NIR_PASS(_, nir, nir_lower_returns);
126    }
127 
128    /* Unroll loops before lowering indirects via nir_lower_io_to_temporaries */
129    UNUSED bool progress = false;
130    NIR_PASS(_, nir, nir_lower_global_vars_to_local);
131 
132    do {
133       progress = false;
134       NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
135       NIR_PASS(progress, nir, nir_copy_prop);
136       NIR_PASS(progress, nir, nir_opt_dce);
137       NIR_PASS(progress, nir, nir_opt_constant_folding);
138       NIR_PASS(progress, nir, nir_opt_loop);
139       NIR_PASS(progress, nir, nir_opt_loop_unroll);
140    } while (progress);
141 
142    if (nir->info.stage == MESA_SHADER_FRAGMENT) {
143       struct nir_lower_sysvals_to_varyings_options sysvals_opts = {
144          .point_coord = true,
145       };
146 
147       nir_lower_sysvals_to_varyings(nir, &sysvals_opts);
148    }
149 
150    NIR_PASS(_, nir, nir_lower_system_values);
151 
152    /* Gather info before preprocess_nir but after some general lowering, so
153     * inputs_read and system_values_read are accurately set.
154     */
155    nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
156 
157    NIR_PASS_V(nir, nir_lower_io_to_temporaries, nir_shader_get_entrypoint(nir),
158               true, false);
159 
160    NIR_PASS(_, nir, nir_lower_global_vars_to_local);
161 
162    NIR_PASS(_, nir, nir_split_var_copies);
163    NIR_PASS(_, nir, nir_split_struct_vars, nir_var_function_temp);
164 
165    /* Optimize but allow copies because we haven't lowered them yet */
166    agx_preprocess_nir(nir, NULL);
167 
168    NIR_PASS(_, nir, nir_lower_load_const_to_scalar);
169    NIR_PASS(_, nir, nir_lower_var_copies);
170 }
171 
172 static void
hk_preprocess_nir(struct vk_physical_device * vk_pdev,nir_shader * nir)173 hk_preprocess_nir(struct vk_physical_device *vk_pdev, nir_shader *nir)
174 {
175    hk_preprocess_nir_internal(vk_pdev, nir);
176    nir_lower_compute_system_values_options csv_options = {
177       .has_base_workgroup_id = true,
178    };
179    NIR_PASS(_, nir, nir_lower_compute_system_values, &csv_options);
180 }
181 
182 static void
hk_populate_fs_key(struct hk_fs_key * key,const struct vk_graphics_pipeline_state * state)183 hk_populate_fs_key(struct hk_fs_key *key,
184                    const struct vk_graphics_pipeline_state *state)
185 {
186    memset(key, 0, sizeof(*key));
187 
188    if (state == NULL)
189       return;
190 
191    if (state->pipeline_flags &
192        VK_PIPELINE_CREATE_2_DEPTH_STENCIL_ATTACHMENT_FEEDBACK_LOOP_BIT_EXT)
193       key->zs_self_dep = true;
194 
195    /* We force per-sample interpolation whenever sampleShadingEnable is set
196     * regardless of minSampleShading or rasterizationSamples.
197     *
198     * When sampleShadingEnable is set, few guarantees are made about the
199     * location of interpolation of the inputs.  The only real guarantees are
200     * that the inputs are interpolated within the pixel and that you get at
201     * least `rasterizationSamples * minSampleShading` unique positions.
202     * Importantly, it does not require that when `rasterizationSamples *
203     * minSampleShading <= 1.0` that those positions are at the fragment
204     * center.  Therefore, it's valid to just always do per-sample all the time.
205     *
206     * The one caveat here is that we have to be careful about gl_SampleMaskIn.
207     * When `hk_fs_key::force_sample_shading = true` we also turn any reads of
208     * gl_SampleMaskIn into `1 << gl_SampleID` because the hardware sample mask
209     * is actually per-fragment, not per-pass.  We handle this by smashing
210     * minSampleShading to 1.0 whenever gl_SampleMaskIn is read.
211     */
212    const struct vk_multisample_state *ms = state->ms;
213    if (ms != NULL && ms->sample_shading_enable)
214       key->force_sample_shading = true;
215 }
216 
217 static void
hk_hash_graphics_state(struct vk_physical_device * device,const struct vk_graphics_pipeline_state * state,VkShaderStageFlags stages,blake3_hash blake3_out)218 hk_hash_graphics_state(struct vk_physical_device *device,
219                        const struct vk_graphics_pipeline_state *state,
220                        VkShaderStageFlags stages, blake3_hash blake3_out)
221 {
222    struct mesa_blake3 blake3_ctx;
223    _mesa_blake3_init(&blake3_ctx);
224    if (stages & VK_SHADER_STAGE_FRAGMENT_BIT) {
225       struct hk_fs_key key;
226       hk_populate_fs_key(&key, state);
227       _mesa_blake3_update(&blake3_ctx, &key, sizeof(key));
228 
229       const bool is_multiview = state->rp->view_mask != 0;
230       _mesa_blake3_update(&blake3_ctx, &is_multiview, sizeof(is_multiview));
231    }
232    _mesa_blake3_final(&blake3_ctx, blake3_out);
233 }
234 
235 static bool
lower_load_global_constant_offset_instr(nir_builder * b,nir_intrinsic_instr * intrin,void * data)236 lower_load_global_constant_offset_instr(nir_builder *b,
237                                         nir_intrinsic_instr *intrin, void *data)
238 {
239    if (intrin->intrinsic != nir_intrinsic_load_global_constant_offset &&
240        intrin->intrinsic != nir_intrinsic_load_global_constant_bounded)
241       return false;
242 
243    b->cursor = nir_before_instr(&intrin->instr);
244    bool *has_soft_fault = data;
245 
246    nir_def *base_addr = intrin->src[0].ssa;
247    nir_def *offset = intrin->src[1].ssa;
248 
249    nir_def *zero = NULL;
250    nir_def *in_bounds = NULL;
251    if (intrin->intrinsic == nir_intrinsic_load_global_constant_bounded) {
252       nir_def *bound = intrin->src[2].ssa;
253 
254       unsigned bit_size = intrin->def.bit_size;
255       assert(bit_size >= 8 && bit_size % 8 == 0);
256       unsigned byte_size = bit_size / 8;
257 
258       zero = nir_imm_zero(b, intrin->num_components, bit_size);
259 
260       unsigned load_size = byte_size * intrin->num_components;
261 
262       nir_def *sat_offset =
263          nir_umin(b, offset, nir_imm_int(b, UINT32_MAX - (load_size - 1)));
264       in_bounds = nir_ilt(b, nir_iadd_imm(b, sat_offset, load_size - 1), bound);
265 
266       /* If we do not have soft fault, we branch to bounds check. This is slow,
267        * fortunately we always have soft fault for release drivers.
268        *
269        * With soft fault, we speculatively load and smash to zero at the end.
270        */
271       if (!(*has_soft_fault))
272          nir_push_if(b, in_bounds);
273    }
274 
275    nir_def *val = nir_build_load_global_constant(
276       b, intrin->def.num_components, intrin->def.bit_size,
277       nir_iadd(b, base_addr, nir_u2u64(b, offset)),
278       .align_mul = nir_intrinsic_align_mul(intrin),
279       .align_offset = nir_intrinsic_align_offset(intrin),
280       .access = nir_intrinsic_access(intrin));
281 
282    if (intrin->intrinsic == nir_intrinsic_load_global_constant_bounded) {
283       if (*has_soft_fault) {
284          val = nir_bcsel(b, in_bounds, val, zero);
285       } else {
286          nir_pop_if(b, NULL);
287          val = nir_if_phi(b, val, zero);
288       }
289    }
290 
291    nir_def_replace(&intrin->def, val);
292    return true;
293 }
294 
295 struct lower_ycbcr_state {
296    uint32_t set_layout_count;
297    struct vk_descriptor_set_layout *const *set_layouts;
298 };
299 
300 static const struct vk_ycbcr_conversion_state *
lookup_ycbcr_conversion(const void * _state,uint32_t set,uint32_t binding,uint32_t array_index)301 lookup_ycbcr_conversion(const void *_state, uint32_t set, uint32_t binding,
302                         uint32_t array_index)
303 {
304    const struct lower_ycbcr_state *state = _state;
305    assert(set < state->set_layout_count);
306    assert(state->set_layouts[set] != NULL);
307    const struct hk_descriptor_set_layout *set_layout =
308       vk_to_hk_descriptor_set_layout(state->set_layouts[set]);
309    assert(binding < set_layout->binding_count);
310 
311    const struct hk_descriptor_set_binding_layout *bind_layout =
312       &set_layout->binding[binding];
313 
314    if (bind_layout->immutable_samplers == NULL)
315       return NULL;
316 
317    array_index = MIN2(array_index, bind_layout->array_size - 1);
318 
319    const struct hk_sampler *sampler =
320       bind_layout->immutable_samplers[array_index];
321 
322    return sampler && sampler->vk.ycbcr_conversion
323              ? &sampler->vk.ycbcr_conversion->state
324              : NULL;
325 }
326 
327 static inline bool
nir_has_image_var(nir_shader * nir)328 nir_has_image_var(nir_shader *nir)
329 {
330    nir_foreach_image_variable(_, nir)
331       return true;
332 
333    return false;
334 }
335 
336 static int
glsl_type_size(const struct glsl_type * type,bool bindless)337 glsl_type_size(const struct glsl_type *type, bool bindless)
338 {
339    return glsl_count_attribute_slots(type, false);
340 }
341 
342 /*
343  * This is the world's worst multiview implementation. We simply duplicate each
344  * draw on the CPU side, changing a uniform in between, and then plumb the view
345  * index into the layer ID here. Whatever, it works.
346  *
347  * The "proper" implementation on AGX would use vertex amplification, but a
348  * MacBook is not a VR headset.
349  */
350 static void
hk_lower_multiview(nir_shader * nir)351 hk_lower_multiview(nir_shader *nir)
352 {
353    /* If there's an existing layer ID write, ignore it. This avoids validation
354     * splat with vk_meta.
355     */
356    nir_variable *existing = nir_find_variable_with_location(
357       nir, nir_var_shader_out, VARYING_SLOT_LAYER);
358 
359    if (existing) {
360       existing->data.mode = nir_var_shader_temp;
361       existing->data.location = 0;
362       nir_fixup_deref_modes(nir);
363    }
364 
365    /* Now write the view index as the layer */
366    nir_builder b =
367       nir_builder_at(nir_after_impl(nir_shader_get_entrypoint(nir)));
368 
369    nir_variable *layer =
370       nir_variable_create(nir, nir_var_shader_out, glsl_uint_type(), NULL);
371 
372    layer->data.location = VARYING_SLOT_LAYER;
373 
374    nir_store_var(&b, layer, nir_load_view_index(&b), nir_component_mask(1));
375    b.shader->info.outputs_written |= VARYING_BIT_LAYER;
376 }
377 
378 /*
379  * KHR_maintenance5 requires that points rasterize with a default point size of
380  * 1.0, while our hardware requires an explicit point size write for this.
381  * Since topology may be dynamic, we insert an unconditional write if necessary.
382  */
383 static bool
hk_nir_insert_psiz_write(nir_shader * nir)384 hk_nir_insert_psiz_write(nir_shader *nir)
385 {
386    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
387 
388    if (nir->info.outputs_written & VARYING_BIT_PSIZ) {
389       nir_metadata_preserve(impl, nir_metadata_all);
390       return false;
391    }
392 
393    nir_builder b = nir_builder_at(nir_after_impl(impl));
394 
395    nir_store_output(&b, nir_imm_float(&b, 1.0), nir_imm_int(&b, 0),
396                     .write_mask = nir_component_mask(1),
397                     .io_semantics.location = VARYING_SLOT_PSIZ,
398                     .io_semantics.num_slots = 1, .src_type = nir_type_float32);
399 
400    nir->info.outputs_written |= VARYING_BIT_PSIZ;
401    nir_metadata_preserve(b.impl, nir_metadata_control_flow);
402    return true;
403 }
404 
405 static nir_def *
query_custom_border(nir_builder * b,nir_tex_instr * tex)406 query_custom_border(nir_builder *b, nir_tex_instr *tex)
407 {
408    return nir_build_texture_query(b, tex, nir_texop_custom_border_color_agx, 4,
409                                   tex->dest_type, false, false);
410 }
411 
412 static nir_def *
has_custom_border(nir_builder * b,nir_tex_instr * tex)413 has_custom_border(nir_builder *b, nir_tex_instr *tex)
414 {
415    return nir_build_texture_query(b, tex, nir_texop_has_custom_border_color_agx,
416                                   1, nir_type_bool1, false, false);
417 }
418 
419 static bool
lower(nir_builder * b,nir_instr * instr,UNUSED void * _data)420 lower(nir_builder *b, nir_instr *instr, UNUSED void *_data)
421 {
422    if (instr->type != nir_instr_type_tex)
423       return false;
424 
425    nir_tex_instr *tex = nir_instr_as_tex(instr);
426    if (!nir_tex_instr_need_sampler(tex) || nir_tex_instr_is_query(tex))
427       return false;
428 
429    /* XXX: this is a really weird edge case, is this even well-defined? */
430    if (tex->is_shadow)
431       return false;
432 
433    b->cursor = nir_after_instr(&tex->instr);
434    nir_def *has_custom = has_custom_border(b, tex);
435 
436    nir_instr *orig = nir_instr_clone(b->shader, &tex->instr);
437    nir_builder_instr_insert(b, orig);
438    nir_def *clamp_to_1 = &nir_instr_as_tex(orig)->def;
439 
440    nir_push_if(b, has_custom);
441    nir_def *replaced = NULL;
442    {
443       /* Sample again, this time with clamp-to-0 instead of clamp-to-1 */
444       nir_instr *clone_instr = nir_instr_clone(b->shader, &tex->instr);
445       nir_builder_instr_insert(b, clone_instr);
446 
447       nir_tex_instr *tex_0 = nir_instr_as_tex(clone_instr);
448       nir_def *clamp_to_0 = &tex_0->def;
449 
450       tex_0->backend_flags |= AGX_TEXTURE_FLAG_CLAMP_TO_0;
451 
452       /* Grab the border colour */
453       nir_def *border = query_custom_border(b, tex_0);
454 
455       if (tex->op == nir_texop_tg4) {
456          border = nir_replicate(b, nir_channel(b, border, tex->component), 4);
457       }
458 
459       /* Combine together with the border */
460       if (nir_alu_type_get_base_type(tex->dest_type) == nir_type_float &&
461           tex->op != nir_texop_tg4) {
462 
463          /* For floats, lerp together:
464           *
465           * For border texels:  (1 * border) + (0 * border      ) = border
466           * For regular texels: (x * border) + (x * (1 - border)) = x.
467           *
468           * Linear filtering is linear (duh), so lerping is compatible.
469           */
470          replaced = nir_flrp(b, clamp_to_0, clamp_to_1, border);
471       } else {
472          /* For integers, just select componentwise since there is no linear
473           * filtering. Gathers also use this path since they are unfiltered in
474           * each component.
475           */
476          replaced = nir_bcsel(b, nir_ieq(b, clamp_to_0, clamp_to_1), clamp_to_0,
477                               border);
478       }
479    }
480    nir_pop_if(b, NULL);
481 
482    /* Put it together with a phi */
483    nir_def *phi = nir_if_phi(b, replaced, clamp_to_1);
484    nir_def_replace(&tex->def, phi);
485    return true;
486 }
487 
488 static bool
agx_nir_lower_custom_border(nir_shader * nir)489 agx_nir_lower_custom_border(nir_shader *nir)
490 {
491    return nir_shader_instructions_pass(nir, lower, nir_metadata_none, NULL);
492 }
493 
494 /*
495  * In Vulkan, the VIEWPORT should read 0 in the fragment shader if it is not
496  * written by the vertex shader, but in our implementation, the varying would
497  * otherwise be undefined. This small pass predicates VIEWPORT reads based on
498  * whether the hardware vertex shader writes the VIEWPORT (nonzero UVS index).
499  */
500 static bool
lower_viewport_fs(nir_builder * b,nir_intrinsic_instr * intr,UNUSED void * data)501 lower_viewport_fs(nir_builder *b, nir_intrinsic_instr *intr, UNUSED void *data)
502 {
503    if (intr->intrinsic != nir_intrinsic_load_input)
504       return false;
505 
506    nir_io_semantics sem = nir_intrinsic_io_semantics(intr);
507    if (sem.location != VARYING_SLOT_VIEWPORT)
508       return false;
509 
510    b->cursor = nir_after_instr(&intr->instr);
511    nir_def *orig = &intr->def;
512 
513    nir_def *uvs = nir_load_uvs_index_agx(b, .io_semantics = sem);
514    nir_def *def = nir_bcsel(b, nir_ine_imm(b, uvs, 0), orig, nir_imm_int(b, 0));
515 
516    nir_def_rewrite_uses_after(orig, def, def->parent_instr);
517    return true;
518 }
519 
520 static bool
lower_subpass_dim(nir_builder * b,nir_instr * instr,UNUSED void * _data)521 lower_subpass_dim(nir_builder *b, nir_instr *instr, UNUSED void *_data)
522 {
523    if (instr->type != nir_instr_type_tex)
524       return false;
525 
526    nir_tex_instr *tex = nir_instr_as_tex(instr);
527    if (tex->sampler_dim == GLSL_SAMPLER_DIM_SUBPASS)
528       tex->sampler_dim = GLSL_SAMPLER_DIM_2D;
529    else if (tex->sampler_dim == GLSL_SAMPLER_DIM_SUBPASS_MS)
530       tex->sampler_dim = GLSL_SAMPLER_DIM_MS;
531    else
532       return false;
533 
534    return true;
535 }
536 
537 void
hk_lower_nir(struct hk_device * dev,nir_shader * nir,const struct vk_pipeline_robustness_state * rs,bool is_multiview,uint32_t set_layout_count,struct vk_descriptor_set_layout * const * set_layouts)538 hk_lower_nir(struct hk_device *dev, nir_shader *nir,
539              const struct vk_pipeline_robustness_state *rs, bool is_multiview,
540              uint32_t set_layout_count,
541              struct vk_descriptor_set_layout *const *set_layouts)
542 {
543    if (nir->info.stage == MESA_SHADER_FRAGMENT) {
544       NIR_PASS(_, nir, nir_lower_input_attachments,
545                &(nir_input_attachment_options){
546                   .use_fragcoord_sysval = true,
547                   .use_layer_id_sysval = true,
548                   .use_view_id_for_layer = is_multiview,
549                });
550 
551       NIR_PASS(_, nir, nir_shader_instructions_pass, lower_subpass_dim,
552                nir_metadata_all, NULL);
553       NIR_PASS(_, nir, nir_lower_wpos_center);
554    }
555 
556    /* XXX: should be last geometry stage, how do I get to that? */
557    if (nir->info.stage == MESA_SHADER_VERTEX) {
558       if (is_multiview)
559          hk_lower_multiview(nir);
560    }
561 
562    if (nir->info.stage == MESA_SHADER_TESS_EVAL) {
563       NIR_PASS(_, nir, nir_lower_patch_vertices,
564                nir->info.tess.tcs_vertices_out, NULL);
565    }
566 
567    const struct lower_ycbcr_state ycbcr_state = {
568       .set_layout_count = set_layout_count,
569       .set_layouts = set_layouts,
570    };
571    NIR_PASS(_, nir, nir_vk_lower_ycbcr_tex, lookup_ycbcr_conversion,
572             &ycbcr_state);
573 
574    /* Lower push constants before lower_descriptors */
575    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_push_const,
576             nir_address_format_32bit_offset);
577 
578    // NIR_PASS(_, nir, nir_opt_large_constants, NULL, 32);
579 
580    /* Images accessed through the texture or PBE hardware are robust, so we
581     * don't set lower_image. (There are some sticky details around txf but
582     * they're handled by agx_nir_lower_texture). However, image atomics are
583     * software so require robustness lowering.
584     */
585    nir_lower_robust_access_options robustness = {
586       .lower_image_atomic = true,
587    };
588 
589    NIR_PASS(_, nir, nir_lower_robust_access, &robustness);
590 
591    /* We must do early lowering before hk_nir_lower_descriptors, since this will
592     * create lod_bias_agx instructions.
593     */
594    NIR_PASS(_, nir, agx_nir_lower_texture_early, true /* support_lod_bias */);
595    NIR_PASS(_, nir, agx_nir_lower_custom_border);
596 
597    NIR_PASS(_, nir, hk_nir_lower_descriptors, rs, set_layout_count,
598             set_layouts);
599    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_global,
600             nir_address_format_64bit_global);
601    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_ssbo,
602             hk_buffer_addr_format(rs->storage_buffers));
603    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_ubo,
604             hk_buffer_addr_format(rs->uniform_buffers));
605 
606    bool soft_fault = agx_has_soft_fault(&dev->dev);
607    NIR_PASS(_, nir, nir_shader_intrinsics_pass,
608             lower_load_global_constant_offset_instr, nir_metadata_none,
609             &soft_fault);
610 
611    if (!nir->info.shared_memory_explicit_layout) {
612       /* There may be garbage in shared_size, but it's the job of
613        * nir_lower_vars_to_explicit_types to allocate it. We have to reset to
614        * avoid overallocation.
615        */
616       nir->info.shared_size = 0;
617 
618       NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_mem_shared,
619                shared_var_info);
620    }
621    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_shared,
622             nir_address_format_32bit_offset);
623 
624    if (nir->info.zero_initialize_shared_memory && nir->info.shared_size > 0) {
625       /* Align everything up to 16B so we can write whole vec4s. */
626       nir->info.shared_size = align(nir->info.shared_size, 16);
627       NIR_PASS(_, nir, nir_zero_initialize_shared_memory, nir->info.shared_size,
628                16);
629 
630       /* We need to call lower_compute_system_values again because
631        * nir_zero_initialize_shared_memory generates load_invocation_id which
632        * has to be lowered to load_invocation_index.
633        */
634       NIR_PASS(_, nir, nir_lower_compute_system_values, NULL);
635    }
636 
637    /* TODO: we can do indirect VS output */
638    nir_variable_mode lower_indirect_modes = 0;
639    if (nir->info.stage == MESA_SHADER_FRAGMENT)
640       lower_indirect_modes |= nir_var_shader_out;
641    else if (nir->info.stage == MESA_SHADER_VERTEX)
642       lower_indirect_modes |= nir_var_shader_in | nir_var_shader_out;
643 
644    NIR_PASS(_, nir, nir_lower_indirect_derefs, lower_indirect_modes,
645             UINT32_MAX);
646 
647    NIR_PASS(_, nir, nir_lower_io, nir_var_shader_in | nir_var_shader_out,
648             glsl_type_size, nir_lower_io_lower_64bit_to_32);
649 
650    if (nir->info.stage == MESA_SHADER_FRAGMENT) {
651       NIR_PASS(_, nir, nir_shader_intrinsics_pass, lower_viewport_fs,
652                nir_metadata_control_flow, NULL);
653    }
654 
655    NIR_PASS(_, nir, agx_nir_lower_texture);
656    NIR_PASS(_, nir, agx_nir_lower_multisampled_image_store);
657 
658    agx_preprocess_nir(nir, dev->dev.libagx);
659    NIR_PASS(_, nir, nir_opt_conditional_discard);
660    NIR_PASS(_, nir, nir_opt_if,
661             nir_opt_if_optimize_phi_true_false | nir_opt_if_avoid_64bit_phis);
662 }
663 
664 static void
hk_upload_shader(struct hk_device * dev,struct hk_shader * shader)665 hk_upload_shader(struct hk_device *dev, struct hk_shader *shader)
666 {
667    if (shader->b.info.has_preamble) {
668       unsigned offs = shader->b.info.preamble_offset;
669       assert(offs < shader->b.binary_size);
670 
671       size_t size = shader->b.binary_size - offs;
672       assert(size > 0);
673 
674       shader->bo = agx_bo_create(&dev->dev, size, 0,
675                                  AGX_BO_EXEC | AGX_BO_LOW_VA, "Preamble");
676       memcpy(shader->bo->map, shader->b.binary + offs, size);
677       shader->preamble_addr = shader->bo->va->addr;
678    }
679 
680    if (!shader->linked.ht) {
681       /* If we only have a single variant, link now. */
682       shader->only_linked = hk_fast_link(dev, false, shader, NULL, NULL, 0);
683    }
684 
685    if (shader->info.stage == MESA_SHADER_FRAGMENT) {
686       agx_pack(&shader->frag_face, FRAGMENT_FACE_2, cfg) {
687          cfg.conservative_depth =
688             agx_translate_depth_layout(shader->b.info.depth_layout);
689       }
690    }
691 
692    agx_pack(&shader->counts, COUNTS, cfg) {
693       cfg.uniform_register_count = shader->b.info.push_count;
694       cfg.preshader_register_count = shader->b.info.nr_preamble_gprs;
695       cfg.sampler_state_register_count = agx_translate_sampler_state_count(
696          shader->b.info.uses_txf ? 1 : 0, false);
697    }
698 }
699 
700 DERIVE_HASH_TABLE(hk_fast_link_key_vs);
701 DERIVE_HASH_TABLE(hk_fast_link_key_fs);
702 
703 static VkResult
hk_init_link_ht(struct hk_shader * shader,gl_shader_stage sw_stage)704 hk_init_link_ht(struct hk_shader *shader, gl_shader_stage sw_stage)
705 {
706    simple_mtx_init(&shader->linked.lock, mtx_plain);
707 
708    bool multiple_variants =
709       sw_stage == MESA_SHADER_VERTEX || sw_stage == MESA_SHADER_FRAGMENT;
710 
711    if (!multiple_variants)
712       return VK_SUCCESS;
713 
714    if (sw_stage == MESA_SHADER_VERTEX)
715       shader->linked.ht = hk_fast_link_key_vs_table_create(NULL);
716    else
717       shader->linked.ht = hk_fast_link_key_fs_table_create(NULL);
718 
719    return (shader->linked.ht == NULL) ? VK_ERROR_OUT_OF_HOST_MEMORY
720                                       : VK_SUCCESS;
721 }
722 
723 static VkResult
hk_compile_nir(struct hk_device * dev,const VkAllocationCallbacks * pAllocator,nir_shader * nir,VkShaderCreateFlagsEXT shader_flags,const struct vk_pipeline_robustness_state * rs,const struct hk_fs_key * fs_key,struct hk_shader * shader,gl_shader_stage sw_stage,bool hw,nir_xfb_info * xfb_info)724 hk_compile_nir(struct hk_device *dev, const VkAllocationCallbacks *pAllocator,
725                nir_shader *nir, VkShaderCreateFlagsEXT shader_flags,
726                const struct vk_pipeline_robustness_state *rs,
727                const struct hk_fs_key *fs_key, struct hk_shader *shader,
728                gl_shader_stage sw_stage, bool hw, nir_xfb_info *xfb_info)
729 {
730    unsigned vs_uniform_base = 0;
731 
732    /* For now, only shader objects are supported */
733    if (sw_stage == MESA_SHADER_VERTEX) {
734       vs_uniform_base =
735          6 * DIV_ROUND_UP(
736                 BITSET_LAST_BIT(shader->info.vs.attrib_components_read), 4);
737    } else if (sw_stage == MESA_SHADER_FRAGMENT) {
738       shader->info.fs.interp = agx_gather_interp_info(nir);
739       shader->info.fs.writes_memory = nir->info.writes_memory;
740 
741       /* Discards must be lowering before lowering MSAA to handle discards */
742       NIR_PASS(_, nir, agx_nir_lower_discard_zs_emit);
743       NIR_PASS(_, nir, agx_nir_lower_fs_output_to_epilog,
744                &shader->info.fs.epilog_key);
745       NIR_PASS(_, nir, agx_nir_lower_sample_mask);
746 
747       if (nir->info.fs.uses_sample_shading) {
748          /* Ensure the sample ID is preserved in register */
749          nir_builder b =
750             nir_builder_at(nir_after_impl(nir_shader_get_entrypoint(nir)));
751          nir_export_agx(&b, nir_load_exported_agx(&b, 1, 16, .base = 1),
752                         .base = 1);
753 
754          NIR_PASS(_, nir, agx_nir_lower_to_per_sample);
755       }
756 
757       NIR_PASS(_, nir, agx_nir_lower_fs_active_samples_to_register);
758       NIR_PASS(_, nir, agx_nir_lower_interpolation);
759    } else if (sw_stage == MESA_SHADER_TESS_EVAL) {
760       shader->info.ts.ccw = nir->info.tess.ccw;
761       shader->info.ts.point_mode = nir->info.tess.point_mode;
762       shader->info.ts.spacing = nir->info.tess.spacing;
763       shader->info.ts.mode = nir->info.tess._primitive_mode;
764 
765       if (nir->info.tess.point_mode) {
766          shader->info.ts.out_prim = MESA_PRIM_POINTS;
767       } else if (nir->info.tess._primitive_mode == TESS_PRIMITIVE_ISOLINES) {
768          shader->info.ts.out_prim = MESA_PRIM_LINES;
769       } else {
770          shader->info.ts.out_prim = MESA_PRIM_TRIANGLES;
771       }
772 
773       /* This destroys info so it needs to happen after the gather */
774       NIR_PASS(_, nir, agx_nir_lower_tes, dev->dev.libagx, hw);
775    } else if (sw_stage == MESA_SHADER_TESS_CTRL) {
776       shader->info.tcs.output_patch_size = nir->info.tess.tcs_vertices_out;
777       shader->info.tcs.per_vertex_outputs = agx_tcs_per_vertex_outputs(nir);
778       shader->info.tcs.nr_patch_outputs =
779          util_last_bit(nir->info.patch_outputs_written);
780       shader->info.tcs.output_stride = agx_tcs_output_stride(nir);
781    }
782 
783    uint64_t outputs = nir->info.outputs_written;
784    if (!hw &&
785        (sw_stage == MESA_SHADER_VERTEX || sw_stage == MESA_SHADER_TESS_EVAL)) {
786       nir->info.stage = MESA_SHADER_COMPUTE;
787       memset(&nir->info.cs, 0, sizeof(nir->info.cs));
788       nir->xfb_info = NULL;
789    }
790 
791    /* XXX: rename */
792    NIR_PASS(_, nir, hk_lower_uvs_index, vs_uniform_base);
793 
794 #if 0
795    /* TODO */
796    nir_variable_mode robust2_modes = 0;
797    if (rs->uniform_buffers == VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_ROBUST_BUFFER_ACCESS_2_EXT)
798       robust2_modes |= nir_var_mem_ubo;
799    if (rs->storage_buffers == VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_ROBUST_BUFFER_ACCESS_2_EXT)
800       robust2_modes |= nir_var_mem_ssbo;
801 #endif
802 
803    struct agx_shader_key backend_key = {
804       .dev = agx_gather_device_key(&dev->dev),
805       .reserved_preamble = 128 /* TODO */,
806       .libagx = dev->dev.libagx,
807       .no_stop = nir->info.stage == MESA_SHADER_FRAGMENT,
808       .has_scratch = true,
809    };
810 
811    /* For now, sample shading is always dynamic. Indicate that. */
812    if (nir->info.stage == MESA_SHADER_FRAGMENT &&
813        nir->info.fs.uses_sample_shading)
814       backend_key.fs.inside_sample_loop = true;
815 
816    agx_compile_shader_nir(nir, &backend_key, NULL, &shader->b);
817 
818    shader->code_ptr = shader->b.binary;
819    shader->code_size = shader->b.binary_size;
820 
821    shader->info.stage = sw_stage;
822    shader->info.clip_distance_array_size = nir->info.clip_distance_array_size;
823    shader->info.cull_distance_array_size = nir->info.cull_distance_array_size;
824    shader->b.info.outputs = outputs;
825 
826    if (sw_stage == MESA_SHADER_COMPUTE) {
827       for (unsigned i = 0; i < 3; ++i)
828          shader->info.cs.local_size[i] = nir->info.workgroup_size[i];
829    }
830 
831    if (xfb_info) {
832       assert(xfb_info->output_count < ARRAY_SIZE(shader->info.xfb_outputs));
833 
834       memcpy(&shader->info.xfb_info, xfb_info,
835              nir_xfb_info_size(xfb_info->output_count));
836 
837       typed_memcpy(shader->info.xfb_stride, nir->info.xfb_stride, 4);
838    }
839 
840    if (nir->constant_data_size > 0) {
841       uint32_t data_size = align(nir->constant_data_size, HK_MIN_UBO_ALIGNMENT);
842 
843       void *data = malloc(data_size);
844       if (data == NULL) {
845          ralloc_free(nir);
846          return vk_error(dev, VK_ERROR_OUT_OF_HOST_MEMORY);
847       }
848 
849       memcpy(data, nir->constant_data, nir->constant_data_size);
850 
851       assert(nir->constant_data_size <= data_size);
852       memset(data + nir->constant_data_size, 0,
853              data_size - nir->constant_data_size);
854 
855       shader->data_ptr = data;
856       shader->data_size = data_size;
857    }
858 
859    ralloc_free(nir);
860 
861    VkResult result = hk_init_link_ht(shader, sw_stage);
862    if (result != VK_SUCCESS)
863       return vk_error(dev, result);
864 
865    hk_upload_shader(dev, shader);
866    return VK_SUCCESS;
867 }
868 
869 static const struct vk_shader_ops hk_shader_ops;
870 
871 static void
hk_destroy_linked_shader(struct hk_device * dev,struct hk_linked_shader * linked)872 hk_destroy_linked_shader(struct hk_device *dev, struct hk_linked_shader *linked)
873 {
874    agx_bo_unreference(&dev->dev, linked->b.bo);
875    ralloc_free(linked);
876 }
877 
878 static void
hk_shader_destroy(struct hk_device * dev,struct hk_shader * s)879 hk_shader_destroy(struct hk_device *dev, struct hk_shader *s)
880 {
881    free((void *)s->code_ptr);
882    free((void *)s->data_ptr);
883    agx_bo_unreference(&dev->dev, s->bo);
884 
885    simple_mtx_destroy(&s->linked.lock);
886 
887    if (s->only_linked)
888       hk_destroy_linked_shader(dev, s->only_linked);
889 
890    if (s->linked.ht) {
891       hash_table_foreach(s->linked.ht, entry) {
892          hk_destroy_linked_shader(dev, entry->data);
893       }
894       _mesa_hash_table_destroy(s->linked.ht, NULL);
895    }
896 }
897 
898 void
hk_api_shader_destroy(struct vk_device * vk_dev,struct vk_shader * vk_shader,const VkAllocationCallbacks * pAllocator)899 hk_api_shader_destroy(struct vk_device *vk_dev, struct vk_shader *vk_shader,
900                       const VkAllocationCallbacks *pAllocator)
901 {
902    struct hk_device *dev = container_of(vk_dev, struct hk_device, vk);
903    struct hk_api_shader *obj =
904       container_of(vk_shader, struct hk_api_shader, vk);
905 
906    hk_foreach_variant(obj, shader) {
907       hk_shader_destroy(dev, shader);
908    }
909 
910    vk_shader_free(&dev->vk, pAllocator, &obj->vk);
911 }
912 
913 static void
hk_lower_hw_vs(nir_shader * nir,struct hk_shader * shader)914 hk_lower_hw_vs(nir_shader *nir, struct hk_shader *shader)
915 {
916    /* Point size must be clamped, excessively large points don't render
917     * properly on G13.
918     *
919     * Must be synced with pointSizeRange.
920     */
921    NIR_PASS(_, nir, nir_lower_point_size, 1.0f, 511.95f);
922 
923    /* TODO: Optimize out for monolithic? */
924    NIR_PASS(_, nir, hk_nir_insert_psiz_write);
925 
926    NIR_PASS(_, nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
927    NIR_PASS(_, nir, agx_nir_lower_cull_distance_vs);
928 
929    NIR_PASS(_, nir, agx_nir_lower_uvs, &shader->info.uvs);
930 
931    shader->info.vs.cull_distance_array_size =
932       nir->info.cull_distance_array_size;
933 }
934 
935 VkResult
hk_compile_shader(struct hk_device * dev,struct vk_shader_compile_info * info,const struct vk_graphics_pipeline_state * state,const VkAllocationCallbacks * pAllocator,struct hk_api_shader ** shader_out)936 hk_compile_shader(struct hk_device *dev, struct vk_shader_compile_info *info,
937                   const struct vk_graphics_pipeline_state *state,
938                   const VkAllocationCallbacks *pAllocator,
939                   struct hk_api_shader **shader_out)
940 {
941    VkResult result;
942 
943    /* We consume the NIR, regardless of success or failure */
944    nir_shader *nir = info->nir;
945 
946    size_t size = sizeof(struct hk_api_shader) +
947                  sizeof(struct hk_shader) * hk_num_variants(info->stage);
948    struct hk_api_shader *obj =
949       vk_shader_zalloc(&dev->vk, &hk_shader_ops, info->stage, pAllocator, size);
950 
951    if (obj == NULL) {
952       ralloc_free(nir);
953       return vk_error(dev, VK_ERROR_OUT_OF_HOST_MEMORY);
954    }
955 
956    /* TODO: Multiview with ESO */
957    const bool is_multiview = state && state->rp->view_mask != 0;
958 
959    hk_lower_nir(dev, nir, info->robustness, is_multiview,
960                 info->set_layout_count, info->set_layouts);
961 
962    gl_shader_stage sw_stage = nir->info.stage;
963 
964    struct hk_fs_key fs_key_tmp, *fs_key = NULL;
965    if (sw_stage == MESA_SHADER_FRAGMENT) {
966       hk_populate_fs_key(&fs_key_tmp, state);
967       fs_key = &fs_key_tmp;
968 
969       nir->info.fs.uses_sample_shading |= fs_key->force_sample_shading;
970 
971       /* Force late-Z for Z/S self-deps. TODO: There's probably a less silly way
972        * to do this.
973        */
974       if (fs_key->zs_self_dep) {
975          nir_builder b =
976             nir_builder_at(nir_before_impl(nir_shader_get_entrypoint(nir)));
977          nir_discard_if(&b, nir_imm_false(&b));
978          nir->info.fs.uses_discard = true;
979       }
980 
981       NIR_PASS(_, nir, agx_nir_lower_sample_intrinsics, false);
982    } else if (sw_stage == MESA_SHADER_TESS_CTRL) {
983       NIR_PASS_V(nir, agx_nir_lower_tcs, dev->dev.libagx);
984    }
985 
986    /* Compile all variants up front */
987    if (sw_stage == MESA_SHADER_GEOMETRY) {
988       for (unsigned rast_disc = 0; rast_disc < 2; ++rast_disc) {
989          struct hk_shader *count_variant = hk_count_gs_variant(obj, rast_disc);
990          nir_shader *clone = nir_shader_clone(NULL, nir);
991 
992          enum mesa_prim out_prim = MESA_PRIM_MAX;
993          nir_shader *count = NULL, *rast = NULL, *pre_gs = NULL;
994 
995          NIR_PASS(_, clone, agx_nir_lower_gs, dev->dev.libagx, rast_disc,
996                   &count, &rast, &pre_gs, &out_prim,
997                   &count_variant->info.gs.count_words);
998 
999          if (!rast_disc) {
1000             struct hk_shader *shader = &obj->variants[HK_GS_VARIANT_RAST];
1001 
1002             hk_lower_hw_vs(rast, shader);
1003             shader->info.gs.out_prim = out_prim;
1004          }
1005 
1006          struct {
1007             nir_shader *in;
1008             struct hk_shader *out;
1009          } variants[] = {
1010             {clone, hk_main_gs_variant(obj, rast_disc)},
1011             {pre_gs, hk_pre_gs_variant(obj, rast_disc)},
1012             {count, count_variant},
1013             {rast_disc ? NULL : rast, &obj->variants[HK_GS_VARIANT_RAST]},
1014          };
1015 
1016          for (unsigned v = 0; v < ARRAY_SIZE(variants); ++v) {
1017             if (variants[v].in) {
1018                result = hk_compile_nir(dev, pAllocator, variants[v].in,
1019                                        info->flags, info->robustness, NULL,
1020                                        variants[v].out, sw_stage, true, NULL);
1021                if (result != VK_SUCCESS) {
1022                   hk_api_shader_destroy(&dev->vk, &obj->vk, pAllocator);
1023                   ralloc_free(nir);
1024                   return result;
1025                }
1026             }
1027          }
1028       }
1029    } else if (sw_stage == MESA_SHADER_VERTEX ||
1030               sw_stage == MESA_SHADER_TESS_EVAL) {
1031 
1032       if (sw_stage == MESA_SHADER_VERTEX) {
1033          assert(
1034             !(nir->info.inputs_read & BITFIELD64_MASK(VERT_ATTRIB_GENERIC0)) &&
1035             "Fixed-function attributes not used in Vulkan");
1036 
1037          NIR_PASS(_, nir, nir_recompute_io_bases, nir_var_shader_in);
1038       }
1039 
1040       /* the shader_out portion of this is load-bearing even for tess eval */
1041       NIR_PASS(_, nir, nir_io_add_const_offset_to_base,
1042                nir_var_shader_in | nir_var_shader_out);
1043 
1044       for (enum hk_vs_variant v = 0; v < HK_VS_VARIANTS; ++v) {
1045          struct hk_shader *shader = &obj->variants[v];
1046          bool hw = v == HK_VS_VARIANT_HW;
1047 
1048          /* TODO: Optimize single variant when we know nextStage */
1049          nir_shader *clone = nir_shader_clone(NULL, nir);
1050 
1051          if (sw_stage == MESA_SHADER_VERTEX) {
1052             NIR_PASS(_, clone, agx_nir_lower_vs_input_to_prolog,
1053                      shader->info.vs.attrib_components_read);
1054 
1055             shader->info.vs.attribs_read =
1056                nir->info.inputs_read >> VERT_ATTRIB_GENERIC0;
1057          }
1058 
1059          if (hw) {
1060             hk_lower_hw_vs(clone, shader);
1061          } else {
1062             NIR_PASS(_, clone, agx_nir_lower_vs_before_gs, dev->dev.libagx);
1063          }
1064 
1065          result = hk_compile_nir(dev, pAllocator, clone, info->flags,
1066                                  info->robustness, fs_key, shader, sw_stage, hw,
1067                                  nir->xfb_info);
1068          if (result != VK_SUCCESS) {
1069             hk_api_shader_destroy(&dev->vk, &obj->vk, pAllocator);
1070             ralloc_free(nir);
1071             return result;
1072          }
1073       }
1074    } else {
1075       struct hk_shader *shader = hk_only_variant(obj);
1076 
1077       /* hk_compile_nir takes ownership of nir */
1078       result =
1079          hk_compile_nir(dev, pAllocator, nir, info->flags, info->robustness,
1080                         fs_key, shader, sw_stage, true, NULL);
1081       if (result != VK_SUCCESS) {
1082          hk_api_shader_destroy(&dev->vk, &obj->vk, pAllocator);
1083          return result;
1084       }
1085    }
1086 
1087    *shader_out = obj;
1088    return VK_SUCCESS;
1089 }
1090 
1091 static VkResult
hk_compile_shaders(struct vk_device * vk_dev,uint32_t shader_count,struct vk_shader_compile_info * infos,const struct vk_graphics_pipeline_state * state,const VkAllocationCallbacks * pAllocator,struct vk_shader ** shaders_out)1092 hk_compile_shaders(struct vk_device *vk_dev, uint32_t shader_count,
1093                    struct vk_shader_compile_info *infos,
1094                    const struct vk_graphics_pipeline_state *state,
1095                    const VkAllocationCallbacks *pAllocator,
1096                    struct vk_shader **shaders_out)
1097 {
1098    struct hk_device *dev = container_of(vk_dev, struct hk_device, vk);
1099 
1100    for (uint32_t i = 0; i < shader_count; i++) {
1101       VkResult result =
1102          hk_compile_shader(dev, &infos[i], state, pAllocator,
1103                            (struct hk_api_shader **)&shaders_out[i]);
1104       if (result != VK_SUCCESS) {
1105          /* Clean up all the shaders before this point */
1106          for (uint32_t j = 0; j < i; j++)
1107             hk_api_shader_destroy(&dev->vk, shaders_out[j], pAllocator);
1108 
1109          /* Clean up all the NIR after this point */
1110          for (uint32_t j = i + 1; j < shader_count; j++)
1111             ralloc_free(infos[j].nir);
1112 
1113          /* Memset the output array */
1114          memset(shaders_out, 0, shader_count * sizeof(*shaders_out));
1115 
1116          return result;
1117       }
1118    }
1119 
1120    return VK_SUCCESS;
1121 }
1122 
1123 static VkResult
hk_deserialize_shader(struct hk_device * dev,struct blob_reader * blob,struct hk_shader * shader)1124 hk_deserialize_shader(struct hk_device *dev, struct blob_reader *blob,
1125                       struct hk_shader *shader)
1126 {
1127    struct hk_shader_info info;
1128    blob_copy_bytes(blob, &info, sizeof(info));
1129 
1130    struct agx_shader_info b_info;
1131    blob_copy_bytes(blob, &b_info, sizeof(b_info));
1132 
1133    const uint32_t code_size = blob_read_uint32(blob);
1134    const uint32_t data_size = blob_read_uint32(blob);
1135    if (blob->overrun)
1136       return vk_error(dev, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
1137 
1138    VkResult result = hk_init_link_ht(shader, info.stage);
1139    if (result != VK_SUCCESS)
1140       return vk_error(dev, VK_ERROR_OUT_OF_HOST_MEMORY);
1141 
1142    simple_mtx_init(&shader->linked.lock, mtx_plain);
1143 
1144    shader->b.info = b_info;
1145    shader->info = info;
1146    shader->code_size = code_size;
1147    shader->data_size = data_size;
1148    shader->b.binary_size = code_size;
1149 
1150    shader->code_ptr = malloc(code_size);
1151    if (shader->code_ptr == NULL)
1152       return vk_error(dev, VK_ERROR_OUT_OF_HOST_MEMORY);
1153 
1154    shader->data_ptr = malloc(data_size);
1155    if (shader->data_ptr == NULL)
1156       return vk_error(dev, VK_ERROR_OUT_OF_HOST_MEMORY);
1157 
1158    blob_copy_bytes(blob, (void *)shader->code_ptr, shader->code_size);
1159    blob_copy_bytes(blob, (void *)shader->data_ptr, shader->data_size);
1160    if (blob->overrun)
1161       return vk_error(dev, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
1162 
1163    shader->b.binary = (void *)shader->code_ptr;
1164    hk_upload_shader(dev, shader);
1165    return VK_SUCCESS;
1166 }
1167 
1168 static VkResult
hk_deserialize_api_shader(struct vk_device * vk_dev,struct blob_reader * blob,uint32_t binary_version,const VkAllocationCallbacks * pAllocator,struct vk_shader ** shader_out)1169 hk_deserialize_api_shader(struct vk_device *vk_dev, struct blob_reader *blob,
1170                           uint32_t binary_version,
1171                           const VkAllocationCallbacks *pAllocator,
1172                           struct vk_shader **shader_out)
1173 {
1174    struct hk_device *dev = container_of(vk_dev, struct hk_device, vk);
1175 
1176    gl_shader_stage stage = blob_read_uint8(blob);
1177    if (blob->overrun)
1178       return vk_error(dev, VK_ERROR_INCOMPATIBLE_SHADER_BINARY_EXT);
1179 
1180    size_t size = sizeof(struct hk_api_shader) +
1181                  sizeof(struct hk_shader) * hk_num_variants(stage);
1182 
1183    struct hk_api_shader *obj =
1184       vk_shader_zalloc(&dev->vk, &hk_shader_ops, stage, pAllocator, size);
1185 
1186    if (obj == NULL)
1187       return vk_error(dev, VK_ERROR_OUT_OF_HOST_MEMORY);
1188 
1189    hk_foreach_variant(obj, shader) {
1190       VkResult result = hk_deserialize_shader(dev, blob, shader);
1191 
1192       if (result != VK_SUCCESS) {
1193          hk_api_shader_destroy(&dev->vk, &obj->vk, pAllocator);
1194          return result;
1195       }
1196    }
1197 
1198    *shader_out = &obj->vk;
1199    return VK_SUCCESS;
1200 }
1201 
1202 static void
hk_shader_serialize(struct vk_device * vk_dev,const struct hk_shader * shader,struct blob * blob)1203 hk_shader_serialize(struct vk_device *vk_dev, const struct hk_shader *shader,
1204                     struct blob *blob)
1205 {
1206    blob_write_bytes(blob, &shader->info, sizeof(shader->info));
1207    blob_write_bytes(blob, &shader->b.info, sizeof(shader->b.info));
1208 
1209    blob_write_uint32(blob, shader->code_size);
1210    blob_write_uint32(blob, shader->data_size);
1211    blob_write_bytes(blob, shader->code_ptr, shader->code_size);
1212    blob_write_bytes(blob, shader->data_ptr, shader->data_size);
1213 }
1214 
1215 static bool
hk_api_shader_serialize(struct vk_device * vk_dev,const struct vk_shader * vk_shader,struct blob * blob)1216 hk_api_shader_serialize(struct vk_device *vk_dev,
1217                         const struct vk_shader *vk_shader, struct blob *blob)
1218 {
1219    struct hk_api_shader *obj =
1220       container_of(vk_shader, struct hk_api_shader, vk);
1221 
1222    blob_write_uint8(blob, vk_shader->stage);
1223 
1224    hk_foreach_variant(obj, shader) {
1225       hk_shader_serialize(vk_dev, shader, blob);
1226    }
1227 
1228    return !blob->out_of_memory;
1229 }
1230 
1231 #define WRITE_STR(field, ...)                                                  \
1232    ({                                                                          \
1233       memset(field, 0, sizeof(field));                                         \
1234       UNUSED int i = snprintf(field, sizeof(field), __VA_ARGS__);              \
1235       assert(i > 0 && i < sizeof(field));                                      \
1236    })
1237 
1238 static VkResult
hk_shader_get_executable_properties(UNUSED struct vk_device * device,const struct vk_shader * vk_shader,uint32_t * executable_count,VkPipelineExecutablePropertiesKHR * properties)1239 hk_shader_get_executable_properties(
1240    UNUSED struct vk_device *device, const struct vk_shader *vk_shader,
1241    uint32_t *executable_count, VkPipelineExecutablePropertiesKHR *properties)
1242 {
1243    struct hk_api_shader *obj =
1244       container_of(vk_shader, struct hk_api_shader, vk);
1245 
1246    VK_OUTARRAY_MAKE_TYPED(VkPipelineExecutablePropertiesKHR, out, properties,
1247                           executable_count);
1248 
1249    vk_outarray_append_typed(VkPipelineExecutablePropertiesKHR, &out, props)
1250    {
1251       props->stages = mesa_to_vk_shader_stage(obj->vk.stage);
1252       props->subgroupSize = 32;
1253       WRITE_STR(props->name, "%s", _mesa_shader_stage_to_string(obj->vk.stage));
1254       WRITE_STR(props->description, "%s shader",
1255                 _mesa_shader_stage_to_string(obj->vk.stage));
1256    }
1257 
1258    return vk_outarray_status(&out);
1259 }
1260 
1261 static VkResult
hk_shader_get_executable_statistics(UNUSED struct vk_device * device,const struct vk_shader * vk_shader,uint32_t executable_index,uint32_t * statistic_count,VkPipelineExecutableStatisticKHR * statistics)1262 hk_shader_get_executable_statistics(
1263    UNUSED struct vk_device *device, const struct vk_shader *vk_shader,
1264    uint32_t executable_index, uint32_t *statistic_count,
1265    VkPipelineExecutableStatisticKHR *statistics)
1266 {
1267    struct hk_api_shader *obj =
1268       container_of(vk_shader, struct hk_api_shader, vk);
1269 
1270    VK_OUTARRAY_MAKE_TYPED(VkPipelineExecutableStatisticKHR, out, statistics,
1271                           statistic_count);
1272 
1273    assert(executable_index == 0);
1274 
1275    /* TODO: find a sane way to report multiple variants and have that play nice
1276     * with zink.
1277     */
1278    struct hk_shader *shader = hk_any_variant(obj);
1279 
1280    vk_outarray_append_typed(VkPipelineExecutableStatisticKHR, &out, stat)
1281    {
1282       WRITE_STR(stat->name, "Code Size");
1283       WRITE_STR(stat->description,
1284                 "Size of the compiled shader binary, in bytes");
1285       stat->format = VK_PIPELINE_EXECUTABLE_STATISTIC_FORMAT_UINT64_KHR;
1286       stat->value.u64 = shader->code_size;
1287    }
1288 
1289    vk_outarray_append_typed(VkPipelineExecutableStatisticKHR, &out, stat)
1290    {
1291       WRITE_STR(stat->name, "Number of GPRs");
1292       WRITE_STR(stat->description, "Number of GPRs used by this pipeline");
1293       stat->format = VK_PIPELINE_EXECUTABLE_STATISTIC_FORMAT_UINT64_KHR;
1294       stat->value.u64 = shader->b.info.nr_gprs;
1295    }
1296 
1297    return vk_outarray_status(&out);
1298 }
1299 
1300 static bool
write_ir_text(VkPipelineExecutableInternalRepresentationKHR * ir,const char * data)1301 write_ir_text(VkPipelineExecutableInternalRepresentationKHR *ir,
1302               const char *data)
1303 {
1304    ir->isText = VK_TRUE;
1305 
1306    size_t data_len = strlen(data) + 1;
1307 
1308    if (ir->pData == NULL) {
1309       ir->dataSize = data_len;
1310       return true;
1311    }
1312 
1313    strncpy(ir->pData, data, ir->dataSize);
1314    if (ir->dataSize < data_len)
1315       return false;
1316 
1317    ir->dataSize = data_len;
1318    return true;
1319 }
1320 
1321 static VkResult
hk_shader_get_executable_internal_representations(UNUSED struct vk_device * device,const struct vk_shader * vk_shader,uint32_t executable_index,uint32_t * internal_representation_count,VkPipelineExecutableInternalRepresentationKHR * internal_representations)1322 hk_shader_get_executable_internal_representations(
1323    UNUSED struct vk_device *device, const struct vk_shader *vk_shader,
1324    uint32_t executable_index, uint32_t *internal_representation_count,
1325    VkPipelineExecutableInternalRepresentationKHR *internal_representations)
1326 {
1327    VK_OUTARRAY_MAKE_TYPED(VkPipelineExecutableInternalRepresentationKHR, out,
1328                           internal_representations,
1329                           internal_representation_count);
1330    bool incomplete_text = false;
1331 
1332    assert(executable_index == 0);
1333 
1334    /* TODO */
1335 #if 0
1336    vk_outarray_append_typed(VkPipelineExecutableInternalRepresentationKHR, &out, ir) {
1337       WRITE_STR(ir->name, "AGX assembly");
1338       WRITE_STR(ir->description, "AGX assembly");
1339       if (!write_ir_text(ir, TODO))
1340          incomplete_text = true;
1341    }
1342 #endif
1343 
1344    return incomplete_text ? VK_INCOMPLETE : vk_outarray_status(&out);
1345 }
1346 
1347 static const struct vk_shader_ops hk_shader_ops = {
1348    .destroy = hk_api_shader_destroy,
1349    .serialize = hk_api_shader_serialize,
1350    .get_executable_properties = hk_shader_get_executable_properties,
1351    .get_executable_statistics = hk_shader_get_executable_statistics,
1352    .get_executable_internal_representations =
1353       hk_shader_get_executable_internal_representations,
1354 };
1355 
1356 const struct vk_device_shader_ops hk_device_shader_ops = {
1357    .get_nir_options = hk_get_nir_options,
1358    .get_spirv_options = hk_get_spirv_options,
1359    .preprocess_nir = hk_preprocess_nir,
1360    .hash_graphics_state = hk_hash_graphics_state,
1361    .compile = hk_compile_shaders,
1362    .deserialize = hk_deserialize_api_shader,
1363    .cmd_set_dynamic_graphics_state = vk_cmd_set_dynamic_graphics_state,
1364    .cmd_bind_shaders = hk_cmd_bind_shaders,
1365 };
1366 
1367 struct hk_linked_shader *
hk_fast_link(struct hk_device * dev,bool fragment,struct hk_shader * main,struct agx_shader_part * prolog,struct agx_shader_part * epilog,unsigned nr_samples_shaded)1368 hk_fast_link(struct hk_device *dev, bool fragment, struct hk_shader *main,
1369              struct agx_shader_part *prolog, struct agx_shader_part *epilog,
1370              unsigned nr_samples_shaded)
1371 {
1372    struct hk_linked_shader *s = rzalloc(NULL, struct hk_linked_shader);
1373    agx_fast_link(&s->b, &dev->dev, fragment, &main->b, prolog, epilog,
1374                  nr_samples_shaded);
1375 
1376    if (fragment) {
1377       agx_pack(&s->fs_counts, FRAGMENT_SHADER_WORD_0, cfg) {
1378          cfg.cf_binding_count = s->b.cf.nr_bindings;
1379          cfg.uniform_register_count = main->b.info.push_count;
1380          cfg.preshader_register_count = main->b.info.nr_preamble_gprs;
1381          cfg.sampler_state_register_count =
1382             agx_translate_sampler_state_count(s->b.uses_txf ? 1 : 0, false);
1383       }
1384    }
1385 
1386    /* Now that we've linked, bake the USC words to bind this program */
1387    struct agx_usc_builder b = agx_usc_builder(s->usc.data, sizeof(s->usc.data));
1388 
1389    if (main && main->b.info.immediate_size_16) {
1390       unreachable("todo");
1391 #if 0
1392       /* XXX: do ahead of time */
1393       uint64_t ptr = agx_pool_upload_aligned(
1394          &cmd->pool, s->b.info.immediates, s->b.info.immediate_size_16 * 2, 64);
1395 
1396       for (unsigned range = 0; range < constant_push_ranges; ++range) {
1397          unsigned offset = 64 * range;
1398          assert(offset < s->b.info.immediate_size_16);
1399 
1400          agx_usc_uniform(&b, s->b.info.immediate_base_uniform + offset,
1401                          MIN2(64, s->b.info.immediate_size_16 - offset),
1402                          ptr + (offset * 2));
1403       }
1404 #endif
1405    }
1406 
1407    agx_usc_push_packed(&b, UNIFORM, dev->rodata.image_heap);
1408 
1409    if (s->b.uses_txf)
1410       agx_usc_push_packed(&b, SAMPLER, dev->rodata.txf_sampler);
1411 
1412    agx_usc_shared_non_fragment(&b, &main->b.info, 0);
1413    agx_usc_push_packed(&b, SHADER, s->b.shader);
1414    agx_usc_push_packed(&b, REGISTERS, s->b.regs);
1415 
1416    if (fragment)
1417       agx_usc_push_packed(&b, FRAGMENT_PROPERTIES, s->b.fragment_props);
1418 
1419    if (main && main->b.info.has_preamble) {
1420       agx_usc_pack(&b, PRESHADER, cfg) {
1421          cfg.code = agx_usc_addr(&dev->dev, main->preamble_addr);
1422       }
1423    } else {
1424       agx_usc_pack(&b, NO_PRESHADER, cfg)
1425          ;
1426    }
1427 
1428    s->usc.size = b.head - s->usc.data;
1429    return s;
1430 }
1431