xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2021 Google
3  * Copyright © 2023 Valve Corporation
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir.h"
8 #include "nir_builder.h"
9 #include "radv_constants.h"
10 #include "radv_nir.h"
11 
12 struct lower_hit_attrib_deref_args {
13    nir_variable_mode mode;
14    uint32_t base_offset;
15 };
16 
17 static bool
lower_hit_attrib_deref(nir_builder * b,nir_instr * instr,void * data)18 lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
19 {
20    if (instr->type != nir_instr_type_intrinsic)
21       return false;
22 
23    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
24    if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != nir_intrinsic_store_deref)
25       return false;
26 
27    struct lower_hit_attrib_deref_args *args = data;
28    nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
29    if (!nir_deref_mode_is(deref, args->mode))
30       return false;
31 
32    assert(deref->deref_type == nir_deref_type_var);
33 
34    b->cursor = nir_after_instr(instr);
35 
36    if (intrin->intrinsic == nir_intrinsic_load_deref) {
37       uint32_t num_components = intrin->def.num_components;
38       uint32_t bit_size = intrin->def.bit_size;
39 
40       nir_def *components[NIR_MAX_VEC_COMPONENTS];
41 
42       for (uint32_t comp = 0; comp < num_components; comp++) {
43          uint32_t offset = args->base_offset + deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8);
44          uint32_t base = offset / 4;
45          uint32_t comp_offset = offset % 4;
46 
47          if (bit_size == 64) {
48             components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base),
49                                                       nir_load_hit_attrib_amd(b, .base = base + 1));
50          } else if (bit_size == 32) {
51             components[comp] = nir_load_hit_attrib_amd(b, .base = base);
52          } else if (bit_size == 16) {
53             components[comp] =
54                nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2);
55          } else if (bit_size == 8) {
56             components[comp] =
57                nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset);
58          } else {
59             assert(bit_size == 1);
60             components[comp] = nir_i2b(b, nir_load_hit_attrib_amd(b, .base = base));
61          }
62       }
63 
64       nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, num_components));
65    } else {
66       nir_def *value = intrin->src[1].ssa;
67       uint32_t num_components = value->num_components;
68       uint32_t bit_size = value->bit_size;
69 
70       for (uint32_t comp = 0; comp < num_components; comp++) {
71          uint32_t offset = args->base_offset + deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8);
72          uint32_t base = offset / 4;
73          uint32_t comp_offset = offset % 4;
74 
75          nir_def *component = nir_channel(b, value, comp);
76 
77          if (bit_size == 64) {
78             nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base);
79             nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1);
80          } else if (bit_size == 32) {
81             nir_store_hit_attrib_amd(b, component, .base = base);
82          } else if (bit_size == 16) {
83             nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base));
84             nir_def *components[2];
85             for (uint32_t word = 0; word < 2; word++)
86                components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word);
87             nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base);
88          } else if (bit_size == 8) {
89             nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8);
90             nir_def *components[4];
91             for (uint32_t byte = 0; byte < 4; byte++)
92                components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte);
93             nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base);
94          } else {
95             assert(bit_size == 1);
96             nir_store_hit_attrib_amd(b, nir_b2i32(b, component), .base = base);
97          }
98       }
99    }
100 
101    nir_instr_remove(instr);
102    return true;
103 }
104 
105 static bool
radv_nir_lower_rt_vars(nir_shader * shader,nir_variable_mode mode,uint32_t base_offset)106 radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base_offset)
107 {
108    bool progress = false;
109 
110    progress |= nir_split_struct_vars(shader, mode);
111    progress |= nir_lower_indirect_derefs(shader, mode, UINT32_MAX);
112    progress |= nir_split_array_vars(shader, mode);
113 
114    progress |= nir_lower_vars_to_explicit_types(shader, mode, glsl_get_natural_size_align_bytes);
115 
116    struct lower_hit_attrib_deref_args args = {
117       .mode = mode,
118       .base_offset = base_offset,
119    };
120 
121    progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref, nir_metadata_control_flow, &args);
122 
123    if (progress) {
124       nir_remove_dead_derefs(shader);
125       nir_remove_dead_variables(shader, mode, NULL);
126    }
127 
128    return progress;
129 }
130 
131 bool
radv_nir_lower_hit_attrib_derefs(nir_shader * shader)132 radv_nir_lower_hit_attrib_derefs(nir_shader *shader)
133 {
134    return radv_nir_lower_rt_vars(shader, nir_var_ray_hit_attrib, 0);
135 }
136 
137 bool
radv_nir_lower_ray_payload_derefs(nir_shader * shader,uint32_t offset)138 radv_nir_lower_ray_payload_derefs(nir_shader *shader, uint32_t offset)
139 {
140    bool progress = radv_nir_lower_rt_vars(shader, nir_var_function_temp, RADV_MAX_HIT_ATTRIB_SIZE + offset);
141    progress |= radv_nir_lower_rt_vars(shader, nir_var_shader_call_data, RADV_MAX_HIT_ATTRIB_SIZE + offset);
142    return progress;
143 }
144