1 /*
2 * Copyright 2022 Alyssa Rosenzweig
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "compiler/nir/nir_builder.h"
7 #include "agx_compiler.h"
8
9 /* Results of pattern matching */
10 struct match {
11 nir_scalar base, offset;
12 bool sign_extend;
13
14 /* Signed shift. A negative shift indicates that the offset needs ushr
15 * applied. It's cheaper to fold iadd and materialize an extra ushr, than
16 * to leave the iadd untouched, so this is good.
17 */
18 int8_t shift;
19 };
20
21 /*
22 * Try to match a multiplication with an immediate value. This generalizes to
23 * both imul and ishl. If successful, returns true and sets the output
24 * variables. Otherwise, returns false.
25 */
26 static bool
match_imul_imm(nir_scalar scalar,nir_scalar * variable,uint32_t * imm)27 match_imul_imm(nir_scalar scalar, nir_scalar *variable, uint32_t *imm)
28 {
29 if (!nir_scalar_is_alu(scalar))
30 return false;
31
32 nir_op op = nir_scalar_alu_op(scalar);
33 if (op != nir_op_imul && op != nir_op_ishl)
34 return false;
35
36 nir_scalar inputs[] = {
37 nir_scalar_chase_alu_src(scalar, 0),
38 nir_scalar_chase_alu_src(scalar, 1),
39 };
40
41 /* For imul check both operands for an immediate, since imul is commutative.
42 * For ishl, only check the operand on the right.
43 */
44 bool commutes = (op == nir_op_imul);
45
46 for (unsigned i = commutes ? 0 : 1; i < ARRAY_SIZE(inputs); ++i) {
47 if (!nir_scalar_is_const(inputs[i]))
48 continue;
49
50 *variable = inputs[1 - i];
51
52 uint32_t value = nir_scalar_as_uint(inputs[i]);
53
54 if (op == nir_op_imul)
55 *imm = value;
56 else
57 *imm = (1 << value);
58
59 return true;
60 }
61
62 return false;
63 }
64
65 /*
66 * Try to rewrite (a << (#b + #c)) + #d as ((a << #b) + #d') << #c,
67 * assuming that #d is a multiple of 1 << #c. This takes advantage of
68 * the hardware's implicit << #c and avoids a right-shift.
69 *
70 * Similarly, try to rewrite (a * (#b << #c)) + #d as ((a * #b) + #d') << #c.
71 *
72 * This pattern occurs with a struct-of-array layout.
73 */
74 static bool
match_soa(nir_builder * b,struct match * match,unsigned format_shift)75 match_soa(nir_builder *b, struct match *match, unsigned format_shift)
76 {
77 if (!nir_scalar_is_alu(match->offset) ||
78 nir_scalar_alu_op(match->offset) != nir_op_iadd)
79 return false;
80
81 nir_scalar summands[] = {
82 nir_scalar_chase_alu_src(match->offset, 0),
83 nir_scalar_chase_alu_src(match->offset, 1),
84 };
85
86 for (unsigned i = 0; i < ARRAY_SIZE(summands); ++i) {
87 if (!nir_scalar_is_const(summands[i]))
88 continue;
89
90 /* Note: This is treated as signed regardless of the sign of the match.
91 * The final addition into the base can be signed or unsigned, but when
92 * we shift right by the format shift below we need to always sign extend
93 * to ensure that any negative offset remains negative when added into
94 * the index. That is, in:
95 *
96 * addr = base + (u64)((index + offset) << shift)
97 *
98 * `index` and `offset` are always 32 bits, and a negative `offset` needs
99 * to subtract from the index, so it needs to be sign extended when we
100 * apply the format shift regardless of the fact that the later conversion
101 * to 64 bits does not sign extend.
102 *
103 * TODO: We need to confirm how the hardware handles 32-bit overflow when
104 * applying the format shift, which might need rework here again.
105 */
106 int offset = nir_scalar_as_int(summands[i]);
107 nir_scalar variable;
108 uint32_t multiplier;
109
110 /* The other operand must multiply */
111 if (!match_imul_imm(summands[1 - i], &variable, &multiplier))
112 return false;
113
114 int offset_shifted = offset >> format_shift;
115 uint32_t multiplier_shifted = multiplier >> format_shift;
116
117 /* If the multiplier or the offset are not aligned, we can't rewrite */
118 if (multiplier != (multiplier_shifted << format_shift))
119 return false;
120
121 if (offset != (offset_shifted << format_shift))
122 return false;
123
124 /* Otherwise, rewrite! */
125 nir_def *unmultiplied = nir_vec_scalars(b, &variable, 1);
126
127 nir_def *rewrite = nir_iadd_imm(
128 b, nir_imul_imm(b, unmultiplied, multiplier_shifted), offset_shifted);
129
130 match->offset = nir_get_scalar(rewrite, 0);
131 match->shift = 0;
132 return true;
133 }
134
135 return false;
136 }
137
138 /* Try to pattern match address calculation */
139 static struct match
match_address(nir_builder * b,nir_scalar base,int8_t format_shift)140 match_address(nir_builder *b, nir_scalar base, int8_t format_shift)
141 {
142 struct match match = {.base = base};
143
144 /* All address calculations are iadd at the root */
145 if (!nir_scalar_is_alu(base) || nir_scalar_alu_op(base) != nir_op_iadd)
146 return match;
147
148 /* Only 64+32 addition is supported, look for an extension */
149 nir_scalar summands[] = {
150 nir_scalar_chase_alu_src(base, 0),
151 nir_scalar_chase_alu_src(base, 1),
152 };
153
154 for (unsigned i = 0; i < ARRAY_SIZE(summands); ++i) {
155 /* We can add a small constant to the 64-bit base for free */
156 if (nir_scalar_is_const(summands[i]) &&
157 nir_scalar_as_uint(summands[i]) < (1ull << 32)) {
158
159 uint32_t value = nir_scalar_as_uint(summands[i]);
160
161 return (struct match){
162 .base = summands[1 - i],
163 .offset = nir_get_scalar(nir_imm_int(b, value), 0),
164 .shift = -format_shift,
165 .sign_extend = false,
166 };
167 }
168
169 /* Otherwise, we can only add an offset extended from 32-bits */
170 if (!nir_scalar_is_alu(summands[i]))
171 continue;
172
173 nir_op op = nir_scalar_alu_op(summands[i]);
174
175 if (op != nir_op_u2u64 && op != nir_op_i2i64)
176 continue;
177
178 /* We've found a summand, commit to it */
179 match.base = summands[1 - i];
180 match.offset = nir_scalar_chase_alu_src(summands[i], 0);
181 match.sign_extend = (op == nir_op_i2i64);
182
183 /* Undo the implicit shift from using as offset */
184 match.shift = -format_shift;
185 break;
186 }
187
188 /* If we didn't find something to fold in, there's nothing else we can do */
189 if (!match.offset.def)
190 return match;
191
192 /* But if we did, we can try to fold in in a multiply */
193 nir_scalar multiplied;
194 uint32_t multiplier;
195
196 if (match_imul_imm(match.offset, &multiplied, &multiplier)) {
197 int8_t new_shift = match.shift;
198
199 /* Try to fold in either a full power-of-two, or just the power-of-two
200 * part of a non-power-of-two stride.
201 */
202 if (util_is_power_of_two_nonzero(multiplier)) {
203 new_shift += util_logbase2(multiplier);
204 multiplier = 1;
205 } else if (((multiplier >> format_shift) << format_shift) == multiplier) {
206 new_shift += format_shift;
207 multiplier >>= format_shift;
208 } else {
209 return match;
210 }
211
212 nir_def *multiplied_ssa = nir_vec_scalars(b, &multiplied, 1);
213
214 /* Only fold in if we wouldn't overflow the lsl field */
215 if (new_shift <= 2) {
216 match.offset =
217 nir_get_scalar(nir_imul_imm(b, multiplied_ssa, multiplier), 0);
218 match.shift = new_shift;
219 } else if (new_shift > 0) {
220 /* For large shifts, we do need a multiply, but we can
221 * shrink the shift to avoid generating an ishr.
222 */
223 assert(new_shift >= 3);
224
225 nir_def *rewrite =
226 nir_imul_imm(b, multiplied_ssa, multiplier << new_shift);
227
228 match.offset = nir_get_scalar(rewrite, 0);
229 match.shift = 0;
230 }
231 } else {
232 /* Try to match struct-of-arrays pattern, updating match if possible */
233 match_soa(b, &match, format_shift);
234 }
235
236 return match;
237 }
238
239 static enum pipe_format
format_for_bitsize(unsigned bitsize)240 format_for_bitsize(unsigned bitsize)
241 {
242 switch (bitsize) {
243 case 8:
244 return PIPE_FORMAT_R8_UINT;
245 case 16:
246 return PIPE_FORMAT_R16_UINT;
247 case 32:
248 return PIPE_FORMAT_R32_UINT;
249 default:
250 unreachable("should have been lowered");
251 }
252 }
253
254 static bool
pass(struct nir_builder * b,nir_intrinsic_instr * intr,void * data)255 pass(struct nir_builder *b, nir_intrinsic_instr *intr, void *data)
256 {
257 if (intr->intrinsic != nir_intrinsic_load_global &&
258 intr->intrinsic != nir_intrinsic_load_global_constant &&
259 intr->intrinsic != nir_intrinsic_global_atomic &&
260 intr->intrinsic != nir_intrinsic_global_atomic_swap &&
261 intr->intrinsic != nir_intrinsic_store_global)
262 return false;
263
264 b->cursor = nir_before_instr(&intr->instr);
265
266 unsigned bitsize = intr->intrinsic == nir_intrinsic_store_global
267 ? nir_src_bit_size(intr->src[0])
268 : intr->def.bit_size;
269 enum pipe_format format = format_for_bitsize(bitsize);
270 unsigned format_shift = util_logbase2(util_format_get_blocksize(format));
271
272 nir_src *orig_offset = nir_get_io_offset_src(intr);
273 nir_scalar base = nir_scalar_resolved(orig_offset->ssa, 0);
274 struct match match = match_address(b, base, format_shift);
275
276 nir_def *offset = match.offset.def != NULL
277 ? nir_channel(b, match.offset.def, match.offset.comp)
278 : nir_imm_int(b, 0);
279
280 /* If we were unable to fold in the shift, insert a right-shift now to undo
281 * the implicit left shift of the instruction.
282 */
283 if (match.shift < 0) {
284 if (match.sign_extend)
285 offset = nir_ishr_imm(b, offset, -match.shift);
286 else
287 offset = nir_ushr_imm(b, offset, -match.shift);
288
289 match.shift = 0;
290 }
291
292 /* Hardware offsets must be 32-bits. Upconvert if the source code used
293 * smaller integers.
294 */
295 if (offset->bit_size != 32) {
296 assert(offset->bit_size < 32);
297
298 if (match.sign_extend)
299 offset = nir_i2i32(b, offset);
300 else
301 offset = nir_u2u32(b, offset);
302 }
303
304 assert(match.shift >= 0);
305 nir_def *new_base = nir_channel(b, match.base.def, match.base.comp);
306
307 nir_def *repl = NULL;
308 bool has_dest = (intr->intrinsic != nir_intrinsic_store_global);
309 unsigned num_components = has_dest ? intr->def.num_components : 0;
310 unsigned bit_size = has_dest ? intr->def.bit_size : 0;
311
312 if (intr->intrinsic == nir_intrinsic_load_global) {
313 repl =
314 nir_load_agx(b, num_components, bit_size, new_base, offset,
315 .access = nir_intrinsic_access(intr), .base = match.shift,
316 .format = format, .sign_extend = match.sign_extend);
317
318 } else if (intr->intrinsic == nir_intrinsic_load_global_constant) {
319 repl = nir_load_constant_agx(b, num_components, bit_size, new_base,
320 offset, .access = nir_intrinsic_access(intr),
321 .base = match.shift, .format = format,
322 .sign_extend = match.sign_extend);
323 } else if (intr->intrinsic == nir_intrinsic_global_atomic) {
324 offset = nir_ishl_imm(b, offset, match.shift);
325 repl =
326 nir_global_atomic_agx(b, bit_size, new_base, offset, intr->src[1].ssa,
327 .atomic_op = nir_intrinsic_atomic_op(intr),
328 .sign_extend = match.sign_extend);
329 } else if (intr->intrinsic == nir_intrinsic_global_atomic_swap) {
330 offset = nir_ishl_imm(b, offset, match.shift);
331 repl = nir_global_atomic_swap_agx(
332 b, bit_size, new_base, offset, intr->src[1].ssa, intr->src[2].ssa,
333 .atomic_op = nir_intrinsic_atomic_op(intr),
334 .sign_extend = match.sign_extend);
335 } else {
336 nir_store_agx(b, intr->src[0].ssa, new_base, offset,
337 .access = nir_intrinsic_access(intr), .base = match.shift,
338 .format = format, .sign_extend = match.sign_extend);
339 }
340
341 if (repl)
342 nir_def_rewrite_uses(&intr->def, repl);
343
344 nir_instr_remove(&intr->instr);
345 return true;
346 }
347
348 bool
agx_nir_lower_address(nir_shader * shader)349 agx_nir_lower_address(nir_shader *shader)
350 {
351 return nir_shader_intrinsics_pass(shader, pass, nir_metadata_control_flow,
352 NULL);
353 }
354