xref: /aosp_15_r20/external/mesa3d/src/asahi/lib/agx_nir_lower_vbo.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2022 Alyssa Rosenzweig
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "agx_nir_lower_vbo.h"
7 #include "asahi/layout/layout.h"
8 #include "compiler/nir/nir_builder.h"
9 #include "compiler/nir/nir_format_convert.h"
10 #include "util/bitset.h"
11 #include "util/u_math.h"
12 #include "shader_enums.h"
13 
14 struct ctx {
15    struct agx_attribute *attribs;
16    struct agx_robustness rs;
17 };
18 
19 static bool
is_rgb10_a2(const struct util_format_description * desc)20 is_rgb10_a2(const struct util_format_description *desc)
21 {
22    return desc->channel[0].shift == 0 && desc->channel[0].size == 10 &&
23           desc->channel[1].shift == 10 && desc->channel[1].size == 10 &&
24           desc->channel[2].shift == 20 && desc->channel[2].size == 10 &&
25           desc->channel[3].shift == 30 && desc->channel[3].size == 2;
26 }
27 
28 static enum pipe_format
agx_vbo_internal_format(enum pipe_format format)29 agx_vbo_internal_format(enum pipe_format format)
30 {
31    const struct util_format_description *desc = util_format_description(format);
32 
33    /* RGB10A2 formats are native for UNORM and unpacked otherwise */
34    if (is_rgb10_a2(desc)) {
35       if (desc->is_unorm)
36          return PIPE_FORMAT_R10G10B10A2_UNORM;
37       else
38          return PIPE_FORMAT_R32_UINT;
39    }
40 
41    /* R11G11B10F is native and special */
42    if (format == PIPE_FORMAT_R11G11B10_FLOAT)
43       return format;
44 
45    /* No other non-array formats handled */
46    if (!desc->is_array)
47       return PIPE_FORMAT_NONE;
48 
49    /* Otherwise look at one (any) channel */
50    int idx = util_format_get_first_non_void_channel(format);
51    if (idx < 0)
52       return PIPE_FORMAT_NONE;
53 
54    /* We only handle RGB formats (we could do SRGB if we wanted though?) */
55    if ((desc->colorspace != UTIL_FORMAT_COLORSPACE_RGB) ||
56        (desc->layout != UTIL_FORMAT_LAYOUT_PLAIN))
57       return PIPE_FORMAT_NONE;
58 
59    /* We have native 8-bit and 16-bit normalized formats */
60    struct util_format_channel_description chan = desc->channel[idx];
61 
62    if (chan.normalized) {
63       if (chan.size == 8)
64          return desc->is_unorm ? PIPE_FORMAT_R8_UNORM : PIPE_FORMAT_R8_SNORM;
65       else if (chan.size == 16)
66          return desc->is_unorm ? PIPE_FORMAT_R16_UNORM : PIPE_FORMAT_R16_SNORM;
67    }
68 
69    /* Otherwise map to the corresponding integer format */
70    switch (chan.size) {
71    case 32:
72       return PIPE_FORMAT_R32_UINT;
73    case 16:
74       return PIPE_FORMAT_R16_UINT;
75    case 8:
76       return PIPE_FORMAT_R8_UINT;
77    default:
78       return PIPE_FORMAT_NONE;
79    }
80 }
81 
82 bool
agx_vbo_supports_format(enum pipe_format format)83 agx_vbo_supports_format(enum pipe_format format)
84 {
85    return agx_vbo_internal_format(format) != PIPE_FORMAT_NONE;
86 }
87 
88 static nir_def *
apply_swizzle_channel(nir_builder * b,nir_def * vec,unsigned swizzle,bool is_int)89 apply_swizzle_channel(nir_builder *b, nir_def *vec, unsigned swizzle,
90                       bool is_int)
91 {
92    switch (swizzle) {
93    case PIPE_SWIZZLE_X:
94       return nir_channel(b, vec, 0);
95    case PIPE_SWIZZLE_Y:
96       return nir_channel(b, vec, 1);
97    case PIPE_SWIZZLE_Z:
98       return nir_channel(b, vec, 2);
99    case PIPE_SWIZZLE_W:
100       return nir_channel(b, vec, 3);
101    case PIPE_SWIZZLE_0:
102       return nir_imm_intN_t(b, 0, vec->bit_size);
103    case PIPE_SWIZZLE_1:
104       return is_int ? nir_imm_intN_t(b, 1, vec->bit_size)
105                     : nir_imm_floatN_t(b, 1.0, vec->bit_size);
106    default:
107       unreachable("Invalid swizzle channel");
108    }
109 }
110 
111 static bool
pass(struct nir_builder * b,nir_intrinsic_instr * intr,void * data)112 pass(struct nir_builder *b, nir_intrinsic_instr *intr, void *data)
113 {
114    if (intr->intrinsic != nir_intrinsic_load_input)
115       return false;
116 
117    struct ctx *ctx = data;
118    struct agx_attribute *attribs = ctx->attribs;
119    b->cursor = nir_instr_remove(&intr->instr);
120 
121    nir_src *offset_src = nir_get_io_offset_src(intr);
122    assert(nir_src_is_const(*offset_src) && "no attribute indirects");
123    unsigned index = nir_intrinsic_base(intr) + nir_src_as_uint(*offset_src);
124 
125    struct agx_attribute attrib = attribs[index];
126    uint32_t stride = attrib.stride;
127    uint16_t offset = attrib.src_offset;
128 
129    const struct util_format_description *desc =
130       util_format_description(attrib.format);
131    int chan = util_format_get_first_non_void_channel(attrib.format);
132    assert(chan >= 0);
133 
134    bool is_float = desc->channel[chan].type == UTIL_FORMAT_TYPE_FLOAT;
135    bool is_unsigned = desc->channel[chan].type == UTIL_FORMAT_TYPE_UNSIGNED;
136    bool is_signed = desc->channel[chan].type == UTIL_FORMAT_TYPE_SIGNED;
137    bool is_fixed = desc->channel[chan].type == UTIL_FORMAT_TYPE_FIXED;
138    bool is_int = util_format_is_pure_integer(attrib.format);
139 
140    assert((is_float ^ is_unsigned ^ is_signed ^ is_fixed) && "Invalid format");
141 
142    enum pipe_format interchange_format = agx_vbo_internal_format(attrib.format);
143    assert(interchange_format != PIPE_FORMAT_NONE);
144 
145    unsigned interchange_align = util_format_get_blocksize(interchange_format);
146    unsigned interchange_comps = util_format_get_nr_components(attrib.format);
147 
148    /* In the hardware, uint formats zero-extend and float formats convert.
149     * However, non-uint formats using a uint interchange format shouldn't be
150     * zero extended.
151     */
152    unsigned interchange_register_size =
153       util_format_is_pure_uint(interchange_format) &&
154             !util_format_is_pure_uint(attrib.format)
155          ? (interchange_align * 8)
156          : intr->def.bit_size;
157 
158    /* Non-UNORM R10G10B10A2 loaded as a scalar and unpacked */
159    if (interchange_format == PIPE_FORMAT_R32_UINT && !desc->is_array)
160       interchange_comps = 1;
161 
162    /* Calculate the element to fetch the vertex for. Divide the instance ID by
163     * the divisor for per-instance data. Divisor=0 specifies per-vertex data.
164     */
165    nir_def *el;
166    if (attrib.instanced) {
167       if (attrib.divisor > 0)
168          el = nir_udiv_imm(b, nir_load_instance_id(b), attrib.divisor);
169       else
170          el = nir_imm_int(b, 0);
171 
172       el = nir_iadd(b, el, nir_load_base_instance(b));
173 
174       BITSET_SET(b->shader->info.system_values_read,
175                  SYSTEM_VALUE_BASE_INSTANCE);
176    } else {
177       el = nir_load_vertex_id(b);
178    }
179 
180    /* VBO bases are per-attribute, otherwise they're per-buffer. This allows
181     * memory sinks to work properly with robustness, allows folding
182     * the src_offset into the VBO base to save an add in the shader, and reduces
183     * the size of the vertex fetch key. That last piece allows reusing a linked
184     * VS with both separate and interleaved attributes.
185     */
186    nir_def *buf_handle = nir_imm_int(b, index);
187 
188    /* Robustness is handled at the ID level */
189    nir_def *bounds = nir_load_attrib_clamp_agx(b, buf_handle);
190 
191    /* For now, robustness is always applied. This gives GL robustness semantics.
192     * For robustBufferAccess2, we'll want to check for out-of-bounds access
193     * (where el > bounds), and replace base with the address of a zero sink.
194     * With soft fault and a large enough sink, we don't need to clamp the index,
195     * allowing that robustness behaviour to be implemented in 2 cmpsel
196     * before the load. That is faster than the 4 cmpsel required after the load,
197     * and it avoids waiting on the load which should help prolog performance.
198     *
199     * TODO: Optimize.
200     *
201     */
202    nir_def *oob = nir_ult(b, bounds, el);
203 
204    /* TODO: We clamp to handle null descriptors. This should be optimized
205     * further. However, with the fix up after the load for D3D robustness, we
206     * don't need this clamp if we can ignore the fault.
207     */
208    if (!(ctx->rs.level >= AGX_ROBUSTNESS_D3D && ctx->rs.soft_fault)) {
209       el = nir_bcsel(b, oob, nir_imm_int(b, 0), el);
210    }
211 
212    nir_def *base = nir_load_vbo_base_agx(b, buf_handle);
213 
214    assert((stride % interchange_align) == 0 && "must be aligned");
215    assert((offset % interchange_align) == 0 && "must be aligned");
216 
217    unsigned stride_el = stride / interchange_align;
218    unsigned offset_el = offset / interchange_align;
219    unsigned shift = 0;
220 
221    /* Try to use the small shift on the load itself when possible. This can save
222     * an instruction. Shifts are only available for regular interchange formats,
223     * i.e. the set of formats that support masking.
224     */
225    if (offset_el == 0 && (stride_el == 2 || stride_el == 4) &&
226        ail_isa_format_supports_mask((enum ail_isa_format)interchange_format)) {
227 
228       shift = util_logbase2(stride_el);
229       stride_el = 1;
230    }
231 
232    nir_def *stride_offset_el =
233       nir_iadd_imm(b, nir_imul_imm(b, el, stride_el), offset_el);
234 
235    /* Load the raw vector */
236    nir_def *memory = nir_load_constant_agx(
237       b, interchange_comps, interchange_register_size, base, stride_offset_el,
238       .format = interchange_format, .base = shift);
239 
240    /* TODO: Optimize per above */
241    if (ctx->rs.level >= AGX_ROBUSTNESS_D3D) {
242       nir_def *zero = nir_imm_zero(b, memory->num_components, memory->bit_size);
243       memory = nir_bcsel(b, oob, zero, memory);
244    }
245 
246    unsigned dest_size = intr->def.bit_size;
247 
248    /* Unpack but do not convert non-native non-array formats */
249    if (is_rgb10_a2(desc) && interchange_format == PIPE_FORMAT_R32_UINT) {
250       unsigned bits[] = {10, 10, 10, 2};
251 
252       if (is_signed)
253          memory = nir_format_unpack_sint(b, memory, bits, 4);
254       else
255          memory = nir_format_unpack_uint(b, memory, bits, 4);
256    }
257 
258    if (desc->channel[chan].normalized) {
259       /* 8/16-bit normalized formats are native, others converted here */
260       if (is_rgb10_a2(desc) && is_signed) {
261          unsigned bits[] = {10, 10, 10, 2};
262          memory = nir_format_snorm_to_float(b, memory, bits);
263       } else if (desc->channel[chan].size == 32) {
264          assert(desc->is_array && "no non-array 32-bit norm formats");
265          unsigned bits[] = {32, 32, 32, 32};
266 
267          if (is_signed)
268             memory = nir_format_snorm_to_float(b, memory, bits);
269          else
270             memory = nir_format_unorm_to_float(b, memory, bits);
271       }
272    } else if (desc->channel[chan].pure_integer) {
273       /* Zero-extension is native, may need to sign extend */
274       if (is_signed)
275          memory = nir_i2iN(b, memory, dest_size);
276    } else {
277       if (is_unsigned)
278          memory = nir_u2fN(b, memory, dest_size);
279       else if (is_signed || is_fixed)
280          memory = nir_i2fN(b, memory, dest_size);
281       else
282          memory = nir_f2fN(b, memory, dest_size);
283 
284       /* 16.16 fixed-point weirdo GL formats need to be scaled */
285       if (is_fixed) {
286          assert(desc->is_array && desc->channel[chan].size == 32);
287          assert(dest_size == 32 && "overflow if smaller");
288          memory = nir_fmul_imm(b, memory, 1.0 / 65536.0);
289       }
290    }
291 
292    /* We now have a properly formatted vector of the components in memory. Apply
293     * the format swizzle forwards to trim/pad/reorder as needed.
294     */
295    nir_def *channels[4] = {NULL};
296 
297    for (unsigned i = 0; i < intr->num_components; ++i) {
298       unsigned c = nir_intrinsic_component(intr) + i;
299       channels[i] = apply_swizzle_channel(b, memory, desc->swizzle[c], is_int);
300    }
301 
302    nir_def *logical = nir_vec(b, channels, intr->num_components);
303    nir_def_rewrite_uses(&intr->def, logical);
304    return true;
305 }
306 
307 bool
agx_nir_lower_vbo(nir_shader * shader,struct agx_attribute * attribs,struct agx_robustness robustness)308 agx_nir_lower_vbo(nir_shader *shader, struct agx_attribute *attribs,
309                   struct agx_robustness robustness)
310 {
311    assert(shader->info.stage == MESA_SHADER_VERTEX);
312 
313    struct ctx ctx = {.attribs = attribs, .rs = robustness};
314    return nir_shader_intrinsics_pass(shader, pass, nir_metadata_control_flow,
315                                      &ctx);
316 }
317