xref: /aosp_15_r20/external/mesa3d/src/intel/compiler/brw_nir_lower_cooperative_matrix.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Intel Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 /**
7  * \file brw_nir_lower_cooperative_matrix.c
8  * Lower cooperative matrix to subgroup operations.
9  *
10  * All supported matrix types are assumed to have either 8 rows or 8
11  * columns. The other dimension of the matrix is typically 8 times the number
12  * of data elements that can be stored in a 32-bit dword. Matrix data is
13  * indexed by a combination of an array element and a subgroup invocation ID.
14  *
15  * Two layouts for matrix data are used. In the first layout,
16  * subgroupShuffle(slice[N], ...) accesses row N of the matrix. This will be
17  * called row-major hereafter. In the other layout,
18  * subgroupShuffle(slice[...], M) accesses column M of the matrix. This will
19  * be called column-major hereafter. In cases where a single 32-bit value is
20  * stored in each entry, these layouts are identical.
21  *
22  * The subtle difference arises when multiple values are packed into a single
23  * 32-bit dword. If two 16-bit values are packed in a single 32-bit value in
24  * column-major, subgroupShuffle(slice[0], 1) holds matrix entries m[1][1] and
25  * m[2][1] (in m[row][column] notation). In row-major, that same shuffle holds
26  * m[0][2] and m[0][3].
27  *
28  * There is an alternate way to think about the matrix layouts. Every matrix
29  * size supported by the Intel driver is either Sx8 (e.g., 16x8 for float16 B
30  * matrix) or Sx8T (e.g., 8x32 for int8 A matrix). The A matrix and B matrix
31  * layouts are such that a single 8 dword register hold an entire row of the
32  * matrix.
33  *
34  * Consider a matrix stored starting in register g32. In an A matrix, the
35  * packed dwords of g32 contain only the data for a single row of the
36  * matrix. g32 is row 0, g33 is row 1, etc. In a B matrix, the packed dwords
37  * of g(32+N).X contain only the data for a single column of the
38  * matrix. g[32:40].0 is column 0, g[32:40].1 is column 1, etc.
39  *
40  * This leads to some shenanigans in \c lower_cmat_load_store.
41  *
42  * In the common case, A, C, and result matrices are stored row major while B
43  * matrices are stored column major. This arrangement facilitates efficient
44  * dot product operations using DPAS or DP4A instructions.
45  *
46  * Future optimizations are possible when row and column major are
47  * flipped. That is, efficient dot products are also possible when A, C, and
48  * result matrices are column major while B is row major.
49  */
50 
51 #include "brw_nir.h"
52 
53 struct lower_cmat_state {
54    nir_shader *shader;
55 
56    struct hash_table *slice_coop_types;
57 
58    struct hash_table *vars_to_slice;
59 
60    unsigned subgroup_size;
61 };
62 
63 static void
print_coop_types(struct lower_cmat_state * state)64 print_coop_types(struct lower_cmat_state *state)
65 {
66    fprintf(stderr, "--- Slices to Cooperative Matrix type table\n");
67    hash_table_foreach(state->slice_coop_types, e) {
68       nir_variable *var = (void *)e->key;
69       const struct glsl_type *t = e->data;
70       fprintf(stderr, "%p: %s -> %s\n", var, var->name, glsl_get_type_name(t));
71    }
72    fprintf(stderr, "\n\n");
73 }
74 
75 static const struct glsl_type *
get_coop_type_for_slice(struct lower_cmat_state * state,nir_deref_instr * deref)76 get_coop_type_for_slice(struct lower_cmat_state *state, nir_deref_instr *deref)
77 {
78    nir_variable *var = nir_deref_instr_get_variable(deref);
79    struct hash_entry *entry = _mesa_hash_table_search(state->slice_coop_types, var);
80 
81    assert(entry != NULL);
82 
83    return entry->data;
84 }
85 
86 static bool
lower_cmat_filter(const nir_instr * instr,const void * _state)87 lower_cmat_filter(const nir_instr *instr, const void *_state)
88 {
89    if (instr->type == nir_instr_type_deref) {
90       nir_deref_instr *deref = nir_instr_as_deref(instr);
91       return glsl_type_is_cmat(deref->type);
92    }
93 
94    if (instr->type != nir_instr_type_intrinsic)
95       return false;
96 
97    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
98    switch (intrin->intrinsic) {
99    case nir_intrinsic_cmat_construct:
100    case nir_intrinsic_cmat_load:
101    case nir_intrinsic_cmat_store:
102    case nir_intrinsic_cmat_length:
103    case nir_intrinsic_cmat_muladd:
104    case nir_intrinsic_cmat_unary_op:
105    case nir_intrinsic_cmat_binary_op:
106    case nir_intrinsic_cmat_scalar_op:
107    case nir_intrinsic_cmat_bitcast:
108    case nir_intrinsic_cmat_insert:
109    case nir_intrinsic_cmat_extract:
110    case nir_intrinsic_cmat_copy:
111       return true;
112 
113    default:
114       return false;
115    }
116 }
117 
118 /**
119  * Get number of matrix elements packed in each component of the slice.
120  */
121 static unsigned
get_packing_factor(const struct glsl_cmat_description desc,const struct glsl_type * slice_type)122 get_packing_factor(const struct glsl_cmat_description desc,
123                    const struct glsl_type *slice_type)
124 {
125    const struct glsl_type *slice_element_type = glsl_without_array(slice_type);
126 
127    assert(!glsl_type_is_cmat(slice_type));
128 
129    assert(glsl_get_bit_size(slice_element_type) >= glsl_base_type_get_bit_size(desc.element_type));
130    assert(glsl_get_bit_size(slice_element_type) % glsl_base_type_get_bit_size(desc.element_type) == 0);
131 
132    return glsl_get_bit_size(slice_element_type) / glsl_base_type_get_bit_size(desc.element_type);
133 }
134 
135 static const struct glsl_type *
get_slice_type_from_desc(const struct lower_cmat_state * state,const struct glsl_cmat_description desc)136 get_slice_type_from_desc(const struct lower_cmat_state *state,
137                          const struct glsl_cmat_description desc)
138 {
139    enum glsl_base_type base_type;
140 
141    /* Number of matrix elements stored by each subgroup invocation. If the
142     * data is packed, the slice size will be less than this.
143     */
144    const unsigned elements_per_invocation =
145       (desc.rows * desc.cols) / state->subgroup_size;
146 
147    assert(elements_per_invocation > 0);
148 
149    const unsigned element_bits = 32;
150    const unsigned bits = glsl_base_type_get_bit_size(desc.element_type);
151 
152    /* Each invocation must have at least one dword of data, and that dword
153     * must be tightly packed with values. No matter the matrix dimensions, a
154     * matrix of uint8_t data must pack 4 values in each entry.
155     */
156    const unsigned packing_factor = element_bits / bits;
157 
158    assert(elements_per_invocation >= packing_factor);
159 
160    switch (desc.element_type) {
161    case GLSL_TYPE_FLOAT:
162       base_type = GLSL_TYPE_FLOAT;
163       break;
164    case GLSL_TYPE_UINT:
165    case GLSL_TYPE_FLOAT16:
166    case GLSL_TYPE_UINT8:
167    case GLSL_TYPE_UINT16:
168       base_type = GLSL_TYPE_UINT;
169       break;
170    case GLSL_TYPE_INT:
171    case GLSL_TYPE_INT8:
172    case GLSL_TYPE_INT16:
173       base_type = GLSL_TYPE_INT;
174       break;
175    default:
176       unreachable("Invalid cooperative matrix element type.");
177    }
178 
179    unsigned len = elements_per_invocation / packing_factor;
180 
181    /* Supported matrix sizes are designed to fill either 4 or 8 SIMD8
182     * registers on DG2. That means:
183     *
184     *          4 regsiters   8 registers
185     * SIMD32     len = 1       len = 2
186     * SIMD16     len = 2       len = 4
187     * SIMD8      len = 4       len = 8
188     *
189     * On Xe2, supported matrix sizes are still designed to fill 4 registers
190     * (e.g., 8x32 uint8_t) or 8 registers (e.g., 16x16 float16). However, the
191     * 16x16 float16 matrix will assign 16 elements per channel at SIMD16.
192     */
193    assert(len == 1 || len == 2 || len == 4 || len == 8 || len == 16);
194 
195    const struct glsl_type *slice_type = glsl_vector_type(base_type, len);
196 
197    assert(packing_factor == get_packing_factor(desc, slice_type));
198 
199    return slice_type;
200 }
201 
202 static const struct glsl_type *
get_slice_type(const struct lower_cmat_state * state,const struct glsl_type * type)203 get_slice_type(const struct lower_cmat_state *state,
204                const struct glsl_type *type)
205 {
206    if (glsl_type_is_array(type)) {
207       const struct glsl_type *slice_type =
208          get_slice_type(state, glsl_get_array_element(type));
209 
210       return glsl_array_type(slice_type, glsl_array_size(type), 0);
211    }
212 
213    assert(glsl_type_is_cmat(type));
214 
215    return get_slice_type_from_desc(state,
216                                    *glsl_get_cmat_description(type));
217 }
218 
219 static nir_deref_instr *
create_local_slice(struct lower_cmat_state * state,nir_builder * b,const struct glsl_type * mat_type,const char * name)220 create_local_slice(struct lower_cmat_state *state, nir_builder *b,
221                    const struct glsl_type *mat_type, const char *name)
222 {
223    const struct glsl_type *slice_type = get_slice_type(state, mat_type);
224    nir_variable *slice_var = nir_local_variable_create(b->impl, slice_type, name);
225    _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type);
226    return nir_build_deref_var(b, slice_var);
227 }
228 
229 static void
lower_cmat_load_store(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)230 lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin,
231                       struct lower_cmat_state *state)
232 {
233    const bool load = intrin->intrinsic == nir_intrinsic_cmat_load;
234    const unsigned mat_src = load ? 0 : 1;
235    const unsigned ptr_src = load ? 1 : 0;
236 
237    nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]);
238    const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
239    const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type);
240 
241    nir_def *results[NIR_MAX_VEC_COMPONENTS];
242    const unsigned num_components = glsl_get_vector_elements(slice->type);
243    const unsigned packing_factor = get_packing_factor(*desc, slice->type);
244 
245    nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]);
246    const unsigned ptr_comp_width = glsl_get_bit_size(pointer->type);
247    const unsigned ptr_num_comps = glsl_get_vector_elements(pointer->type);
248 
249    /* The stride is given in number of elements of the pointed type, which
250     * doesn't necessarily match the matrix element type, so we need to adjust
251     * it considering it may be a vector and have a different bit-width.
252     */
253    nir_def *stride = nir_udiv_imm(b,
254                                   nir_imul_imm(b,
255                                                intrin->src[2].ssa,
256                                                ptr_comp_width * ptr_num_comps),
257                                   glsl_base_type_get_bit_size(desc->element_type));
258 
259    /* The data that will be packed is in successive columns for A and
260     * accumulator matrices. The data that will be packed for B matrices is in
261     * successive rows.
262     */
263    const unsigned cols =
264       desc->use != GLSL_CMAT_USE_B ? desc->cols / packing_factor : desc->cols;
265 
266    nir_def *invocation = nir_load_subgroup_invocation(b);
267    nir_def *invocation_div_cols = nir_udiv_imm(b, invocation, cols);
268    nir_def *invocation_mod_cols = nir_umod_imm(b, invocation, cols);
269 
270    nir_def *i_stride;
271 
272    const bool memory_layout_matches_register_layout =
273       (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) ==
274       (desc->use != GLSL_CMAT_USE_B);
275 
276    if (memory_layout_matches_register_layout) {
277       /* In the row-major arrangement, data is loaded a dword at a time
278        * instead of a single element at a time. For this reason the stride is
279        * divided by the packing factor.
280        */
281       i_stride = nir_udiv_imm(b, stride, packing_factor);
282    } else {
283       /* In the column-major arrangement, data is loaded a single element at a
284        * time. Because the data elements are transposed, the step direction
285        * that moves a single (packed) element in the row-major arrangement has
286        * to explicitly step over the packing factor count of elements. For
287        * this reason the stride is multiplied by the packing factor.
288        *
289        * NOTE: The unscaled stride is also still needed when stepping from one
290        * packed element to the next. This occurs in the for-j loop below.
291        */
292       i_stride = nir_imul_imm(b, stride, packing_factor);
293    }
294 
295    nir_def *base_offset;
296    nir_def *i_step;
297 
298    if (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
299       base_offset = nir_iadd(b,
300                              nir_imul(b,
301                                       invocation_div_cols,
302                                       i_stride),
303                              invocation_mod_cols);
304 
305       i_step = nir_imul_imm(b, i_stride, state->subgroup_size / cols);
306    } else {
307       base_offset = nir_iadd(b,
308                              nir_imul(b,
309                                       invocation_mod_cols,
310                                       i_stride),
311                              invocation_div_cols);
312 
313       i_step = nir_imm_int(b, state->subgroup_size / cols);
314    }
315 
316    if (memory_layout_matches_register_layout) {
317       const struct glsl_type *element_type =
318          glsl_scalar_type(glsl_get_base_type(slice->type));
319 
320       pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes,
321                                      element_type,
322                                      glsl_get_bit_size(element_type) / 8);
323 
324       for (unsigned i = 0; i < num_components; i++) {
325          nir_def *offset = nir_imul_imm(b, i_step, i);
326 
327          nir_deref_instr *memory_deref =
328             nir_build_deref_ptr_as_array(b, pointer,
329                                          nir_i2iN(b,
330                                                   nir_iadd(b,
331                                                            base_offset,
332                                                            offset),
333                                                   pointer->def.bit_size));
334 
335          if (load) {
336             results[i] = nir_load_deref(b, memory_deref);
337          } else {
338             nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
339             nir_store_deref(b, memory_deref, src, 0x1);
340          }
341       }
342    } else {
343       const struct glsl_type *element_type = glsl_scalar_type(desc->element_type);
344       const unsigned element_bits = glsl_base_type_get_bit_size(desc->element_type);
345       const unsigned element_stride = element_bits / 8;
346 
347       pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type,
348                                      element_stride);
349 
350       for (unsigned i = 0; i < num_components; i++) {
351          nir_def *i_offset = nir_imul_imm(b, i_step, i);
352          nir_def *v[4];
353 
354          for (unsigned j = 0; j < packing_factor; j++) {
355             nir_def *offset = nir_iadd(b, nir_imul_imm(b, stride, j), i_offset);
356 
357             nir_deref_instr *memory_deref =
358                nir_build_deref_ptr_as_array(b, pointer,
359                                             nir_i2iN(b,
360                                                      nir_iadd(b,
361                                                               base_offset,
362                                                               offset),
363                                                      pointer->def.bit_size));
364 
365             if (load) {
366                v[j] = nir_load_deref(b, memory_deref);
367             } else {
368                nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
369 
370                nir_def *v =
371                   nir_channel(b, nir_unpack_bits(b, src, element_bits), j);
372 
373                nir_store_deref(b, memory_deref, v, 0x1);
374             }
375          }
376 
377          if (load) {
378             results[i] = nir_pack_bits(b, nir_vec(b, v, packing_factor),
379                                        packing_factor * element_bits);
380          }
381       }
382    }
383 
384    if (load)
385       nir_store_deref(b, slice, nir_vec(b, results, num_components),
386                       nir_component_mask(num_components));
387 }
388 
389 static void
lower_cmat_unary_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)390 lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
391                     struct lower_cmat_state *state)
392 {
393    nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
394    nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
395    nir_def *results[NIR_MAX_VEC_COMPONENTS];
396    const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
397 
398    const struct glsl_type *dst_mat_type =
399       get_coop_type_for_slice(state, dst_slice);
400    const struct glsl_type *src_mat_type =
401       get_coop_type_for_slice(state, src_slice);
402 
403    const struct glsl_cmat_description dst_desc =
404       *glsl_get_cmat_description(dst_mat_type);
405 
406    const struct glsl_cmat_description src_desc =
407       *glsl_get_cmat_description(src_mat_type);
408 
409    const unsigned dst_bits = glsl_base_type_bit_size(dst_desc.element_type);
410    const unsigned src_bits = glsl_base_type_bit_size(src_desc.element_type);
411 
412    /* The type of the returned slice may be different from the type of the
413     * input slice.
414     */
415    const unsigned dst_packing_factor =
416       get_packing_factor(dst_desc, dst_slice->type);
417 
418    const unsigned src_packing_factor =
419       get_packing_factor(src_desc, src_slice->type);
420 
421    const nir_op op = nir_intrinsic_alu_op(intrin);
422 
423    /* With the combinations of formats exposed on all platforms, matrices with
424     * the same dimensions will always have the same data size. The only real
425     * type conversion possible is int32 <-> float32. As a result
426     * dst_packing_factor == src_packing_factor.
427     */
428    assert(dst_packing_factor == src_packing_factor);
429 
430    /* Stores at most dst_packing_factor partial results. */
431    nir_def *v[4];
432    assert(dst_packing_factor <= 4);
433 
434    for (unsigned i = 0; i < num_components; i++) {
435       nir_def *chan = nir_channel(b, nir_load_deref(b, src_slice), i);
436 
437       for (unsigned j = 0; j < dst_packing_factor; j++) {
438          nir_def *src =
439             nir_channel(b, nir_unpack_bits(b, chan, src_bits), j);
440 
441          v[j] = nir_build_alu1(b, op, src);
442       }
443 
444       results[i] =
445          nir_pack_bits(b, nir_vec(b, v, dst_packing_factor),
446                        dst_packing_factor * dst_bits);
447    }
448 
449    nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
450                    nir_component_mask(num_components));
451 }
452 
453 static void
lower_cmat_binary_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)454 lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin,
455                      struct lower_cmat_state *state)
456 {
457    nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
458    nir_deref_instr *src_a_slice = nir_src_as_deref(intrin->src[1]);
459    nir_deref_instr *src_b_slice = nir_src_as_deref(intrin->src[2]);
460 
461    nir_def *src_a = nir_load_deref(b, src_a_slice);
462    nir_def *src_b = nir_load_deref(b, src_b_slice);
463    nir_def *results[NIR_MAX_VEC_COMPONENTS];
464    const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
465 
466    const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
467    ASSERTED const struct glsl_type *src_a_mat_type = get_coop_type_for_slice(state, src_a_slice);
468    ASSERTED const struct glsl_type *src_b_mat_type = get_coop_type_for_slice(state, src_b_slice);
469 
470    const struct glsl_cmat_description desc =
471       *glsl_get_cmat_description(dst_mat_type);
472 
473    assert(dst_mat_type == src_a_mat_type);
474    assert(dst_mat_type == src_b_mat_type);
475 
476    const unsigned bits = glsl_base_type_bit_size(desc.element_type);
477    const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
478 
479    for (unsigned i = 0; i < num_components; i++) {
480       nir_def *val_a = nir_channel(b, src_a, i);
481       nir_def *val_b = nir_channel(b, src_b, i);
482 
483       results[i] =
484          nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin),
485                                          nir_unpack_bits(b, val_a, bits),
486                                          nir_unpack_bits(b, val_b, bits)),
487                        packing_factor * bits);
488    }
489 
490    nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
491                    nir_component_mask(num_components));
492 }
493 
494 static void
lower_cmat_scalar_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)495 lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin,
496                      struct lower_cmat_state *state)
497 {
498    nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
499    nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
500    nir_def *scalar = intrin->src[2].ssa;
501 
502    nir_def *src = nir_load_deref(b, src_slice);
503    nir_def *results[NIR_MAX_VEC_COMPONENTS];
504    const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
505 
506    ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
507    ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice);
508    assert(dst_mat_type == src_mat_type);
509 
510    const struct glsl_cmat_description desc =
511       *glsl_get_cmat_description(dst_mat_type);
512 
513    const unsigned bits = glsl_base_type_bit_size(desc.element_type);
514    const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
515 
516    for (unsigned i = 0; i < num_components; i++) {
517       nir_def *val = nir_channel(b, src, i);
518 
519       results[i] =
520          nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin),
521                                          nir_unpack_bits(b, val, bits),
522                                          scalar),
523                        packing_factor * bits);
524    }
525 
526    nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
527                    nir_component_mask(num_components));
528 }
529 
530 static nir_deref_instr *
lower_cmat_deref(nir_builder * b,nir_deref_instr * deref,struct lower_cmat_state * state)531 lower_cmat_deref(nir_builder *b, nir_deref_instr *deref,
532                  struct lower_cmat_state *state)
533 {
534    nir_deref_instr *parent = nir_deref_instr_parent(deref);
535    if (parent) {
536       assert(deref->deref_type == nir_deref_type_array);
537       parent = lower_cmat_deref(b, parent, state);
538       return nir_build_deref_array(b, parent, deref->arr.index.ssa);
539    } else {
540       assert(deref->deref_type == nir_deref_type_var);
541       assert(deref->var);
542       assert(glsl_type_is_cmat(glsl_without_array(deref->var->type)));
543 
544       struct hash_entry *entry = _mesa_hash_table_search(state->vars_to_slice, deref->var);
545       assert(entry);
546       return nir_build_deref_var(b, (nir_variable *)entry->data);
547    }
548 }
549 
550 static nir_def *
lower_cmat_instr(nir_builder * b,nir_instr * instr,void * _state)551 lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
552 {
553    struct lower_cmat_state *state = _state;
554 
555    if (instr->type == nir_instr_type_deref) {
556       nir_deref_instr *deref = lower_cmat_deref(b, nir_instr_as_deref(instr), state);
557       return &deref->def;
558    }
559 
560    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
561    switch (intrin->intrinsic) {
562    case nir_intrinsic_cmat_load:
563    case nir_intrinsic_cmat_store:
564       lower_cmat_load_store(b, intrin, state);
565       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
566 
567    case nir_intrinsic_cmat_construct: {
568       nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
569       nir_def *src = intrin->src[1].ssa;
570 
571       const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
572       const struct glsl_cmat_description desc =
573          *glsl_get_cmat_description(mat_type);
574       const unsigned packing_factor = get_packing_factor(desc, slice->type);
575 
576       if (packing_factor > 1) {
577          src = nir_pack_bits(b, nir_replicate(b, src, packing_factor),
578                              packing_factor * glsl_base_type_get_bit_size(desc.element_type));
579       }
580 
581       const unsigned num_components = glsl_get_vector_elements(slice->type);
582 
583       nir_store_deref(b, slice, nir_replicate(b, src, num_components),
584                       nir_component_mask(num_components));
585       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
586    }
587 
588    case nir_intrinsic_cmat_unary_op:
589       lower_cmat_unary_op(b, intrin, state);
590       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
591 
592    case nir_intrinsic_cmat_binary_op:
593       lower_cmat_binary_op(b, intrin, state);
594       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
595 
596    case nir_intrinsic_cmat_scalar_op:
597       lower_cmat_scalar_op(b, intrin, state);
598       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
599 
600    case nir_intrinsic_cmat_length: {
601       const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin);
602       const struct glsl_type *mat_type = glsl_cmat_type(&desc);
603       const struct glsl_type *slice_type = get_slice_type(state, mat_type);
604       return nir_imm_intN_t(b, (get_packing_factor(desc, slice_type) *
605                                 glsl_get_vector_elements(slice_type)), 32);
606    }
607 
608    case nir_intrinsic_cmat_muladd: {
609       nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
610       nir_deref_instr *A_slice = nir_src_as_deref(intrin->src[1]);
611       nir_deref_instr *B_slice = nir_src_as_deref(intrin->src[2]);
612       nir_deref_instr *accum_slice = nir_src_as_deref(intrin->src[3]);
613 
614       const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
615       const struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_mat_type);
616 
617       const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, A_slice);
618       const struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_mat_type);
619 
620       const unsigned packing_factor = get_packing_factor(dst_desc, dst_slice->type);
621       const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
622 
623       const nir_cmat_signed cmat_signed_mask =
624          nir_intrinsic_cmat_signed_mask(intrin);
625 
626       assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) ==
627              ((cmat_signed_mask & NIR_CMAT_B_SIGNED) == 0));
628       assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) ==
629              ((cmat_signed_mask & NIR_CMAT_C_SIGNED) == 0));
630       assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) ==
631              ((cmat_signed_mask & NIR_CMAT_RESULT_SIGNED) == 0));
632 
633       nir_alu_type src_type =
634          nir_get_nir_type_for_glsl_base_type(src_desc.element_type);
635       nir_alu_type dest_type =
636          nir_get_nir_type_for_glsl_base_type(dst_desc.element_type);
637 
638       /* For integer types, the signedness is determined by flags on the
639        * muladd instruction. The types of the sources play no role. Adjust the
640        * types passed to the dpas_intel intrinsic to match.
641        */
642       if (nir_alu_type_get_base_type(src_type) == nir_type_uint ||
643           nir_alu_type_get_base_type(src_type) == nir_type_int) {
644          if ((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) {
645             src_type = nir_alu_type_get_type_size(src_type) | nir_type_uint;
646             dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_uint;
647          } else {
648             src_type = nir_alu_type_get_type_size(src_type) | nir_type_int;
649             dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_int;
650          }
651       }
652 
653       nir_def *result =
654          nir_dpas_intel(b,
655                         packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type),
656                         nir_load_deref(b, accum_slice),
657                         nir_load_deref(b, A_slice),
658                         nir_load_deref(b, B_slice),
659                         .dest_type = dest_type,
660                         .src_type = src_type,
661                         .saturate = nir_intrinsic_saturate(intrin),
662                         .systolic_depth = 8,
663                         .repeat_count = 8);
664 
665       nir_store_deref(b, dst_slice, result,
666                       nir_component_mask(num_components));
667 
668       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
669    }
670 
671    case nir_intrinsic_cmat_bitcast: {
672       nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
673       nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
674 
675       const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
676 
677       assert(glsl_get_vector_elements(src_slice->type) == num_components);
678 
679       nir_store_deref(b, dst_slice, nir_load_deref(b, src_slice),
680                       nir_component_mask(num_components));
681       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
682    }
683 
684    case nir_intrinsic_cmat_copy:
685       nir_copy_deref(b,
686                      nir_src_as_deref(intrin->src[0]),
687                      nir_src_as_deref(intrin->src[1]));
688       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
689 
690    case nir_intrinsic_cmat_insert: {
691       nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
692       nir_def *scalar = intrin->src[1].ssa;
693       nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[2]);
694       const nir_src dst_index = intrin->src[3];
695 
696       const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
697       ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice);
698       assert(dst_mat_type == src_mat_type);
699 
700       const struct glsl_cmat_description desc =
701          *glsl_get_cmat_description(dst_mat_type);
702 
703       const unsigned bits = glsl_base_type_bit_size(desc.element_type);
704       const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
705       const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
706 
707       nir_def *slice_index = nir_udiv_imm(b, dst_index.ssa, packing_factor);
708       nir_def *vector_index = nir_umod_imm(b, dst_index.ssa, packing_factor);
709       nir_def *results[NIR_MAX_VEC_COMPONENTS];
710 
711       const int slice_constant_index = nir_src_is_const(dst_index)
712          ? nir_src_as_uint(dst_index) / packing_factor
713          : -1;
714 
715       for (unsigned i = 0; i < num_components; i++) {
716          nir_def *val = nir_channel(b, nir_load_deref(b, src_slice), i);
717          nir_def *insert;
718 
719          if (slice_constant_index < 0 || slice_constant_index == i) {
720             if (packing_factor == 1) {
721                insert = scalar;
722             } else {
723                nir_def *unpacked = nir_unpack_bits(b, val, bits);
724                nir_def *v = nir_vector_insert(b, unpacked, scalar, vector_index);
725 
726                insert = nir_pack_bits(b, v, bits * packing_factor);
727             }
728          } else {
729             insert = val;
730          }
731 
732          results[i] = slice_constant_index < 0
733             ? nir_bcsel(b, nir_ieq_imm(b, slice_index, i), insert, val)
734             : insert;
735       }
736 
737       nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
738                       nir_component_mask(num_components));
739 
740       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
741    }
742 
743    case nir_intrinsic_cmat_extract: {
744       nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
745       const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
746       nir_def *index = intrin->src[1].ssa;
747 
748       const struct glsl_cmat_description desc =
749          *glsl_get_cmat_description(mat_type);
750 
751       const unsigned bits = glsl_base_type_bit_size(desc.element_type);
752       const unsigned packing_factor = get_packing_factor(desc, slice->type);
753 
754       nir_def *src =
755          nir_vector_extract(b, nir_load_deref(b, slice),
756                             nir_udiv_imm(b, index, packing_factor));
757 
758       if (packing_factor == 1) {
759          return src;
760       } else {
761          return nir_vector_extract(b,
762                                    nir_unpack_bits(b, src, bits),
763                                    nir_umod_imm(b, index, packing_factor));
764       }
765 
766       return NIR_LOWER_INSTR_PROGRESS_REPLACE;
767    }
768 
769    default:
770       unreachable("invalid cooperative matrix intrinsic");
771    }
772 }
773 
774 static void
create_slice_var(struct lower_cmat_state * state,nir_variable * var,nir_function_impl * impl)775 create_slice_var(struct lower_cmat_state *state, nir_variable *var,
776                  nir_function_impl *impl)
777 {
778    // TODO: without array
779    const struct glsl_type *mat_type = glsl_without_array(var->type);
780 
781    assert(glsl_type_is_cmat(mat_type));
782    assert((!impl && var->data.mode == nir_var_shader_temp) ||
783           ( impl && var->data.mode == nir_var_function_temp));
784 
785    const struct glsl_type *slice_type = get_slice_type(state, var->type);
786    const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name);
787    nir_variable *slice_var = impl ?
788       nir_local_variable_create(impl, slice_type, slice_name) :
789       nir_variable_create(state->shader, var->data.mode, slice_type, slice_name);
790 
791    _mesa_hash_table_insert(state->vars_to_slice, var, slice_var);
792    _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type);
793 }
794 
795 bool
brw_nir_lower_cmat(nir_shader * shader,unsigned subgroup_size)796 brw_nir_lower_cmat(nir_shader *shader, unsigned subgroup_size)
797 {
798    void *temp_ctx = ralloc_context(NULL);
799 
800    struct lower_cmat_state state = {
801       .shader = shader,
802       .slice_coop_types = _mesa_pointer_hash_table_create(temp_ctx),
803       .vars_to_slice = _mesa_pointer_hash_table_create(temp_ctx),
804       .subgroup_size = subgroup_size,
805    };
806 
807    /* Create a slice array for each variable and add a map from the original
808     * variable back to it, so it can be reached during lowering.
809     *
810     * TODO: Cooperative matrix inside struct?
811     */
812    nir_foreach_variable_in_shader(var, shader) {
813       if (glsl_type_is_cmat(glsl_without_array(var->type)))
814          create_slice_var(&state, var, NULL);
815    }
816    nir_foreach_function(func, shader) {
817       nir_foreach_function_temp_variable(var, func->impl) {
818          if (glsl_type_is_cmat(glsl_without_array(var->type)))
819             create_slice_var(&state, var, func->impl);
820       }
821    }
822 
823    bool progress = nir_shader_lower_instructions(shader,
824                                                  lower_cmat_filter,
825                                                  lower_cmat_instr,
826                                                  &state);
827 
828    ralloc_free(temp_ctx);
829 
830    return progress;
831 }
832