1 /*
2 * Copyright © 2021 Google
3 * Copyright © 2023 Valve Corporation
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "nir.h"
8 #include "nir_builder.h"
9 #include "radv_constants.h"
10 #include "radv_nir.h"
11
12 struct lower_hit_attrib_deref_args {
13 nir_variable_mode mode;
14 uint32_t base_offset;
15 };
16
17 static bool
lower_hit_attrib_deref(nir_builder * b,nir_instr * instr,void * data)18 lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
19 {
20 if (instr->type != nir_instr_type_intrinsic)
21 return false;
22
23 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
24 if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != nir_intrinsic_store_deref)
25 return false;
26
27 struct lower_hit_attrib_deref_args *args = data;
28 nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
29 if (!nir_deref_mode_is(deref, args->mode))
30 return false;
31
32 assert(deref->deref_type == nir_deref_type_var);
33
34 b->cursor = nir_after_instr(instr);
35
36 if (intrin->intrinsic == nir_intrinsic_load_deref) {
37 uint32_t num_components = intrin->def.num_components;
38 uint32_t bit_size = intrin->def.bit_size;
39
40 nir_def *components[NIR_MAX_VEC_COMPONENTS];
41
42 for (uint32_t comp = 0; comp < num_components; comp++) {
43 uint32_t offset = args->base_offset + deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8);
44 uint32_t base = offset / 4;
45 uint32_t comp_offset = offset % 4;
46
47 if (bit_size == 64) {
48 components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base),
49 nir_load_hit_attrib_amd(b, .base = base + 1));
50 } else if (bit_size == 32) {
51 components[comp] = nir_load_hit_attrib_amd(b, .base = base);
52 } else if (bit_size == 16) {
53 components[comp] =
54 nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2);
55 } else if (bit_size == 8) {
56 components[comp] =
57 nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset);
58 } else {
59 assert(bit_size == 1);
60 components[comp] = nir_i2b(b, nir_load_hit_attrib_amd(b, .base = base));
61 }
62 }
63
64 nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, num_components));
65 } else {
66 nir_def *value = intrin->src[1].ssa;
67 uint32_t num_components = value->num_components;
68 uint32_t bit_size = value->bit_size;
69
70 for (uint32_t comp = 0; comp < num_components; comp++) {
71 uint32_t offset = args->base_offset + deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8);
72 uint32_t base = offset / 4;
73 uint32_t comp_offset = offset % 4;
74
75 nir_def *component = nir_channel(b, value, comp);
76
77 if (bit_size == 64) {
78 nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base);
79 nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1);
80 } else if (bit_size == 32) {
81 nir_store_hit_attrib_amd(b, component, .base = base);
82 } else if (bit_size == 16) {
83 nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base));
84 nir_def *components[2];
85 for (uint32_t word = 0; word < 2; word++)
86 components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word);
87 nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base);
88 } else if (bit_size == 8) {
89 nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8);
90 nir_def *components[4];
91 for (uint32_t byte = 0; byte < 4; byte++)
92 components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte);
93 nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base);
94 } else {
95 assert(bit_size == 1);
96 nir_store_hit_attrib_amd(b, nir_b2i32(b, component), .base = base);
97 }
98 }
99 }
100
101 nir_instr_remove(instr);
102 return true;
103 }
104
105 static bool
radv_nir_lower_rt_vars(nir_shader * shader,nir_variable_mode mode,uint32_t base_offset)106 radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base_offset)
107 {
108 bool progress = false;
109
110 progress |= nir_split_struct_vars(shader, mode);
111 progress |= nir_lower_indirect_derefs(shader, mode, UINT32_MAX);
112 progress |= nir_split_array_vars(shader, mode);
113
114 progress |= nir_lower_vars_to_explicit_types(shader, mode, glsl_get_natural_size_align_bytes);
115
116 struct lower_hit_attrib_deref_args args = {
117 .mode = mode,
118 .base_offset = base_offset,
119 };
120
121 progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref, nir_metadata_control_flow, &args);
122
123 if (progress) {
124 nir_remove_dead_derefs(shader);
125 nir_remove_dead_variables(shader, mode, NULL);
126 }
127
128 return progress;
129 }
130
131 bool
radv_nir_lower_hit_attrib_derefs(nir_shader * shader)132 radv_nir_lower_hit_attrib_derefs(nir_shader *shader)
133 {
134 return radv_nir_lower_rt_vars(shader, nir_var_ray_hit_attrib, 0);
135 }
136
137 bool
radv_nir_lower_ray_payload_derefs(nir_shader * shader,uint32_t offset)138 radv_nir_lower_ray_payload_derefs(nir_shader *shader, uint32_t offset)
139 {
140 bool progress = radv_nir_lower_rt_vars(shader, nir_var_function_temp, RADV_MAX_HIT_ATTRIB_SIZE + offset);
141 progress |= radv_nir_lower_rt_vars(shader, nir_var_shader_call_data, RADV_MAX_HIT_ATTRIB_SIZE + offset);
142 return progress;
143 }
144