xref: /aosp_15_r20/external/mesa3d/src/amd/vulkan/nir/radv_nir_lower_primitive_shading_rate.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2023 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "nir.h"
8 #include "nir_builder.h"
9 #include "radv_nir.h"
10 
11 bool
radv_nir_lower_primitive_shading_rate(nir_shader * nir,enum amd_gfx_level gfx_level)12 radv_nir_lower_primitive_shading_rate(nir_shader *nir, enum amd_gfx_level gfx_level)
13 {
14    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
15    bool progress = false;
16 
17    nir_builder b = nir_builder_create(impl);
18 
19    /* Iterate in reverse order since there should be only one deref store to PRIMITIVE_SHADING_RATE
20     * after lower_io_to_temporaries for vertex shaders.
21     */
22    nir_foreach_block_reverse (block, impl) {
23       nir_foreach_instr_reverse (instr, block) {
24          if (instr->type != nir_instr_type_intrinsic)
25             continue;
26 
27          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
28          if (intr->intrinsic != nir_intrinsic_store_deref)
29             continue;
30 
31          nir_variable *var = nir_intrinsic_get_var(intr, 0);
32          if (var->data.mode != nir_var_shader_out || var->data.location != VARYING_SLOT_PRIMITIVE_SHADING_RATE)
33             continue;
34 
35          b.cursor = nir_before_instr(instr);
36 
37          nir_def *val = intr->src[1].ssa;
38 
39          /* x_rate = (shadingRate & (Horizontal2Pixels | Horizontal4Pixels)) ? 0x1 : 0x0; */
40          nir_def *x_rate = nir_iand_imm(&b, val, 12);
41          x_rate = nir_b2i32(&b, nir_ine_imm(&b, x_rate, 0));
42 
43          /* y_rate = (shadingRate & (Vertical2Pixels | Vertical4Pixels)) ? 0x1 : 0x0; */
44          nir_def *y_rate = nir_iand_imm(&b, val, 3);
45          y_rate = nir_b2i32(&b, nir_ine_imm(&b, y_rate, 0));
46 
47          nir_def *out = NULL;
48 
49          /* MS:
50           * Primitive shading rate is a per-primitive output, it is
51           * part of the second channel of the primitive export.
52           * Bits [28:31] = VRS rate
53           * This will be added to the other bits of that channel in the backend.
54           *
55           * VS, TES, GS:
56           * Primitive shading rate is a per-vertex output pos export.
57           * Bits [2:5] = VRS rate
58           * HW shading rate = (xRate << 2) | (yRate << 4)
59           *
60           * GFX11: 4-bit VRS_SHADING_RATE enum
61           * GFX10: X = low 2 bits, Y = high 2 bits
62           */
63          unsigned x_rate_shift = 2;
64          unsigned y_rate_shift = 4;
65 
66          if (gfx_level >= GFX11) {
67             x_rate_shift = 4;
68             y_rate_shift = 2;
69          }
70          if (nir->info.stage == MESA_SHADER_MESH) {
71             x_rate_shift += 26;
72             y_rate_shift += 26;
73          }
74 
75          out = nir_ior(&b, nir_ishl_imm(&b, x_rate, x_rate_shift), nir_ishl_imm(&b, y_rate, y_rate_shift));
76 
77          nir_src_rewrite(&intr->src[1], out);
78 
79          progress = true;
80          if (nir->info.stage == MESA_SHADER_VERTEX)
81             break;
82       }
83       if (nir->info.stage == MESA_SHADER_VERTEX && progress)
84          break;
85    }
86 
87    if (progress)
88       nir_metadata_preserve(impl, nir_metadata_control_flow);
89    else
90       nir_metadata_preserve(impl, nir_metadata_all);
91 
92    return progress;
93 }
94