xref: /aosp_15_r20/external/mesa3d/src/intel/compiler/intel_nir_opt_peephole_ffma.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2014 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "intel_nir.h"
25 #include "compiler/nir/nir_builder.h"
26 
27 /*
28  * Implements a small peephole optimization that looks for a multiply that
29  * is only ever used in an add and replaces both with an fma.
30  */
31 
32 static inline bool
are_all_uses_fadd(nir_def * def)33 are_all_uses_fadd(nir_def *def)
34 {
35    nir_foreach_use_including_if(use_src, def) {
36       if (nir_src_is_if(use_src))
37          return false;
38 
39       nir_instr *use_instr = nir_src_parent_instr(use_src);
40       if (use_instr->type != nir_instr_type_alu)
41          return false;
42 
43       nir_alu_instr *use_alu = nir_instr_as_alu(use_instr);
44       switch (use_alu->op) {
45       case nir_op_fadd:
46          break; /* This one's ok */
47 
48       case nir_op_mov:
49       case nir_op_fneg:
50       case nir_op_fabs:
51          if (!are_all_uses_fadd(&use_alu->def))
52             return false;
53          break;
54 
55       default:
56          return false;
57       }
58    }
59 
60    return true;
61 }
62 
63 static nir_alu_instr *
get_mul_for_src(nir_alu_src * src,unsigned num_components,uint8_t * swizzle,bool * negate,bool * abs)64 get_mul_for_src(nir_alu_src *src, unsigned num_components,
65                 uint8_t *swizzle, bool *negate, bool *abs)
66 {
67    uint8_t swizzle_tmp[NIR_MAX_VEC_COMPONENTS];
68 
69    nir_instr *instr = src->src.ssa->parent_instr;
70    if (instr->type != nir_instr_type_alu)
71       return NULL;
72 
73    nir_alu_instr *alu = nir_instr_as_alu(instr);
74 
75    /* We want to bail if any of the other ALU operations involved is labeled
76     * exact.  One reason for this is that, while the value that is changing is
77     * actually the result of the add and not the multiply, the intention of
78     * the user when they specify an exact multiply is that they want *that*
79     * value and what they don't care about is the add.  Another reason is that
80     * SPIR-V explicitly requires this behaviour.
81     */
82    if (alu->exact)
83       return NULL;
84 
85    switch (alu->op) {
86    case nir_op_mov:
87       alu = get_mul_for_src(&alu->src[0], alu->def.num_components,
88                             swizzle, negate, abs);
89       break;
90 
91    case nir_op_fneg:
92       alu = get_mul_for_src(&alu->src[0], alu->def.num_components,
93                             swizzle, negate, abs);
94       *negate = !*negate;
95       break;
96 
97    case nir_op_fabs:
98       alu = get_mul_for_src(&alu->src[0], alu->def.num_components,
99                             swizzle, negate, abs);
100       *negate = false;
101       *abs = true;
102       break;
103 
104    case nir_op_fmul:
105       /* Only absorb a fmul into a ffma if the fmul is only used in fadd
106        * operations.  This prevents us from being too aggressive with our
107        * fusing which can actually lead to more instructions.
108        */
109       if (!are_all_uses_fadd(&alu->def))
110          return NULL;
111       break;
112 
113    default:
114       return NULL;
115    }
116 
117    if (!alu)
118       return NULL;
119 
120    /* Copy swizzle data before overwriting it to avoid setting a wrong swizzle.
121     *
122     * Example:
123     *   Former swizzle[] = xyzw
124     *   src->swizzle[] = zyxx
125     *
126     *   Expected output swizzle = zyxx
127     *   If we reuse swizzle in the loop, then output swizzle would be zyzz.
128     */
129    memcpy(swizzle_tmp, swizzle, NIR_MAX_VEC_COMPONENTS*sizeof(uint8_t));
130    for (int i = 0; i < num_components; i++)
131       swizzle[i] = swizzle_tmp[src->swizzle[i]];
132 
133    return alu;
134 }
135 
136 /**
137  * Given a list of (at least two) nir_alu_src's, tells if any of them is a
138  * constant value and is used only once.
139  */
140 static bool
any_alu_src_is_a_constant(nir_alu_src srcs[])141 any_alu_src_is_a_constant(nir_alu_src srcs[])
142 {
143    for (unsigned i = 0; i < 2; i++) {
144       if (srcs[i].src.ssa->parent_instr->type == nir_instr_type_load_const) {
145          nir_load_const_instr *load_const =
146             nir_instr_as_load_const (srcs[i].src.ssa->parent_instr);
147 
148          if (list_is_singular(&load_const->def.uses))
149             return true;
150       }
151    }
152 
153    return false;
154 }
155 
156 static bool
intel_nir_opt_peephole_ffma_instr(nir_builder * b,nir_instr * instr,UNUSED void * cb_data)157 intel_nir_opt_peephole_ffma_instr(nir_builder *b,
158                                   nir_instr *instr,
159                                   UNUSED void *cb_data)
160 {
161    if (instr->type != nir_instr_type_alu)
162       return false;
163 
164    nir_alu_instr *add = nir_instr_as_alu(instr);
165    if (add->op != nir_op_fadd)
166       return false;
167 
168    if (add->exact)
169       return false;
170 
171 
172    /* This, is the case a + a.  We would rather handle this with an
173     * algebraic reduction than fuse it.  Also, we want to only fuse
174     * things where the multiply is used only once and, in this case,
175     * it would be used twice by the same instruction.
176     */
177    if (add->src[0].src.ssa == add->src[1].src.ssa)
178       return false;
179 
180    nir_alu_instr *mul;
181    uint8_t add_mul_src, swizzle[NIR_MAX_VEC_COMPONENTS];
182    bool negate, abs;
183    for (add_mul_src = 0; add_mul_src < 2; add_mul_src++) {
184       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
185          swizzle[i] = i;
186 
187       negate = false;
188       abs = false;
189 
190       mul = get_mul_for_src(&add->src[add_mul_src],
191                             add->def.num_components,
192                             swizzle, &negate, &abs);
193 
194       if (mul != NULL)
195          break;
196    }
197 
198    if (mul == NULL)
199       return false;
200 
201    unsigned bit_size = add->def.bit_size;
202 
203    nir_def *mul_src[2];
204    mul_src[0] = mul->src[0].src.ssa;
205    mul_src[1] = mul->src[1].src.ssa;
206 
207    /* If any of the operands of the fmul and any of the fadd is a constant,
208     * we bypass because it will be more efficient as the constants will be
209     * propagated as operands, potentially saving two load_const instructions.
210     */
211    if (any_alu_src_is_a_constant(mul->src) &&
212        any_alu_src_is_a_constant(add->src)) {
213       return false;
214    }
215 
216    b->cursor = nir_before_instr(&add->instr);
217 
218    if (abs) {
219       for (unsigned i = 0; i < 2; i++)
220          mul_src[i] = nir_fabs(b, mul_src[i]);
221    }
222 
223    if (negate)
224       mul_src[0] = nir_fneg(b, mul_src[0]);
225 
226    nir_alu_instr *ffma = nir_alu_instr_create(b->shader, nir_op_ffma);
227 
228    for (unsigned i = 0; i < 2; i++) {
229       ffma->src[i].src = nir_src_for_ssa(mul_src[i]);
230       for (unsigned j = 0; j < add->def.num_components; j++)
231          ffma->src[i].swizzle[j] = mul->src[i].swizzle[swizzle[j]];
232    }
233    nir_alu_src_copy(&ffma->src[2], &add->src[1 - add_mul_src]);
234 
235    nir_def_init(&ffma->instr, &ffma->def,
236                 add->def.num_components, bit_size);
237    nir_def_rewrite_uses(&add->def, &ffma->def);
238 
239    nir_builder_instr_insert(b, &ffma->instr);
240    assert(list_is_empty(&add->def.uses));
241    nir_instr_remove(&add->instr);
242 
243    return true;
244 }
245 
246 bool
intel_nir_opt_peephole_ffma(nir_shader * shader)247 intel_nir_opt_peephole_ffma(nir_shader *shader)
248 {
249    return nir_shader_instructions_pass(shader, intel_nir_opt_peephole_ffma_instr,
250                                        nir_metadata_control_flow,
251                                        NULL);
252 }
253