xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_opt_load_store_vectorize.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2019 Valve Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 /**
25  * Although it's called a load/store "vectorization" pass, this also combines
26  * intersecting and identical loads/stores. It currently supports derefs, ubo,
27  * ssbo and push constant loads/stores.
28  *
29  * This doesn't handle copy_deref intrinsics and assumes that
30  * nir_lower_alu_to_scalar() has been called and that the IR is free from ALU
31  * modifiers. It also assumes that derefs have explicitly laid out types.
32  *
33  * After vectorization, the backend may want to call nir_lower_alu_to_scalar()
34  * and nir_lower_pack(). Also this creates cast instructions taking derefs as a
35  * source and some parts of NIR may not be able to handle that well.
36  *
37  * There are a few situations where this doesn't vectorize as well as it could:
38  * - It won't turn four consecutive vec3 loads into 3 vec4 loads.
39  * - It doesn't do global vectorization.
40  * Handling these cases probably wouldn't provide much benefit though.
41  *
42  * This probably doesn't handle big-endian GPUs correctly.
43  */
44 
45 #include "util/u_dynarray.h"
46 #include "nir.h"
47 #include "nir_builder.h"
48 #include "nir_deref.h"
49 #include "nir_worklist.h"
50 
51 #include <stdlib.h>
52 
53 struct intrinsic_info {
54    nir_variable_mode mode; /* 0 if the mode is obtained from the deref. */
55    nir_intrinsic_op op;
56    bool is_atomic;
57    /* Indices into nir_intrinsic::src[] or -1 if not applicable. */
58    int resource_src; /* resource (e.g. from vulkan_resource_index) */
59    int base_src;     /* offset which it loads/stores from */
60    int deref_src;    /* deref which is loads/stores from */
61    int value_src;    /* the data it is storing */
62 
63    /* Number of bytes for an offset delta of 1. */
64    unsigned offset_scale;
65 };
66 
67 static const struct intrinsic_info *
get_info(nir_intrinsic_op op)68 get_info(nir_intrinsic_op op)
69 {
70    switch (op) {
71 #define INFO(mode, op, atomic, res, base, deref, val, scale)                                                             \
72    case nir_intrinsic_##op: {                                                                                            \
73       static const struct intrinsic_info op##_info = { mode, nir_intrinsic_##op, atomic, res, base, deref, val, scale }; \
74       return &op##_info;                                                                                                 \
75    }
76 #define LOAD(mode, op, res, base, deref, scale)       INFO(mode, load_##op, false, res, base, deref, -1, scale)
77 #define STORE(mode, op, res, base, deref, val, scale) INFO(mode, store_##op, false, res, base, deref, val, scale)
78 #define ATOMIC(mode, type, res, base, deref, val, scale)         \
79    INFO(mode, type##_atomic, true, res, base, deref, val, scale) \
80    INFO(mode, type##_atomic_swap, true, res, base, deref, val, scale)
81 
82       LOAD(nir_var_mem_push_const, push_constant, -1, 0, -1, 1)
83       LOAD(nir_var_mem_ubo, ubo, 0, 1, -1, 1)
84       LOAD(nir_var_mem_ssbo, ssbo, 0, 1, -1, 1)
85       STORE(nir_var_mem_ssbo, ssbo, 1, 2, -1, 0, 1)
86       LOAD(0, deref, -1, -1, 0, 1)
87       STORE(0, deref, -1, -1, 0, 1, 1)
88       LOAD(nir_var_mem_shared, shared, -1, 0, -1, 1)
89       STORE(nir_var_mem_shared, shared, -1, 1, -1, 0, 1)
90       LOAD(nir_var_mem_global, global, -1, 0, -1, 1)
91       STORE(nir_var_mem_global, global, -1, 1, -1, 0, 1)
92       LOAD(nir_var_mem_global, global_constant, -1, 0, -1, 1)
93       LOAD(nir_var_mem_task_payload, task_payload, -1, 0, -1, 1)
94       STORE(nir_var_mem_task_payload, task_payload, -1, 1, -1, 0, 1)
95       ATOMIC(nir_var_mem_ssbo, ssbo, 0, 1, -1, 2, 1)
96       ATOMIC(0, deref, -1, -1, 0, 1, 1)
97       ATOMIC(nir_var_mem_shared, shared, -1, 0, -1, 1, 1)
98       ATOMIC(nir_var_mem_global, global, -1, 0, -1, 1, 1)
99       ATOMIC(nir_var_mem_task_payload, task_payload, -1, 0, -1, 1, 1)
100       LOAD(nir_var_shader_temp, stack, -1, -1, -1, 1)
101       STORE(nir_var_shader_temp, stack, -1, -1, -1, 0, 1)
102       LOAD(nir_var_shader_temp, scratch, -1, 0, -1, 1)
103       STORE(nir_var_shader_temp, scratch, -1, 1, -1, 0, 1)
104       LOAD(nir_var_mem_ubo, ubo_uniform_block_intel, 0, 1, -1, 1)
105       LOAD(nir_var_mem_ssbo, ssbo_uniform_block_intel, 0, 1, -1, 1)
106       LOAD(nir_var_mem_shared, shared_uniform_block_intel, -1, 0, -1, 1)
107       LOAD(nir_var_mem_global, global_constant_uniform_block_intel, -1, 0, -1, 1)
108       INFO(nir_var_mem_ubo, ldc_nv, false, 0, 1, -1, -1, 1)
109       INFO(nir_var_mem_ubo, ldcx_nv, false, 0, 1, -1, -1, 1)
110       LOAD(nir_var_uniform, const_ir3, -1, 0, -1, 4)
111       STORE(nir_var_uniform, const_ir3, -1, -1, -1, 0, 4)
112       INFO(nir_var_mem_shared, shared_append_amd, true, -1, -1, -1, -1, 1)
113       INFO(nir_var_mem_shared, shared_consume_amd, true, -1, -1, -1, -1, 1)
114    default:
115       break;
116 #undef ATOMIC
117 #undef STORE
118 #undef LOAD
119 #undef INFO
120    }
121    return NULL;
122 }
123 
124 /*
125  * Information used to compare memory operations.
126  * It canonically represents an offset as:
127  * `offset_defs[0]*offset_defs_mul[0] + offset_defs[1]*offset_defs_mul[1] + ...`
128  * "offset_defs" is sorted in ascenting order by the ssa definition's index.
129  * "resource" or "var" may be NULL.
130  */
131 struct entry_key {
132    nir_def *resource;
133    nir_variable *var;
134    unsigned offset_def_count;
135    nir_scalar *offset_defs;
136    uint64_t *offset_defs_mul;
137 };
138 
139 /* Information on a single memory operation. */
140 struct entry {
141    struct list_head head;
142    unsigned index;
143 
144    struct entry_key *key;
145    union {
146       uint64_t offset; /* sign-extended */
147       int64_t offset_signed;
148    };
149    uint32_t align_mul;
150    uint32_t align_offset;
151 
152    nir_instr *instr;
153    nir_intrinsic_instr *intrin;
154    const struct intrinsic_info *info;
155    enum gl_access_qualifier access;
156    bool is_store;
157 
158    nir_deref_instr *deref;
159 };
160 
161 struct vectorize_ctx {
162    nir_shader *shader;
163    const nir_load_store_vectorize_options *options;
164    struct list_head entries[nir_num_variable_modes];
165    struct hash_table *loads[nir_num_variable_modes];
166    struct hash_table *stores[nir_num_variable_modes];
167 };
168 
169 static uint32_t
hash_entry_key(const void * key_)170 hash_entry_key(const void *key_)
171 {
172    /* this is careful to not include pointers in the hash calculation so that
173     * the order of the hash table walk is deterministic */
174    struct entry_key *key = (struct entry_key *)key_;
175 
176    uint32_t hash = 0;
177    if (key->resource)
178       hash = XXH32(&key->resource->index, sizeof(key->resource->index), hash);
179    if (key->var) {
180       hash = XXH32(&key->var->index, sizeof(key->var->index), hash);
181       unsigned mode = key->var->data.mode;
182       hash = XXH32(&mode, sizeof(mode), hash);
183    }
184 
185    for (unsigned i = 0; i < key->offset_def_count; i++) {
186       hash = XXH32(&key->offset_defs[i].def->index, sizeof(key->offset_defs[i].def->index), hash);
187       hash = XXH32(&key->offset_defs[i].comp, sizeof(key->offset_defs[i].comp), hash);
188    }
189 
190    hash = XXH32(key->offset_defs_mul, key->offset_def_count * sizeof(uint64_t), hash);
191 
192    return hash;
193 }
194 
195 static bool
entry_key_equals(const void * a_,const void * b_)196 entry_key_equals(const void *a_, const void *b_)
197 {
198    struct entry_key *a = (struct entry_key *)a_;
199    struct entry_key *b = (struct entry_key *)b_;
200 
201    if (a->var != b->var || a->resource != b->resource)
202       return false;
203 
204    if (a->offset_def_count != b->offset_def_count)
205       return false;
206 
207    for (unsigned i = 0; i < a->offset_def_count; i++) {
208       if (!nir_scalar_equal(a->offset_defs[i], b->offset_defs[i]))
209          return false;
210    }
211 
212    size_t offset_def_mul_size = a->offset_def_count * sizeof(uint64_t);
213    if (a->offset_def_count &&
214        memcmp(a->offset_defs_mul, b->offset_defs_mul, offset_def_mul_size))
215       return false;
216 
217    return true;
218 }
219 
220 static void
delete_entry_dynarray(struct hash_entry * entry)221 delete_entry_dynarray(struct hash_entry *entry)
222 {
223    struct util_dynarray *arr = (struct util_dynarray *)entry->data;
224    ralloc_free(arr);
225 }
226 
227 static int
sort_entries(const void * a_,const void * b_)228 sort_entries(const void *a_, const void *b_)
229 {
230    struct entry *a = *(struct entry *const *)a_;
231    struct entry *b = *(struct entry *const *)b_;
232 
233    if (a->offset_signed > b->offset_signed)
234       return 1;
235    else if (a->offset_signed < b->offset_signed)
236       return -1;
237    else
238       return 0;
239 }
240 
241 static unsigned
get_bit_size(struct entry * entry)242 get_bit_size(struct entry *entry)
243 {
244    unsigned size = entry->is_store ? entry->intrin->src[entry->info->value_src].ssa->bit_size : entry->intrin->def.bit_size;
245    return size == 1 ? 32u : size;
246 }
247 
248 static unsigned
get_write_mask(const nir_intrinsic_instr * intrin)249 get_write_mask(const nir_intrinsic_instr *intrin)
250 {
251    if (nir_intrinsic_has_write_mask(intrin))
252       return nir_intrinsic_write_mask(intrin);
253 
254    const struct intrinsic_info *info = get_info(intrin->intrinsic);
255    assert(info->value_src >= 0);
256    return nir_component_mask(intrin->src[info->value_src].ssa->num_components);
257 }
258 
259 /* If "def" is from an alu instruction with the opcode "op" and one of it's
260  * sources is a constant, update "def" to be the non-constant source, fill "c"
261  * with the constant and return true. */
262 static bool
parse_alu(nir_scalar * def,nir_op op,uint64_t * c)263 parse_alu(nir_scalar *def, nir_op op, uint64_t *c)
264 {
265    if (!nir_scalar_is_alu(*def) || nir_scalar_alu_op(*def) != op)
266       return false;
267 
268    nir_scalar src0 = nir_scalar_chase_alu_src(*def, 0);
269    nir_scalar src1 = nir_scalar_chase_alu_src(*def, 1);
270    if (op != nir_op_ishl && nir_scalar_is_const(src0)) {
271       *c = nir_scalar_as_uint(src0);
272       *def = src1;
273    } else if (nir_scalar_is_const(src1)) {
274       *c = nir_scalar_as_uint(src1);
275       *def = src0;
276    } else {
277       return false;
278    }
279    return true;
280 }
281 
282 /* Parses an offset expression such as "a * 16 + 4" and "(a * 16 + 4) * 64 + 32". */
283 static void
parse_offset(nir_scalar * base,uint64_t * base_mul,uint64_t * offset)284 parse_offset(nir_scalar *base, uint64_t *base_mul, uint64_t *offset)
285 {
286    if (nir_scalar_is_const(*base)) {
287       *offset = nir_scalar_as_uint(*base);
288       base->def = NULL;
289       return;
290    }
291 
292    uint64_t mul = 1;
293    uint64_t add = 0;
294    bool progress = false;
295    do {
296       uint64_t mul2 = 1, add2 = 0;
297 
298       progress = parse_alu(base, nir_op_imul, &mul2);
299       mul *= mul2;
300 
301       mul2 = 0;
302       progress |= parse_alu(base, nir_op_ishl, &mul2);
303       mul <<= mul2;
304 
305       progress |= parse_alu(base, nir_op_iadd, &add2);
306       add += add2 * mul;
307 
308       if (nir_scalar_is_alu(*base) && nir_scalar_alu_op(*base) == nir_op_mov) {
309          *base = nir_scalar_chase_alu_src(*base, 0);
310          progress = true;
311       }
312    } while (progress);
313 
314    if (base->def->parent_instr->type == nir_instr_type_intrinsic) {
315       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(base->def->parent_instr);
316       if (intrin->intrinsic == nir_intrinsic_load_vulkan_descriptor)
317          base->def = NULL;
318    }
319 
320    *base_mul = mul;
321    *offset = add;
322 }
323 
324 static unsigned
type_scalar_size_bytes(const struct glsl_type * type)325 type_scalar_size_bytes(const struct glsl_type *type)
326 {
327    assert(glsl_type_is_vector_or_scalar(type) ||
328           glsl_type_is_matrix(type));
329    return glsl_type_is_boolean(type) ? 4u : glsl_get_bit_size(type) / 8u;
330 }
331 
332 static unsigned
add_to_entry_key(nir_scalar * offset_defs,uint64_t * offset_defs_mul,unsigned offset_def_count,nir_scalar def,uint64_t mul)333 add_to_entry_key(nir_scalar *offset_defs, uint64_t *offset_defs_mul,
334                  unsigned offset_def_count, nir_scalar def, uint64_t mul)
335 {
336    mul = util_mask_sign_extend(mul, def.def->bit_size);
337 
338    for (unsigned i = 0; i <= offset_def_count; i++) {
339       if (i == offset_def_count || def.def->index > offset_defs[i].def->index) {
340          /* insert before i */
341          memmove(offset_defs + i + 1, offset_defs + i,
342                  (offset_def_count - i) * sizeof(nir_scalar));
343          memmove(offset_defs_mul + i + 1, offset_defs_mul + i,
344                  (offset_def_count - i) * sizeof(uint64_t));
345          offset_defs[i] = def;
346          offset_defs_mul[i] = mul;
347          return 1;
348       } else if (nir_scalar_equal(def, offset_defs[i])) {
349          /* merge with offset_def at i */
350          offset_defs_mul[i] += mul;
351          return 0;
352       }
353    }
354    unreachable("Unreachable.");
355    return 0;
356 }
357 
358 static struct entry_key *
create_entry_key_from_deref(void * mem_ctx,nir_deref_path * path,uint64_t * offset_base)359 create_entry_key_from_deref(void *mem_ctx,
360                             nir_deref_path *path,
361                             uint64_t *offset_base)
362 {
363    unsigned path_len = 0;
364    while (path->path[path_len])
365       path_len++;
366 
367    nir_scalar offset_defs_stack[32];
368    uint64_t offset_defs_mul_stack[32];
369    nir_scalar *offset_defs = offset_defs_stack;
370    uint64_t *offset_defs_mul = offset_defs_mul_stack;
371    if (path_len > 32) {
372       offset_defs = malloc(path_len * sizeof(nir_scalar));
373       offset_defs_mul = malloc(path_len * sizeof(uint64_t));
374    }
375    unsigned offset_def_count = 0;
376 
377    struct entry_key *key = ralloc(mem_ctx, struct entry_key);
378    key->resource = NULL;
379    key->var = NULL;
380    *offset_base = 0;
381 
382    for (unsigned i = 0; i < path_len; i++) {
383       nir_deref_instr *parent = i ? path->path[i - 1] : NULL;
384       nir_deref_instr *deref = path->path[i];
385 
386       switch (deref->deref_type) {
387       case nir_deref_type_var: {
388          assert(!parent);
389          key->var = deref->var;
390          break;
391       }
392       case nir_deref_type_array:
393       case nir_deref_type_ptr_as_array: {
394          assert(parent);
395          nir_def *index = deref->arr.index.ssa;
396          uint32_t stride = nir_deref_instr_array_stride(deref);
397 
398          nir_scalar base = { .def = index, .comp = 0 };
399          uint64_t offset = 0, base_mul = 1;
400          parse_offset(&base, &base_mul, &offset);
401          offset = util_mask_sign_extend(offset, index->bit_size);
402 
403          *offset_base += offset * stride;
404          if (base.def) {
405             offset_def_count += add_to_entry_key(offset_defs, offset_defs_mul,
406                                                  offset_def_count,
407                                                  base, base_mul * stride);
408          }
409          break;
410       }
411       case nir_deref_type_struct: {
412          assert(parent);
413          int offset = glsl_get_struct_field_offset(parent->type, deref->strct.index);
414          *offset_base += offset;
415          break;
416       }
417       case nir_deref_type_cast: {
418          if (!parent)
419             key->resource = deref->parent.ssa;
420          break;
421       }
422       default:
423          unreachable("Unhandled deref type");
424       }
425    }
426 
427    key->offset_def_count = offset_def_count;
428    key->offset_defs = ralloc_array(mem_ctx, nir_scalar, offset_def_count);
429    key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, offset_def_count);
430    memcpy(key->offset_defs, offset_defs, offset_def_count * sizeof(nir_scalar));
431    memcpy(key->offset_defs_mul, offset_defs_mul, offset_def_count * sizeof(uint64_t));
432 
433    if (offset_defs != offset_defs_stack)
434       free(offset_defs);
435    if (offset_defs_mul != offset_defs_mul_stack)
436       free(offset_defs_mul);
437 
438    return key;
439 }
440 
441 static unsigned
parse_entry_key_from_offset(struct entry_key * key,unsigned size,unsigned left,nir_scalar base,uint64_t base_mul,uint64_t * offset)442 parse_entry_key_from_offset(struct entry_key *key, unsigned size, unsigned left,
443                             nir_scalar base, uint64_t base_mul, uint64_t *offset)
444 {
445    uint64_t new_mul;
446    uint64_t new_offset;
447    parse_offset(&base, &new_mul, &new_offset);
448    *offset += new_offset * base_mul;
449 
450    if (!base.def)
451       return 0;
452 
453    base_mul *= new_mul;
454 
455    assert(left >= 1);
456 
457    if (left >= 2) {
458       if (nir_scalar_is_alu(base) && nir_scalar_alu_op(base) == nir_op_iadd) {
459          nir_scalar src0 = nir_scalar_chase_alu_src(base, 0);
460          nir_scalar src1 = nir_scalar_chase_alu_src(base, 1);
461          unsigned amount = parse_entry_key_from_offset(key, size, left - 1, src0, base_mul, offset);
462          amount += parse_entry_key_from_offset(key, size + amount, left - amount, src1, base_mul, offset);
463          return amount;
464       }
465    }
466 
467    return add_to_entry_key(key->offset_defs, key->offset_defs_mul, size, base, base_mul);
468 }
469 
470 static struct entry_key *
create_entry_key_from_offset(void * mem_ctx,nir_def * base,uint64_t base_mul,uint64_t * offset)471 create_entry_key_from_offset(void *mem_ctx, nir_def *base, uint64_t base_mul, uint64_t *offset)
472 {
473    struct entry_key *key = ralloc(mem_ctx, struct entry_key);
474    key->resource = NULL;
475    key->var = NULL;
476    if (base) {
477       nir_scalar offset_defs[32];
478       uint64_t offset_defs_mul[32];
479       key->offset_defs = offset_defs;
480       key->offset_defs_mul = offset_defs_mul;
481 
482       nir_scalar scalar = { .def = base, .comp = 0 };
483       key->offset_def_count = parse_entry_key_from_offset(key, 0, 32, scalar, base_mul, offset);
484 
485       key->offset_defs = ralloc_array(mem_ctx, nir_scalar, key->offset_def_count);
486       key->offset_defs_mul = ralloc_array(mem_ctx, uint64_t, key->offset_def_count);
487       memcpy(key->offset_defs, offset_defs, key->offset_def_count * sizeof(nir_scalar));
488       memcpy(key->offset_defs_mul, offset_defs_mul, key->offset_def_count * sizeof(uint64_t));
489    } else {
490       key->offset_def_count = 0;
491       key->offset_defs = NULL;
492       key->offset_defs_mul = NULL;
493    }
494    return key;
495 }
496 
497 static nir_variable_mode
get_variable_mode(struct entry * entry)498 get_variable_mode(struct entry *entry)
499 {
500    if (entry->info->mode)
501       return entry->info->mode;
502    assert(entry->deref && util_bitcount(entry->deref->modes) == 1);
503    return entry->deref->modes;
504 }
505 
506 static unsigned
mode_to_index(nir_variable_mode mode)507 mode_to_index(nir_variable_mode mode)
508 {
509    assert(util_bitcount(mode) == 1);
510 
511    /* Globals and SSBOs should be tracked together */
512    if (mode == nir_var_mem_global)
513       mode = nir_var_mem_ssbo;
514 
515    return ffs(mode) - 1;
516 }
517 
518 static nir_variable_mode
aliasing_modes(nir_variable_mode modes)519 aliasing_modes(nir_variable_mode modes)
520 {
521    /* Global and SSBO can alias */
522    if (modes & (nir_var_mem_ssbo | nir_var_mem_global))
523       modes |= nir_var_mem_ssbo | nir_var_mem_global;
524    return modes;
525 }
526 
527 static void
calc_alignment(struct entry * entry)528 calc_alignment(struct entry *entry)
529 {
530    uint32_t align_mul = 31;
531    for (unsigned i = 0; i < entry->key->offset_def_count; i++) {
532       if (entry->key->offset_defs_mul[i])
533          align_mul = MIN2(align_mul, ffsll(entry->key->offset_defs_mul[i]));
534    }
535 
536    entry->align_mul = 1u << (align_mul - 1);
537    bool has_align = nir_intrinsic_infos[entry->intrin->intrinsic].index_map[NIR_INTRINSIC_ALIGN_MUL];
538    if (!has_align || entry->align_mul >= nir_intrinsic_align_mul(entry->intrin)) {
539       entry->align_offset = entry->offset % entry->align_mul;
540    } else {
541       entry->align_mul = nir_intrinsic_align_mul(entry->intrin);
542       entry->align_offset = nir_intrinsic_align_offset(entry->intrin);
543    }
544 }
545 
546 static struct entry *
create_entry(void * mem_ctx,const struct intrinsic_info * info,nir_intrinsic_instr * intrin)547 create_entry(void *mem_ctx,
548              const struct intrinsic_info *info,
549              nir_intrinsic_instr *intrin)
550 {
551    struct entry *entry = rzalloc(mem_ctx, struct entry);
552    entry->intrin = intrin;
553    entry->instr = &intrin->instr;
554    entry->info = info;
555    entry->is_store = entry->info->value_src >= 0;
556 
557    if (entry->info->deref_src >= 0) {
558       entry->deref = nir_src_as_deref(intrin->src[entry->info->deref_src]);
559       nir_deref_path path;
560       nir_deref_path_init(&path, entry->deref, NULL);
561       entry->key = create_entry_key_from_deref(entry, &path, &entry->offset);
562       nir_deref_path_finish(&path);
563    } else {
564       nir_def *base = entry->info->base_src >= 0 ? intrin->src[entry->info->base_src].ssa : NULL;
565       uint64_t offset = 0;
566       if (nir_intrinsic_has_base(intrin))
567          offset += nir_intrinsic_base(intrin) * info->offset_scale;
568       entry->key = create_entry_key_from_offset(entry, base, info->offset_scale, &offset);
569       entry->offset = offset;
570 
571       if (base)
572          entry->offset = util_mask_sign_extend(entry->offset, base->bit_size);
573    }
574 
575    if (entry->info->resource_src >= 0)
576       entry->key->resource = intrin->src[entry->info->resource_src].ssa;
577 
578    if (nir_intrinsic_has_access(intrin))
579       entry->access = nir_intrinsic_access(intrin);
580    else if (entry->key->var)
581       entry->access = entry->key->var->data.access;
582 
583    if (nir_intrinsic_can_reorder(intrin))
584       entry->access |= ACCESS_CAN_REORDER;
585 
586    uint32_t restrict_modes = nir_var_shader_in | nir_var_shader_out;
587    restrict_modes |= nir_var_shader_temp | nir_var_function_temp;
588    restrict_modes |= nir_var_uniform | nir_var_mem_push_const;
589    restrict_modes |= nir_var_system_value | nir_var_mem_shared;
590    restrict_modes |= nir_var_mem_task_payload;
591    if (get_variable_mode(entry) & restrict_modes)
592       entry->access |= ACCESS_RESTRICT;
593 
594    calc_alignment(entry);
595 
596    return entry;
597 }
598 
599 static nir_deref_instr *
cast_deref(nir_builder * b,unsigned num_components,unsigned bit_size,nir_deref_instr * deref)600 cast_deref(nir_builder *b, unsigned num_components, unsigned bit_size, nir_deref_instr *deref)
601 {
602    if (glsl_get_components(deref->type) == num_components &&
603        type_scalar_size_bytes(deref->type) * 8u == bit_size)
604       return deref;
605 
606    enum glsl_base_type types[] = {
607       GLSL_TYPE_UINT8, GLSL_TYPE_UINT16, GLSL_TYPE_UINT, GLSL_TYPE_UINT64
608    };
609    enum glsl_base_type base = types[ffs(bit_size / 8u) - 1u];
610    const struct glsl_type *type = glsl_vector_type(base, num_components);
611 
612    if (deref->type == type)
613       return deref;
614 
615    return nir_build_deref_cast(b, &deref->def, deref->modes, type, 0);
616 }
617 
618 /* Return true if "new_bit_size" is a usable bit size for a vectorized load/store
619  * of "low" and "high". */
620 static bool
new_bitsize_acceptable(struct vectorize_ctx * ctx,unsigned new_bit_size,struct entry * low,struct entry * high,unsigned size)621 new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
622                        struct entry *low, struct entry *high, unsigned size)
623 {
624    if (size % new_bit_size != 0)
625       return false;
626 
627    unsigned new_num_components = size / new_bit_size;
628    if (!nir_num_components_valid(new_num_components))
629       return false;
630 
631    unsigned high_offset = high->offset_signed - low->offset_signed;
632 
633    /* check nir_extract_bits limitations */
634    unsigned common_bit_size = MIN2(get_bit_size(low), get_bit_size(high));
635    common_bit_size = MIN2(common_bit_size, new_bit_size);
636    if (high_offset > 0)
637       common_bit_size = MIN2(common_bit_size, (1u << (ffs(high_offset * 8) - 1)));
638    if (new_bit_size / common_bit_size > NIR_MAX_VEC_COMPONENTS)
639       return false;
640 
641    if (!ctx->options->callback(low->align_mul,
642                                low->align_offset,
643                                new_bit_size, new_num_components,
644                                low->intrin, high->intrin,
645                                ctx->options->cb_data))
646       return false;
647 
648    if (low->is_store) {
649       unsigned low_size = low->intrin->num_components * get_bit_size(low);
650       unsigned high_size = high->intrin->num_components * get_bit_size(high);
651 
652       if (low_size % new_bit_size != 0)
653          return false;
654       if (high_size % new_bit_size != 0)
655          return false;
656 
657       unsigned write_mask = get_write_mask(low->intrin);
658       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(low), new_bit_size))
659          return false;
660 
661       write_mask = get_write_mask(high->intrin);
662       if (!nir_component_mask_can_reinterpret(write_mask, get_bit_size(high), new_bit_size))
663          return false;
664    }
665 
666    return true;
667 }
668 
669 static nir_deref_instr *
subtract_deref(nir_builder * b,nir_deref_instr * deref,int64_t offset)670 subtract_deref(nir_builder *b, nir_deref_instr *deref, int64_t offset)
671 {
672    /* avoid adding another deref to the path */
673    if (deref->deref_type == nir_deref_type_ptr_as_array &&
674        nir_src_is_const(deref->arr.index) &&
675        offset % nir_deref_instr_array_stride(deref) == 0) {
676       unsigned stride = nir_deref_instr_array_stride(deref);
677       nir_def *index = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index) - offset / stride,
678                                       deref->def.bit_size);
679       return nir_build_deref_ptr_as_array(b, nir_deref_instr_parent(deref), index);
680    }
681 
682    if (deref->deref_type == nir_deref_type_array &&
683        nir_src_is_const(deref->arr.index)) {
684       nir_deref_instr *parent = nir_deref_instr_parent(deref);
685       unsigned stride = glsl_get_explicit_stride(parent->type);
686       if (offset % stride == 0)
687          return nir_build_deref_array_imm(
688             b, parent, nir_src_as_int(deref->arr.index) - offset / stride);
689    }
690 
691    deref = nir_build_deref_cast(b, &deref->def, deref->modes,
692                                 glsl_scalar_type(GLSL_TYPE_UINT8), 1);
693    return nir_build_deref_ptr_as_array(
694       b, deref, nir_imm_intN_t(b, -offset, deref->def.bit_size));
695 }
696 
697 static void
vectorize_loads(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)698 vectorize_loads(nir_builder *b, struct vectorize_ctx *ctx,
699                 struct entry *low, struct entry *high,
700                 struct entry *first, struct entry *second,
701                 unsigned new_bit_size, unsigned new_num_components,
702                 unsigned high_start)
703 {
704    unsigned low_bit_size = get_bit_size(low);
705    unsigned high_bit_size = get_bit_size(high);
706    bool low_bool = low->intrin->def.bit_size == 1;
707    bool high_bool = high->intrin->def.bit_size == 1;
708    nir_def *data = &first->intrin->def;
709 
710    b->cursor = nir_after_instr(first->instr);
711 
712    /* update the load's destination size and extract data for each of the original loads */
713    data->num_components = new_num_components;
714    data->bit_size = new_bit_size;
715 
716    nir_def *low_def = nir_extract_bits(
717       b, &data, 1, 0, low->intrin->num_components, low_bit_size);
718    nir_def *high_def = nir_extract_bits(
719       b, &data, 1, high_start, high->intrin->num_components, high_bit_size);
720 
721    /* convert booleans */
722    low_def = low_bool ? nir_i2b(b, low_def) : nir_mov(b, low_def);
723    high_def = high_bool ? nir_i2b(b, high_def) : nir_mov(b, high_def);
724 
725    /* update uses */
726    if (first == low) {
727       nir_def_rewrite_uses_after(&low->intrin->def, low_def,
728                                  high_def->parent_instr);
729       nir_def_rewrite_uses(&high->intrin->def, high_def);
730    } else {
731       nir_def_rewrite_uses(&low->intrin->def, low_def);
732       nir_def_rewrite_uses_after(&high->intrin->def, high_def,
733                                  high_def->parent_instr);
734    }
735 
736    /* update the intrinsic */
737    first->intrin->num_components = new_num_components;
738 
739    const struct intrinsic_info *info = first->info;
740 
741    /* update the offset */
742    if (first != low && info->base_src >= 0) {
743       /* let nir_opt_algebraic() remove this addition. this doesn't have much
744        * issues with subtracting 16 from expressions like "(i + 1) * 16" because
745        * nir_opt_algebraic() turns them into "i * 16 + 16" */
746       b->cursor = nir_before_instr(first->instr);
747 
748       nir_def *new_base = first->intrin->src[info->base_src].ssa;
749       new_base = nir_iadd_imm(b, new_base, -(int)(high_start / 8u / first->info->offset_scale));
750 
751       nir_src_rewrite(&first->intrin->src[info->base_src], new_base);
752    }
753 
754    /* update the deref */
755    if (info->deref_src >= 0) {
756       b->cursor = nir_before_instr(first->instr);
757 
758       nir_deref_instr *deref = nir_src_as_deref(first->intrin->src[info->deref_src]);
759       if (first != low && high_start != 0)
760          deref = subtract_deref(b, deref, high_start / 8u / first->info->offset_scale);
761       first->deref = cast_deref(b, new_num_components, new_bit_size, deref);
762 
763       nir_src_rewrite(&first->intrin->src[info->deref_src],
764                       &first->deref->def);
765    }
766 
767    /* update align */
768    if (nir_intrinsic_has_range_base(first->intrin)) {
769       uint32_t low_base = nir_intrinsic_range_base(low->intrin);
770       uint32_t high_base = nir_intrinsic_range_base(high->intrin);
771       uint32_t low_end = low_base + nir_intrinsic_range(low->intrin);
772       uint32_t high_end = high_base + nir_intrinsic_range(high->intrin);
773 
774       nir_intrinsic_set_range_base(first->intrin, low_base);
775       nir_intrinsic_set_range(first->intrin, MAX2(low_end, high_end) - low_base);
776    } else if (nir_intrinsic_has_base(first->intrin) && info->base_src == -1 && info->deref_src == -1) {
777       nir_intrinsic_set_base(first->intrin, nir_intrinsic_base(low->intrin));
778    }
779 
780    first->key = low->key;
781    first->offset = low->offset;
782 
783    first->align_mul = low->align_mul;
784    first->align_offset = low->align_offset;
785 
786    nir_instr_remove(second->instr);
787 }
788 
789 static void
vectorize_stores(nir_builder * b,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second,unsigned new_bit_size,unsigned new_num_components,unsigned high_start)790 vectorize_stores(nir_builder *b, struct vectorize_ctx *ctx,
791                  struct entry *low, struct entry *high,
792                  struct entry *first, struct entry *second,
793                  unsigned new_bit_size, unsigned new_num_components,
794                  unsigned high_start)
795 {
796    ASSERTED unsigned low_size = low->intrin->num_components * get_bit_size(low);
797    assert(low_size % new_bit_size == 0);
798 
799    b->cursor = nir_before_instr(second->instr);
800 
801    /* get new writemasks */
802    uint32_t low_write_mask = get_write_mask(low->intrin);
803    uint32_t high_write_mask = get_write_mask(high->intrin);
804    low_write_mask = nir_component_mask_reinterpret(low_write_mask,
805                                                    get_bit_size(low),
806                                                    new_bit_size);
807    high_write_mask = nir_component_mask_reinterpret(high_write_mask,
808                                                     get_bit_size(high),
809                                                     new_bit_size);
810    high_write_mask <<= high_start / new_bit_size;
811 
812    uint32_t write_mask = low_write_mask | high_write_mask;
813 
814    /* convert booleans */
815    nir_def *low_val = low->intrin->src[low->info->value_src].ssa;
816    nir_def *high_val = high->intrin->src[high->info->value_src].ssa;
817    low_val = low_val->bit_size == 1 ? nir_b2iN(b, low_val, 32) : low_val;
818    high_val = high_val->bit_size == 1 ? nir_b2iN(b, high_val, 32) : high_val;
819 
820    /* combine the data */
821    nir_def *data_channels[NIR_MAX_VEC_COMPONENTS];
822    for (unsigned i = 0; i < new_num_components; i++) {
823       bool set_low = low_write_mask & (1 << i);
824       bool set_high = high_write_mask & (1 << i);
825 
826       if (set_low && (!set_high || low == second)) {
827          unsigned offset = i * new_bit_size;
828          data_channels[i] = nir_extract_bits(b, &low_val, 1, offset, 1, new_bit_size);
829       } else if (set_high) {
830          assert(!set_low || high == second);
831          unsigned offset = i * new_bit_size - high_start;
832          data_channels[i] = nir_extract_bits(b, &high_val, 1, offset, 1, new_bit_size);
833       } else {
834          data_channels[i] = nir_undef(b, 1, new_bit_size);
835       }
836    }
837    nir_def *data = nir_vec(b, data_channels, new_num_components);
838 
839    /* update the intrinsic */
840    if (nir_intrinsic_has_write_mask(second->intrin))
841       nir_intrinsic_set_write_mask(second->intrin, write_mask);
842    second->intrin->num_components = data->num_components;
843 
844    const struct intrinsic_info *info = second->info;
845    assert(info->value_src >= 0);
846    nir_src_rewrite(&second->intrin->src[info->value_src], data);
847 
848    /* update the offset */
849    if (second != low && info->base_src >= 0)
850       nir_src_rewrite(&second->intrin->src[info->base_src],
851                       low->intrin->src[info->base_src].ssa);
852 
853    /* update the deref */
854    if (info->deref_src >= 0) {
855       b->cursor = nir_before_instr(second->instr);
856       second->deref = cast_deref(b, new_num_components, new_bit_size,
857                                  nir_src_as_deref(low->intrin->src[info->deref_src]));
858       nir_src_rewrite(&second->intrin->src[info->deref_src],
859                       &second->deref->def);
860    }
861 
862    /* update base/align */
863    if (second != low && nir_intrinsic_has_base(second->intrin))
864       nir_intrinsic_set_base(second->intrin, nir_intrinsic_base(low->intrin));
865 
866    second->key = low->key;
867    second->offset = low->offset;
868 
869    second->align_mul = low->align_mul;
870    second->align_offset = low->align_offset;
871 
872    list_del(&first->head);
873    nir_instr_remove(first->instr);
874 }
875 
876 /* Returns true if it can prove that "a" and "b" point to different bindings
877  * and either one uses ACCESS_RESTRICT. */
878 static bool
bindings_different_restrict(nir_shader * shader,struct entry * a,struct entry * b)879 bindings_different_restrict(nir_shader *shader, struct entry *a, struct entry *b)
880 {
881    bool different_bindings = false;
882    nir_variable *a_var = NULL, *b_var = NULL;
883    if (a->key->resource && b->key->resource) {
884       nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
885       nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
886       if (!a_res.success || !b_res.success)
887          return false;
888 
889       if (a_res.num_indices != b_res.num_indices ||
890           a_res.desc_set != b_res.desc_set ||
891           a_res.binding != b_res.binding)
892          different_bindings = true;
893 
894       for (unsigned i = 0; i < a_res.num_indices; i++) {
895          if (nir_src_is_const(a_res.indices[i]) && nir_src_is_const(b_res.indices[i]) &&
896              nir_src_as_uint(a_res.indices[i]) != nir_src_as_uint(b_res.indices[i]))
897             different_bindings = true;
898       }
899 
900       if (different_bindings) {
901          a_var = nir_get_binding_variable(shader, a_res);
902          b_var = nir_get_binding_variable(shader, b_res);
903       }
904    } else if (a->key->var && b->key->var) {
905       a_var = a->key->var;
906       b_var = b->key->var;
907       different_bindings = a_var != b_var;
908    } else if (!!a->key->resource != !!b->key->resource) {
909       /* comparing global and ssbo access */
910       different_bindings = true;
911 
912       if (a->key->resource) {
913          nir_binding a_res = nir_chase_binding(nir_src_for_ssa(a->key->resource));
914          a_var = nir_get_binding_variable(shader, a_res);
915       }
916 
917       if (b->key->resource) {
918          nir_binding b_res = nir_chase_binding(nir_src_for_ssa(b->key->resource));
919          b_var = nir_get_binding_variable(shader, b_res);
920       }
921    } else {
922       return false;
923    }
924 
925    unsigned a_access = a->access | (a_var ? a_var->data.access : 0);
926    unsigned b_access = b->access | (b_var ? b_var->data.access : 0);
927 
928    return different_bindings &&
929           ((a_access | b_access) & ACCESS_RESTRICT);
930 }
931 
932 static int64_t
compare_entries(struct entry * a,struct entry * b)933 compare_entries(struct entry *a, struct entry *b)
934 {
935    if (!entry_key_equals(a->key, b->key))
936       return INT64_MAX;
937    return b->offset_signed - a->offset_signed;
938 }
939 
940 static bool
may_alias(nir_shader * shader,struct entry * a,struct entry * b)941 may_alias(nir_shader *shader, struct entry *a, struct entry *b)
942 {
943    assert(mode_to_index(get_variable_mode(a)) ==
944           mode_to_index(get_variable_mode(b)));
945 
946    if ((a->access | b->access) & ACCESS_CAN_REORDER)
947       return false;
948 
949    /* if the resources/variables are definitively different and both have
950     * ACCESS_RESTRICT, we can assume they do not alias. */
951    if (bindings_different_restrict(shader, a, b))
952       return false;
953 
954    /* we can't compare offsets if the resources/variables might be different */
955    if (a->key->var != b->key->var || a->key->resource != b->key->resource)
956       return true;
957 
958    /* use adjacency information */
959    /* TODO: we can look closer at the entry keys */
960    int64_t diff = compare_entries(a, b);
961    if (diff != INT64_MAX) {
962       /* with atomics, intrin->num_components can be 0 */
963       if (diff < 0)
964          return llabs(diff) < MAX2(b->intrin->num_components, 1u) * (get_bit_size(b) / 8u);
965       else
966          return diff < MAX2(a->intrin->num_components, 1u) * (get_bit_size(a) / 8u);
967    }
968 
969    /* TODO: we can use deref information */
970 
971    return true;
972 }
973 
974 static bool
check_for_aliasing(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)975 check_for_aliasing(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
976 {
977    nir_variable_mode mode = get_variable_mode(first);
978    if (mode & (nir_var_uniform | nir_var_system_value |
979                nir_var_mem_push_const | nir_var_mem_ubo))
980       return false;
981 
982    unsigned mode_index = mode_to_index(mode);
983    if (first->is_store) {
984       /* find first entry that aliases "first" */
985       list_for_each_entry_from(struct entry, next, first, &ctx->entries[mode_index], head) {
986          if (next == first)
987             continue;
988          if (next == second)
989             return false;
990          if (may_alias(ctx->shader, first, next))
991             return true;
992       }
993    } else {
994       /* find previous store that aliases this load */
995       list_for_each_entry_from_rev(struct entry, prev, second, &ctx->entries[mode_index], head) {
996          if (prev == second)
997             continue;
998          if (prev == first)
999             return false;
1000          if (prev->is_store && may_alias(ctx->shader, second, prev))
1001             return true;
1002       }
1003    }
1004 
1005    return false;
1006 }
1007 
1008 static uint64_t
calc_gcd(uint64_t a,uint64_t b)1009 calc_gcd(uint64_t a, uint64_t b)
1010 {
1011    while (b != 0) {
1012       int tmp_a = a;
1013       a = b;
1014       b = tmp_a % b;
1015    }
1016    return a;
1017 }
1018 
1019 static uint64_t
round_down(uint64_t a,uint64_t b)1020 round_down(uint64_t a, uint64_t b)
1021 {
1022    return a / b * b;
1023 }
1024 
1025 static bool
addition_wraps(uint64_t a,uint64_t b,unsigned bits)1026 addition_wraps(uint64_t a, uint64_t b, unsigned bits)
1027 {
1028    uint64_t mask = BITFIELD64_MASK(bits);
1029    return ((a + b) & mask) < (a & mask);
1030 }
1031 
1032 /* Return true if the addition of "low"'s offset and "high_offset" could wrap
1033  * around.
1034  *
1035  * This is to prevent a situation where the hardware considers the high load
1036  * out-of-bounds after vectorization if the low load is out-of-bounds, even if
1037  * the wrap-around from the addition could make the high load in-bounds.
1038  */
1039 static bool
check_for_robustness(struct vectorize_ctx * ctx,struct entry * low,uint64_t high_offset)1040 check_for_robustness(struct vectorize_ctx *ctx, struct entry *low, uint64_t high_offset)
1041 {
1042    nir_variable_mode mode = get_variable_mode(low);
1043    if (!(mode & ctx->options->robust_modes))
1044       return false;
1045 
1046    unsigned scale = low->info->offset_scale;
1047 
1048    /* First, try to use alignment information in case the application provided some. If the addition
1049     * of the maximum offset of the low load and "high_offset" wraps around, we can't combine the low
1050     * and high loads.
1051     */
1052    uint64_t max_low = round_down(UINT64_MAX, low->align_mul) + low->align_offset;
1053    if (!addition_wraps(max_low / scale, high_offset / scale, 64))
1054       return false;
1055 
1056    /* We can't obtain addition_bits */
1057    if (low->info->base_src < 0)
1058       return true;
1059 
1060    /* Second, use information about the factors from address calculation (offset_defs_mul). These
1061     * are not guaranteed to be power-of-2.
1062     */
1063    uint64_t stride = 0;
1064    for (unsigned i = 0; i < low->key->offset_def_count; i++)
1065       stride = calc_gcd(low->key->offset_defs_mul[i], stride);
1066 
1067    unsigned addition_bits = low->intrin->src[low->info->base_src].ssa->bit_size;
1068    /* low's offset must be a multiple of "stride" plus "low->offset". */
1069    max_low = low->offset;
1070    if (stride)
1071       max_low = round_down(BITFIELD64_MASK(addition_bits), stride) + (low->offset % stride);
1072    return addition_wraps(max_low / scale, high_offset / scale, addition_bits);
1073 }
1074 
1075 static bool
is_strided_vector(const struct glsl_type * type)1076 is_strided_vector(const struct glsl_type *type)
1077 {
1078    if (glsl_type_is_vector(type)) {
1079       unsigned explicit_stride = glsl_get_explicit_stride(type);
1080       return explicit_stride != 0 && explicit_stride !=
1081                                         type_scalar_size_bytes(glsl_get_array_element(type));
1082    } else {
1083       return false;
1084    }
1085 }
1086 
1087 static bool
can_vectorize(struct vectorize_ctx * ctx,struct entry * first,struct entry * second)1088 can_vectorize(struct vectorize_ctx *ctx, struct entry *first, struct entry *second)
1089 {
1090    if ((first->access | second->access) & ACCESS_KEEP_SCALAR)
1091       return false;
1092 
1093    if (!(get_variable_mode(first) & ctx->options->modes) ||
1094        !(get_variable_mode(second) & ctx->options->modes))
1095       return false;
1096 
1097    if (check_for_aliasing(ctx, first, second))
1098       return false;
1099 
1100    /* we can only vectorize non-volatile loads/stores of the same type and with
1101     * the same access */
1102    if (first->info != second->info || first->access != second->access ||
1103        (first->access & ACCESS_VOLATILE) || first->info->is_atomic)
1104       return false;
1105 
1106    return true;
1107 }
1108 
1109 static bool
try_vectorize(nir_function_impl * impl,struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1110 try_vectorize(nir_function_impl *impl, struct vectorize_ctx *ctx,
1111               struct entry *low, struct entry *high,
1112               struct entry *first, struct entry *second)
1113 {
1114    if (!can_vectorize(ctx, first, second))
1115       return false;
1116 
1117    uint64_t diff = high->offset_signed - low->offset_signed;
1118    if (check_for_robustness(ctx, low, diff))
1119       return false;
1120 
1121    /* don't attempt to vectorize accesses of row-major matrix columns */
1122    if (first->deref) {
1123       const struct glsl_type *first_type = first->deref->type;
1124       const struct glsl_type *second_type = second->deref->type;
1125       if (is_strided_vector(first_type) || is_strided_vector(second_type))
1126          return false;
1127    }
1128 
1129    /* gather information */
1130    unsigned low_bit_size = get_bit_size(low);
1131    unsigned high_bit_size = get_bit_size(high);
1132    unsigned low_size = low->intrin->num_components * low_bit_size;
1133    unsigned high_size = high->intrin->num_components * high_bit_size;
1134    unsigned new_size = MAX2(diff * 8u + high_size, low_size);
1135 
1136    /* find a good bit size for the new load/store */
1137    unsigned new_bit_size = 0;
1138    if (new_bitsize_acceptable(ctx, low_bit_size, low, high, new_size)) {
1139       new_bit_size = low_bit_size;
1140    } else if (low_bit_size != high_bit_size &&
1141               new_bitsize_acceptable(ctx, high_bit_size, low, high, new_size)) {
1142       new_bit_size = high_bit_size;
1143    } else {
1144       new_bit_size = 64;
1145       for (; new_bit_size >= 8; new_bit_size /= 2) {
1146          /* don't repeat trying out bitsizes */
1147          if (new_bit_size == low_bit_size || new_bit_size == high_bit_size)
1148             continue;
1149          if (new_bitsize_acceptable(ctx, new_bit_size, low, high, new_size))
1150             break;
1151       }
1152       if (new_bit_size < 8)
1153          return false;
1154    }
1155    unsigned new_num_components = new_size / new_bit_size;
1156 
1157    /* vectorize the loads/stores */
1158    nir_builder b = nir_builder_create(impl);
1159 
1160    if (first->is_store)
1161       vectorize_stores(&b, ctx, low, high, first, second,
1162                        new_bit_size, new_num_components, diff * 8u);
1163    else
1164       vectorize_loads(&b, ctx, low, high, first, second,
1165                       new_bit_size, new_num_components, diff * 8u);
1166 
1167    return true;
1168 }
1169 
1170 static bool
try_vectorize_shared2(struct vectorize_ctx * ctx,struct entry * low,struct entry * high,struct entry * first,struct entry * second)1171 try_vectorize_shared2(struct vectorize_ctx *ctx,
1172                       struct entry *low, struct entry *high,
1173                       struct entry *first, struct entry *second)
1174 {
1175    if (!can_vectorize(ctx, first, second) || first->deref)
1176       return false;
1177 
1178    unsigned low_bit_size = get_bit_size(low);
1179    unsigned high_bit_size = get_bit_size(high);
1180    unsigned low_size = low->intrin->num_components * low_bit_size / 8;
1181    unsigned high_size = high->intrin->num_components * high_bit_size / 8;
1182    if ((low_size != 4 && low_size != 8) || (high_size != 4 && high_size != 8))
1183       return false;
1184    if (low_size != high_size)
1185       return false;
1186    if (low->align_mul % low_size || low->align_offset % low_size)
1187       return false;
1188    if (high->align_mul % low_size || high->align_offset % low_size)
1189       return false;
1190 
1191    uint64_t diff = high->offset_signed - low->offset_signed;
1192    bool st64 = diff % (64 * low_size) == 0;
1193    unsigned stride = st64 ? 64 * low_size : low_size;
1194    if (diff % stride || diff > 255 * stride)
1195       return false;
1196 
1197    /* try to avoid creating accesses we can't combine additions/offsets into */
1198    if (high->offset > 255 * stride || (st64 && high->offset % stride))
1199       return false;
1200 
1201    if (first->is_store) {
1202       if (get_write_mask(low->intrin) != BITFIELD_MASK(low->intrin->num_components))
1203          return false;
1204       if (get_write_mask(high->intrin) != BITFIELD_MASK(high->intrin->num_components))
1205          return false;
1206    }
1207 
1208    /* vectorize the accesses */
1209    nir_builder b = nir_builder_at(nir_after_instr(first->is_store ? second->instr : first->instr));
1210 
1211    nir_def *offset = first->intrin->src[first->is_store].ssa;
1212    offset = nir_iadd_imm(&b, offset, nir_intrinsic_base(first->intrin));
1213    if (first != low)
1214       offset = nir_iadd_imm(&b, offset, -(int)diff);
1215 
1216    if (first->is_store) {
1217       nir_def *low_val = low->intrin->src[low->info->value_src].ssa;
1218       nir_def *high_val = high->intrin->src[high->info->value_src].ssa;
1219       nir_def *val = nir_vec2(&b, nir_bitcast_vector(&b, low_val, low_size * 8u),
1220                               nir_bitcast_vector(&b, high_val, low_size * 8u));
1221       nir_store_shared2_amd(&b, val, offset, .offset1 = diff / stride, .st64 = st64);
1222    } else {
1223       nir_def *new_def = nir_load_shared2_amd(&b, low_size * 8u, offset, .offset1 = diff / stride,
1224                                               .st64 = st64);
1225       nir_def_rewrite_uses(&low->intrin->def,
1226                            nir_bitcast_vector(&b, nir_channel(&b, new_def, 0), low_bit_size));
1227       nir_def_rewrite_uses(&high->intrin->def,
1228                            nir_bitcast_vector(&b, nir_channel(&b, new_def, 1), high_bit_size));
1229    }
1230 
1231    nir_instr_remove(first->instr);
1232    nir_instr_remove(second->instr);
1233 
1234    return true;
1235 }
1236 
1237 static bool
update_align(struct entry * entry)1238 update_align(struct entry *entry)
1239 {
1240    if (nir_intrinsic_has_align_mul(entry->intrin) &&
1241        (entry->align_mul != nir_intrinsic_align_mul(entry->intrin) ||
1242         entry->align_offset != nir_intrinsic_align_offset(entry->intrin))) {
1243       nir_intrinsic_set_align(entry->intrin, entry->align_mul, entry->align_offset);
1244       return true;
1245    }
1246    return false;
1247 }
1248 
1249 static bool
vectorize_sorted_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct util_dynarray * arr)1250 vectorize_sorted_entries(struct vectorize_ctx *ctx, nir_function_impl *impl,
1251                          struct util_dynarray *arr)
1252 {
1253    unsigned num_entries = util_dynarray_num_elements(arr, struct entry *);
1254 
1255    bool progress = false;
1256    for (unsigned first_idx = 0; first_idx < num_entries; first_idx++) {
1257       struct entry *low = *util_dynarray_element(arr, struct entry *, first_idx);
1258       if (!low)
1259          continue;
1260 
1261       for (unsigned second_idx = first_idx + 1; second_idx < num_entries; second_idx++) {
1262          struct entry *high = *util_dynarray_element(arr, struct entry *, second_idx);
1263          if (!high)
1264             continue;
1265 
1266          struct entry *first = low->index < high->index ? low : high;
1267          struct entry *second = low->index < high->index ? high : low;
1268 
1269          uint64_t diff = high->offset_signed - low->offset_signed;
1270          bool separate = diff > get_bit_size(low) / 8u * low->intrin->num_components;
1271          if (separate) {
1272             if (!ctx->options->has_shared2_amd ||
1273                 get_variable_mode(first) != nir_var_mem_shared)
1274                break;
1275 
1276             if (try_vectorize_shared2(ctx, low, high, first, second)) {
1277                low = NULL;
1278                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1279                progress = true;
1280                break;
1281             }
1282          } else {
1283             if (try_vectorize(impl, ctx, low, high, first, second)) {
1284                low = low->is_store ? second : first;
1285                *util_dynarray_element(arr, struct entry *, second_idx) = NULL;
1286                progress = true;
1287             }
1288          }
1289       }
1290 
1291       *util_dynarray_element(arr, struct entry *, first_idx) = low;
1292    }
1293 
1294    return progress;
1295 }
1296 
1297 static bool
vectorize_entries(struct vectorize_ctx * ctx,nir_function_impl * impl,struct hash_table * ht)1298 vectorize_entries(struct vectorize_ctx *ctx, nir_function_impl *impl, struct hash_table *ht)
1299 {
1300    if (!ht)
1301       return false;
1302 
1303    bool progress = false;
1304    hash_table_foreach(ht, entry) {
1305       struct util_dynarray *arr = entry->data;
1306       if (!arr->size)
1307          continue;
1308 
1309       qsort(util_dynarray_begin(arr),
1310             util_dynarray_num_elements(arr, struct entry *),
1311             sizeof(struct entry *), &sort_entries);
1312 
1313       while (vectorize_sorted_entries(ctx, impl, arr))
1314          progress = true;
1315 
1316       util_dynarray_foreach(arr, struct entry *, elem) {
1317          if (*elem)
1318             progress |= update_align(*elem);
1319       }
1320    }
1321 
1322    _mesa_hash_table_clear(ht, delete_entry_dynarray);
1323 
1324    return progress;
1325 }
1326 
1327 static bool
handle_barrier(struct vectorize_ctx * ctx,bool * progress,nir_function_impl * impl,nir_instr * instr)1328 handle_barrier(struct vectorize_ctx *ctx, bool *progress, nir_function_impl *impl, nir_instr *instr)
1329 {
1330    unsigned modes = 0;
1331    bool acquire = true;
1332    bool release = true;
1333    if (instr->type == nir_instr_type_intrinsic) {
1334       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1335       switch (intrin->intrinsic) {
1336       /* prevent speculative loads/stores */
1337       case nir_intrinsic_terminate_if:
1338       case nir_intrinsic_terminate:
1339       case nir_intrinsic_launch_mesh_workgroups:
1340          modes = nir_var_all;
1341          break;
1342       case nir_intrinsic_demote_if:
1343       case nir_intrinsic_demote:
1344          acquire = false;
1345          modes = nir_var_all;
1346          break;
1347       case nir_intrinsic_barrier:
1348          if (nir_intrinsic_memory_scope(intrin) == SCOPE_NONE)
1349             break;
1350 
1351          modes = nir_intrinsic_memory_modes(intrin) & (nir_var_mem_ssbo |
1352                                                        nir_var_mem_shared |
1353                                                        nir_var_mem_global |
1354                                                        nir_var_mem_task_payload);
1355          acquire = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_ACQUIRE;
1356          release = nir_intrinsic_memory_semantics(intrin) & NIR_MEMORY_RELEASE;
1357          switch (nir_intrinsic_memory_scope(intrin)) {
1358          case SCOPE_INVOCATION:
1359             /* a barier should never be required for correctness with these scopes */
1360             modes = 0;
1361             break;
1362          default:
1363             break;
1364          }
1365          break;
1366       default:
1367          return false;
1368       }
1369    } else if (instr->type == nir_instr_type_call) {
1370       modes = nir_var_all;
1371    } else {
1372       return false;
1373    }
1374 
1375    while (modes) {
1376       unsigned mode_index = u_bit_scan(&modes);
1377       if ((1 << mode_index) == nir_var_mem_global) {
1378          /* Global should be rolled in with SSBO */
1379          assert(list_is_empty(&ctx->entries[mode_index]));
1380          assert(ctx->loads[mode_index] == NULL);
1381          assert(ctx->stores[mode_index] == NULL);
1382          continue;
1383       }
1384 
1385       if (acquire)
1386          *progress |= vectorize_entries(ctx, impl, ctx->loads[mode_index]);
1387       if (release)
1388          *progress |= vectorize_entries(ctx, impl, ctx->stores[mode_index]);
1389    }
1390 
1391    return true;
1392 }
1393 
1394 static bool
process_block(nir_function_impl * impl,struct vectorize_ctx * ctx,nir_block * block)1395 process_block(nir_function_impl *impl, struct vectorize_ctx *ctx, nir_block *block)
1396 {
1397    bool progress = false;
1398 
1399    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1400       list_inithead(&ctx->entries[i]);
1401       if (ctx->loads[i])
1402          _mesa_hash_table_clear(ctx->loads[i], delete_entry_dynarray);
1403       if (ctx->stores[i])
1404          _mesa_hash_table_clear(ctx->stores[i], delete_entry_dynarray);
1405    }
1406 
1407    /* create entries */
1408    unsigned next_index = 0;
1409 
1410    nir_foreach_instr_safe(instr, block) {
1411       if (handle_barrier(ctx, &progress, impl, instr))
1412          continue;
1413 
1414       /* gather information */
1415       if (instr->type != nir_instr_type_intrinsic)
1416          continue;
1417       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
1418 
1419       const struct intrinsic_info *info = get_info(intrin->intrinsic);
1420       if (!info)
1421          continue;
1422 
1423       nir_variable_mode mode = info->mode;
1424       if (!mode)
1425          mode = nir_src_as_deref(intrin->src[info->deref_src])->modes;
1426       if (!(mode & aliasing_modes(ctx->options->modes)))
1427          continue;
1428       unsigned mode_index = mode_to_index(mode);
1429 
1430       /* create entry */
1431       struct entry *entry = create_entry(ctx, info, intrin);
1432       entry->index = next_index++;
1433 
1434       list_addtail(&entry->head, &ctx->entries[mode_index]);
1435 
1436       /* add the entry to a hash table */
1437 
1438       struct hash_table *adj_ht = NULL;
1439       if (entry->is_store) {
1440          if (!ctx->stores[mode_index])
1441             ctx->stores[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1442          adj_ht = ctx->stores[mode_index];
1443       } else {
1444          if (!ctx->loads[mode_index])
1445             ctx->loads[mode_index] = _mesa_hash_table_create(ctx, &hash_entry_key, &entry_key_equals);
1446          adj_ht = ctx->loads[mode_index];
1447       }
1448 
1449       uint32_t key_hash = hash_entry_key(entry->key);
1450       struct hash_entry *adj_entry = _mesa_hash_table_search_pre_hashed(adj_ht, key_hash, entry->key);
1451       struct util_dynarray *arr;
1452       if (adj_entry && adj_entry->data) {
1453          arr = (struct util_dynarray *)adj_entry->data;
1454       } else {
1455          arr = ralloc(ctx, struct util_dynarray);
1456          util_dynarray_init(arr, arr);
1457          _mesa_hash_table_insert_pre_hashed(adj_ht, key_hash, entry->key, arr);
1458       }
1459       util_dynarray_append(arr, struct entry *, entry);
1460    }
1461 
1462    /* sort and combine entries */
1463    for (unsigned i = 0; i < nir_num_variable_modes; i++) {
1464       progress |= vectorize_entries(ctx, impl, ctx->loads[i]);
1465       progress |= vectorize_entries(ctx, impl, ctx->stores[i]);
1466    }
1467 
1468    return progress;
1469 }
1470 
1471 bool
nir_opt_load_store_vectorize(nir_shader * shader,const nir_load_store_vectorize_options * options)1472 nir_opt_load_store_vectorize(nir_shader *shader, const nir_load_store_vectorize_options *options)
1473 {
1474    bool progress = false;
1475 
1476    struct vectorize_ctx *ctx = rzalloc(NULL, struct vectorize_ctx);
1477    ctx->shader = shader;
1478    ctx->options = options;
1479 
1480    nir_shader_index_vars(shader, options->modes);
1481 
1482    nir_foreach_function_impl(impl, shader) {
1483       if (options->modes & nir_var_function_temp)
1484          nir_function_impl_index_vars(impl);
1485 
1486       nir_foreach_block(block, impl)
1487          progress |= process_block(impl, ctx, block);
1488 
1489       nir_metadata_preserve(impl,
1490                             nir_metadata_control_flow |
1491                             nir_metadata_live_defs);
1492    }
1493 
1494    ralloc_free(ctx);
1495    return progress;
1496 }
1497 
1498 static bool
opt_load_store_update_alignments_callback(struct nir_builder * b,nir_intrinsic_instr * intrin,UNUSED void * s)1499 opt_load_store_update_alignments_callback(struct nir_builder *b,
1500                                           nir_intrinsic_instr *intrin,
1501                                           UNUSED void *s)
1502 {
1503    if (!nir_intrinsic_has_align_mul(intrin))
1504       return false;
1505 
1506    const struct intrinsic_info *info = get_info(intrin->intrinsic);
1507    if (!info)
1508       return false;
1509 
1510    struct entry *entry = create_entry(NULL, info, intrin);
1511    const bool progress = update_align(entry);
1512    ralloc_free(entry);
1513 
1514    return progress;
1515 }
1516 
1517 bool
nir_opt_load_store_update_alignments(nir_shader * shader)1518 nir_opt_load_store_update_alignments(nir_shader *shader)
1519 {
1520    return nir_shader_intrinsics_pass(shader,
1521                                      opt_load_store_update_alignments_callback,
1522                                      nir_metadata_control_flow |
1523                                      nir_metadata_live_defs |
1524                                      nir_metadata_instr_index, NULL);
1525 }
1526