1 /*
2 * Copyright © 2023 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "nir.h"
8 #include "nir_builder.h"
9 #include "radv_nir.h"
10
11 bool
radv_nir_lower_primitive_shading_rate(nir_shader * nir,enum amd_gfx_level gfx_level)12 radv_nir_lower_primitive_shading_rate(nir_shader *nir, enum amd_gfx_level gfx_level)
13 {
14 nir_function_impl *impl = nir_shader_get_entrypoint(nir);
15 bool progress = false;
16
17 nir_builder b = nir_builder_create(impl);
18
19 /* Iterate in reverse order since there should be only one deref store to PRIMITIVE_SHADING_RATE
20 * after lower_io_to_temporaries for vertex shaders.
21 */
22 nir_foreach_block_reverse (block, impl) {
23 nir_foreach_instr_reverse (instr, block) {
24 if (instr->type != nir_instr_type_intrinsic)
25 continue;
26
27 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
28 if (intr->intrinsic != nir_intrinsic_store_deref)
29 continue;
30
31 nir_variable *var = nir_intrinsic_get_var(intr, 0);
32 if (var->data.mode != nir_var_shader_out || var->data.location != VARYING_SLOT_PRIMITIVE_SHADING_RATE)
33 continue;
34
35 b.cursor = nir_before_instr(instr);
36
37 nir_def *val = intr->src[1].ssa;
38
39 /* x_rate = (shadingRate & (Horizontal2Pixels | Horizontal4Pixels)) ? 0x1 : 0x0; */
40 nir_def *x_rate = nir_iand_imm(&b, val, 12);
41 x_rate = nir_b2i32(&b, nir_ine_imm(&b, x_rate, 0));
42
43 /* y_rate = (shadingRate & (Vertical2Pixels | Vertical4Pixels)) ? 0x1 : 0x0; */
44 nir_def *y_rate = nir_iand_imm(&b, val, 3);
45 y_rate = nir_b2i32(&b, nir_ine_imm(&b, y_rate, 0));
46
47 nir_def *out = NULL;
48
49 /* MS:
50 * Primitive shading rate is a per-primitive output, it is
51 * part of the second channel of the primitive export.
52 * Bits [28:31] = VRS rate
53 * This will be added to the other bits of that channel in the backend.
54 *
55 * VS, TES, GS:
56 * Primitive shading rate is a per-vertex output pos export.
57 * Bits [2:5] = VRS rate
58 * HW shading rate = (xRate << 2) | (yRate << 4)
59 *
60 * GFX11: 4-bit VRS_SHADING_RATE enum
61 * GFX10: X = low 2 bits, Y = high 2 bits
62 */
63 unsigned x_rate_shift = 2;
64 unsigned y_rate_shift = 4;
65
66 if (gfx_level >= GFX11) {
67 x_rate_shift = 4;
68 y_rate_shift = 2;
69 }
70 if (nir->info.stage == MESA_SHADER_MESH) {
71 x_rate_shift += 26;
72 y_rate_shift += 26;
73 }
74
75 out = nir_ior(&b, nir_ishl_imm(&b, x_rate, x_rate_shift), nir_ishl_imm(&b, y_rate, y_rate_shift));
76
77 nir_src_rewrite(&intr->src[1], out);
78
79 progress = true;
80 if (nir->info.stage == MESA_SHADER_VERTEX)
81 break;
82 }
83 if (nir->info.stage == MESA_SHADER_VERTEX && progress)
84 break;
85 }
86
87 if (progress)
88 nir_metadata_preserve(impl, nir_metadata_control_flow);
89 else
90 nir_metadata_preserve(impl, nir_metadata_all);
91
92 return progress;
93 }
94