xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_opt_mqsad.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Valve Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 #include "nir.h"
6 #include "nir_builder.h"
7 #include "nir_worklist.h"
8 
9 /*
10  * This pass recognizes certain patterns of nir_op_shfr and nir_op_msad_4x8 and replaces it
11  * with a single nir_op_mqsad_4x8 instruction.
12  */
13 
14 struct mqsad {
15    nir_scalar ref;
16    nir_scalar src[2];
17 
18    nir_scalar accum[4];
19    nir_alu_instr *msad[4];
20    unsigned first_msad_index;
21    uint8_t mask;
22 };
23 
24 static bool
is_mqsad_compatible(struct mqsad * mqsad,nir_scalar ref,nir_scalar src0,nir_scalar src1,unsigned idx,nir_alu_instr * msad)25 is_mqsad_compatible(struct mqsad *mqsad, nir_scalar ref, nir_scalar src0, nir_scalar src1,
26                     unsigned idx, nir_alu_instr *msad)
27 {
28    if (!nir_scalar_equal(ref, mqsad->ref) || !nir_scalar_equal(src0, mqsad->src[0]))
29       return false;
30    if ((mqsad->mask & 0b1110) && idx && !nir_scalar_equal(src1, mqsad->src[1]))
31       return false;
32 
33    /* Ensure that this MSAD doesn't depend on any previous MSAD. */
34    nir_instr_worklist *wl = nir_instr_worklist_create();
35    nir_instr_worklist_add_ssa_srcs(wl, &msad->instr);
36    nir_foreach_instr_in_worklist(instr, wl) {
37       if (instr->block != msad->instr.block || instr->index < mqsad->first_msad_index)
38          continue;
39 
40       u_foreach_bit(i, mqsad->mask) {
41          if (instr == &mqsad->msad[i]->instr) {
42             nir_instr_worklist_destroy(wl);
43             return false;
44          }
45       }
46 
47       nir_instr_worklist_add_ssa_srcs(wl, instr);
48    }
49    nir_instr_worklist_destroy(wl);
50 
51    return true;
52 }
53 
54 static void
parse_msad(nir_alu_instr * msad,struct mqsad * mqsad)55 parse_msad(nir_alu_instr *msad, struct mqsad *mqsad)
56 {
57    if (msad->def.num_components != 1)
58       return;
59 
60    nir_scalar msad_s = nir_get_scalar(&msad->def, 0);
61    nir_scalar ref = nir_scalar_chase_alu_src(msad_s, 0);
62    nir_scalar accum = nir_scalar_chase_alu_src(msad_s, 2);
63 
64    unsigned idx = 0;
65    nir_scalar src0 = nir_scalar_chase_alu_src(msad_s, 1);
66    nir_scalar src1;
67    if (nir_scalar_is_alu(src0) && nir_scalar_alu_op(src0) == nir_op_shfr) {
68       nir_scalar amount_s = nir_scalar_chase_alu_src(src0, 2);
69       uint32_t amount = nir_scalar_is_const(amount_s) ? nir_scalar_as_uint(amount_s) : 0;
70       if (amount == 8 || amount == 16 || amount == 24) {
71          idx = amount / 8;
72          src1 = nir_scalar_chase_alu_src(src0, 0);
73          src0 = nir_scalar_chase_alu_src(src0, 1);
74       }
75    }
76 
77    if (mqsad->mask && !is_mqsad_compatible(mqsad, ref, src0, src1, idx, msad))
78       memset(mqsad, 0, sizeof(*mqsad));
79 
80    /* Add this instruction to the in-progress MQSAD. */
81    mqsad->ref = ref;
82    mqsad->src[0] = src0;
83    if (idx)
84       mqsad->src[1] = src1;
85 
86    mqsad->accum[idx] = accum;
87    mqsad->msad[idx] = msad;
88    if (!mqsad->mask)
89       mqsad->first_msad_index = msad->instr.index;
90    mqsad->mask |= 1 << idx;
91 }
92 
93 static void
create_msad(nir_builder * b,struct mqsad * mqsad)94 create_msad(nir_builder *b, struct mqsad *mqsad)
95 {
96    nir_def *mqsad_def = nir_mqsad_4x8(b, nir_channel(b, mqsad->ref.def, mqsad->ref.comp),
97                                       nir_vec_scalars(b, mqsad->src, 2),
98                                       nir_vec_scalars(b, mqsad->accum, 4));
99 
100    for (unsigned i = 0; i < 4; i++)
101       nir_def_rewrite_uses(&mqsad->msad[i]->def, nir_channel(b, mqsad_def, i));
102 
103    memset(mqsad, 0, sizeof(*mqsad));
104 }
105 
106 bool
nir_opt_mqsad(nir_shader * shader)107 nir_opt_mqsad(nir_shader *shader)
108 {
109    bool progress = false;
110    nir_foreach_function_impl(impl, shader) {
111       bool progress_impl = false;
112 
113       nir_metadata_require(impl, nir_metadata_instr_index);
114 
115       nir_foreach_block(block, impl) {
116          struct mqsad mqsad;
117          memset(&mqsad, 0, sizeof(mqsad));
118 
119          nir_foreach_instr(instr, block) {
120             if (instr->type != nir_instr_type_alu)
121                continue;
122 
123             nir_alu_instr *alu = nir_instr_as_alu(instr);
124             if (alu->op != nir_op_msad_4x8)
125                continue;
126 
127             parse_msad(alu, &mqsad);
128 
129             if (mqsad.mask == 0xf) {
130                nir_builder b = nir_builder_at(nir_before_instr(instr));
131                create_msad(&b, &mqsad);
132                progress_impl = true;
133             }
134          }
135       }
136 
137       if (progress_impl) {
138          nir_metadata_preserve(impl, nir_metadata_control_flow);
139          progress = true;
140       } else {
141          nir_metadata_preserve(impl, nir_metadata_block_index);
142       }
143    }
144 
145    return progress;
146 }
147