xref: /aosp_15_r20/external/mesa3d/src/asahi/compiler/agx_nir_lower_address.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2022 Alyssa Rosenzweig
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "compiler/nir/nir_builder.h"
7 #include "agx_compiler.h"
8 
9 /* Results of pattern matching */
10 struct match {
11    nir_scalar base, offset;
12    bool sign_extend;
13 
14    /* Signed shift. A negative shift indicates that the offset needs ushr
15     * applied. It's cheaper to fold iadd and materialize an extra ushr, than
16     * to leave the iadd untouched, so this is good.
17     */
18    int8_t shift;
19 };
20 
21 /*
22  * Try to match a multiplication with an immediate value. This generalizes to
23  * both imul and ishl. If successful, returns true and sets the output
24  * variables. Otherwise, returns false.
25  */
26 static bool
match_imul_imm(nir_scalar scalar,nir_scalar * variable,uint32_t * imm)27 match_imul_imm(nir_scalar scalar, nir_scalar *variable, uint32_t *imm)
28 {
29    if (!nir_scalar_is_alu(scalar))
30       return false;
31 
32    nir_op op = nir_scalar_alu_op(scalar);
33    if (op != nir_op_imul && op != nir_op_ishl)
34       return false;
35 
36    nir_scalar inputs[] = {
37       nir_scalar_chase_alu_src(scalar, 0),
38       nir_scalar_chase_alu_src(scalar, 1),
39    };
40 
41    /* For imul check both operands for an immediate, since imul is commutative.
42     * For ishl, only check the operand on the right.
43     */
44    bool commutes = (op == nir_op_imul);
45 
46    for (unsigned i = commutes ? 0 : 1; i < ARRAY_SIZE(inputs); ++i) {
47       if (!nir_scalar_is_const(inputs[i]))
48          continue;
49 
50       *variable = inputs[1 - i];
51 
52       uint32_t value = nir_scalar_as_uint(inputs[i]);
53 
54       if (op == nir_op_imul)
55          *imm = value;
56       else
57          *imm = (1 << value);
58 
59       return true;
60    }
61 
62    return false;
63 }
64 
65 /*
66  * Try to rewrite (a << (#b + #c)) + #d as ((a << #b) + #d') << #c,
67  * assuming that #d is a multiple of 1 << #c. This takes advantage of
68  * the hardware's implicit << #c and avoids a right-shift.
69  *
70  * Similarly, try to rewrite (a * (#b << #c)) + #d as ((a * #b) + #d') << #c.
71  *
72  * This pattern occurs with a struct-of-array layout.
73  */
74 static bool
match_soa(nir_builder * b,struct match * match,unsigned format_shift)75 match_soa(nir_builder *b, struct match *match, unsigned format_shift)
76 {
77    if (!nir_scalar_is_alu(match->offset) ||
78        nir_scalar_alu_op(match->offset) != nir_op_iadd)
79       return false;
80 
81    nir_scalar summands[] = {
82       nir_scalar_chase_alu_src(match->offset, 0),
83       nir_scalar_chase_alu_src(match->offset, 1),
84    };
85 
86    for (unsigned i = 0; i < ARRAY_SIZE(summands); ++i) {
87       if (!nir_scalar_is_const(summands[i]))
88          continue;
89 
90       /* Note: This is treated as signed regardless of the sign of the match.
91        * The final addition into the base can be signed or unsigned, but when
92        * we shift right by the format shift below we need to always sign extend
93        * to ensure that any negative offset remains negative when added into
94        * the index. That is, in:
95        *
96        * addr = base + (u64)((index + offset) << shift)
97        *
98        * `index` and `offset` are always 32 bits, and a negative `offset` needs
99        * to subtract from the index, so it needs to be sign extended when we
100        * apply the format shift regardless of the fact that the later conversion
101        * to 64 bits does not sign extend.
102        *
103        * TODO: We need to confirm how the hardware handles 32-bit overflow when
104        * applying the format shift, which might need rework here again.
105        */
106       int offset = nir_scalar_as_int(summands[i]);
107       nir_scalar variable;
108       uint32_t multiplier;
109 
110       /* The other operand must multiply */
111       if (!match_imul_imm(summands[1 - i], &variable, &multiplier))
112          return false;
113 
114       int offset_shifted = offset >> format_shift;
115       uint32_t multiplier_shifted = multiplier >> format_shift;
116 
117       /* If the multiplier or the offset are not aligned, we can't rewrite */
118       if (multiplier != (multiplier_shifted << format_shift))
119          return false;
120 
121       if (offset != (offset_shifted << format_shift))
122          return false;
123 
124       /* Otherwise, rewrite! */
125       nir_def *unmultiplied = nir_vec_scalars(b, &variable, 1);
126 
127       nir_def *rewrite = nir_iadd_imm(
128          b, nir_imul_imm(b, unmultiplied, multiplier_shifted), offset_shifted);
129 
130       match->offset = nir_get_scalar(rewrite, 0);
131       match->shift = 0;
132       return true;
133    }
134 
135    return false;
136 }
137 
138 /* Try to pattern match address calculation */
139 static struct match
match_address(nir_builder * b,nir_scalar base,int8_t format_shift)140 match_address(nir_builder *b, nir_scalar base, int8_t format_shift)
141 {
142    struct match match = {.base = base};
143 
144    /* All address calculations are iadd at the root */
145    if (!nir_scalar_is_alu(base) || nir_scalar_alu_op(base) != nir_op_iadd)
146       return match;
147 
148    /* Only 64+32 addition is supported, look for an extension */
149    nir_scalar summands[] = {
150       nir_scalar_chase_alu_src(base, 0),
151       nir_scalar_chase_alu_src(base, 1),
152    };
153 
154    for (unsigned i = 0; i < ARRAY_SIZE(summands); ++i) {
155       /* We can add a small constant to the 64-bit base for free */
156       if (nir_scalar_is_const(summands[i]) &&
157           nir_scalar_as_uint(summands[i]) < (1ull << 32)) {
158 
159          uint32_t value = nir_scalar_as_uint(summands[i]);
160 
161          return (struct match){
162             .base = summands[1 - i],
163             .offset = nir_get_scalar(nir_imm_int(b, value), 0),
164             .shift = -format_shift,
165             .sign_extend = false,
166          };
167       }
168 
169       /* Otherwise, we can only add an offset extended from 32-bits */
170       if (!nir_scalar_is_alu(summands[i]))
171          continue;
172 
173       nir_op op = nir_scalar_alu_op(summands[i]);
174 
175       if (op != nir_op_u2u64 && op != nir_op_i2i64)
176          continue;
177 
178       /* We've found a summand, commit to it */
179       match.base = summands[1 - i];
180       match.offset = nir_scalar_chase_alu_src(summands[i], 0);
181       match.sign_extend = (op == nir_op_i2i64);
182 
183       /* Undo the implicit shift from using as offset */
184       match.shift = -format_shift;
185       break;
186    }
187 
188    /* If we didn't find something to fold in, there's nothing else we can do */
189    if (!match.offset.def)
190       return match;
191 
192    /* But if we did, we can try to fold in in a multiply */
193    nir_scalar multiplied;
194    uint32_t multiplier;
195 
196    if (match_imul_imm(match.offset, &multiplied, &multiplier)) {
197       int8_t new_shift = match.shift;
198 
199       /* Try to fold in either a full power-of-two, or just the power-of-two
200        * part of a non-power-of-two stride.
201        */
202       if (util_is_power_of_two_nonzero(multiplier)) {
203          new_shift += util_logbase2(multiplier);
204          multiplier = 1;
205       } else if (((multiplier >> format_shift) << format_shift) == multiplier) {
206          new_shift += format_shift;
207          multiplier >>= format_shift;
208       } else {
209          return match;
210       }
211 
212       nir_def *multiplied_ssa = nir_vec_scalars(b, &multiplied, 1);
213 
214       /* Only fold in if we wouldn't overflow the lsl field */
215       if (new_shift <= 2) {
216          match.offset =
217             nir_get_scalar(nir_imul_imm(b, multiplied_ssa, multiplier), 0);
218          match.shift = new_shift;
219       } else if (new_shift > 0) {
220          /* For large shifts, we do need a multiply, but we can
221           * shrink the shift to avoid generating an ishr.
222           */
223          assert(new_shift >= 3);
224 
225          nir_def *rewrite =
226             nir_imul_imm(b, multiplied_ssa, multiplier << new_shift);
227 
228          match.offset = nir_get_scalar(rewrite, 0);
229          match.shift = 0;
230       }
231    } else {
232       /* Try to match struct-of-arrays pattern, updating match if possible */
233       match_soa(b, &match, format_shift);
234    }
235 
236    return match;
237 }
238 
239 static enum pipe_format
format_for_bitsize(unsigned bitsize)240 format_for_bitsize(unsigned bitsize)
241 {
242    switch (bitsize) {
243    case 8:
244       return PIPE_FORMAT_R8_UINT;
245    case 16:
246       return PIPE_FORMAT_R16_UINT;
247    case 32:
248       return PIPE_FORMAT_R32_UINT;
249    default:
250       unreachable("should have been lowered");
251    }
252 }
253 
254 static bool
pass(struct nir_builder * b,nir_intrinsic_instr * intr,void * data)255 pass(struct nir_builder *b, nir_intrinsic_instr *intr, void *data)
256 {
257    if (intr->intrinsic != nir_intrinsic_load_global &&
258        intr->intrinsic != nir_intrinsic_load_global_constant &&
259        intr->intrinsic != nir_intrinsic_global_atomic &&
260        intr->intrinsic != nir_intrinsic_global_atomic_swap &&
261        intr->intrinsic != nir_intrinsic_store_global)
262       return false;
263 
264    b->cursor = nir_before_instr(&intr->instr);
265 
266    unsigned bitsize = intr->intrinsic == nir_intrinsic_store_global
267                          ? nir_src_bit_size(intr->src[0])
268                          : intr->def.bit_size;
269    enum pipe_format format = format_for_bitsize(bitsize);
270    unsigned format_shift = util_logbase2(util_format_get_blocksize(format));
271 
272    nir_src *orig_offset = nir_get_io_offset_src(intr);
273    nir_scalar base = nir_scalar_resolved(orig_offset->ssa, 0);
274    struct match match = match_address(b, base, format_shift);
275 
276    nir_def *offset = match.offset.def != NULL
277                         ? nir_channel(b, match.offset.def, match.offset.comp)
278                         : nir_imm_int(b, 0);
279 
280    /* If we were unable to fold in the shift, insert a right-shift now to undo
281     * the implicit left shift of the instruction.
282     */
283    if (match.shift < 0) {
284       if (match.sign_extend)
285          offset = nir_ishr_imm(b, offset, -match.shift);
286       else
287          offset = nir_ushr_imm(b, offset, -match.shift);
288 
289       match.shift = 0;
290    }
291 
292    /* Hardware offsets must be 32-bits. Upconvert if the source code used
293     * smaller integers.
294     */
295    if (offset->bit_size != 32) {
296       assert(offset->bit_size < 32);
297 
298       if (match.sign_extend)
299          offset = nir_i2i32(b, offset);
300       else
301          offset = nir_u2u32(b, offset);
302    }
303 
304    assert(match.shift >= 0);
305    nir_def *new_base = nir_channel(b, match.base.def, match.base.comp);
306 
307    nir_def *repl = NULL;
308    bool has_dest = (intr->intrinsic != nir_intrinsic_store_global);
309    unsigned num_components = has_dest ? intr->def.num_components : 0;
310    unsigned bit_size = has_dest ? intr->def.bit_size : 0;
311 
312    if (intr->intrinsic == nir_intrinsic_load_global) {
313       repl =
314          nir_load_agx(b, num_components, bit_size, new_base, offset,
315                       .access = nir_intrinsic_access(intr), .base = match.shift,
316                       .format = format, .sign_extend = match.sign_extend);
317 
318    } else if (intr->intrinsic == nir_intrinsic_load_global_constant) {
319       repl = nir_load_constant_agx(b, num_components, bit_size, new_base,
320                                    offset, .access = nir_intrinsic_access(intr),
321                                    .base = match.shift, .format = format,
322                                    .sign_extend = match.sign_extend);
323    } else if (intr->intrinsic == nir_intrinsic_global_atomic) {
324       offset = nir_ishl_imm(b, offset, match.shift);
325       repl =
326          nir_global_atomic_agx(b, bit_size, new_base, offset, intr->src[1].ssa,
327                                .atomic_op = nir_intrinsic_atomic_op(intr),
328                                .sign_extend = match.sign_extend);
329    } else if (intr->intrinsic == nir_intrinsic_global_atomic_swap) {
330       offset = nir_ishl_imm(b, offset, match.shift);
331       repl = nir_global_atomic_swap_agx(
332          b, bit_size, new_base, offset, intr->src[1].ssa, intr->src[2].ssa,
333          .atomic_op = nir_intrinsic_atomic_op(intr),
334          .sign_extend = match.sign_extend);
335    } else {
336       nir_store_agx(b, intr->src[0].ssa, new_base, offset,
337                     .access = nir_intrinsic_access(intr), .base = match.shift,
338                     .format = format, .sign_extend = match.sign_extend);
339    }
340 
341    if (repl)
342       nir_def_rewrite_uses(&intr->def, repl);
343 
344    nir_instr_remove(&intr->instr);
345    return true;
346 }
347 
348 bool
agx_nir_lower_address(nir_shader * shader)349 agx_nir_lower_address(nir_shader *shader)
350 {
351    return nir_shader_intrinsics_pass(shader, pass, nir_metadata_control_flow,
352                                      NULL);
353 }
354