1 /*
2 * Copyright 2023 Intel Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 /**
7 * \file brw_nir_lower_cooperative_matrix.c
8 * Lower cooperative matrix to subgroup operations.
9 *
10 * All supported matrix types are assumed to have either 8 rows or 8
11 * columns. The other dimension of the matrix is typically 8 times the number
12 * of data elements that can be stored in a 32-bit dword. Matrix data is
13 * indexed by a combination of an array element and a subgroup invocation ID.
14 *
15 * Two layouts for matrix data are used. In the first layout,
16 * subgroupShuffle(slice[N], ...) accesses row N of the matrix. This will be
17 * called row-major hereafter. In the other layout,
18 * subgroupShuffle(slice[...], M) accesses column M of the matrix. This will
19 * be called column-major hereafter. In cases where a single 32-bit value is
20 * stored in each entry, these layouts are identical.
21 *
22 * The subtle difference arises when multiple values are packed into a single
23 * 32-bit dword. If two 16-bit values are packed in a single 32-bit value in
24 * column-major, subgroupShuffle(slice[0], 1) holds matrix entries m[1][1] and
25 * m[2][1] (in m[row][column] notation). In row-major, that same shuffle holds
26 * m[0][2] and m[0][3].
27 *
28 * There is an alternate way to think about the matrix layouts. Every matrix
29 * size supported by the Intel driver is either Sx8 (e.g., 16x8 for float16 B
30 * matrix) or Sx8T (e.g., 8x32 for int8 A matrix). The A matrix and B matrix
31 * layouts are such that a single 8 dword register hold an entire row of the
32 * matrix.
33 *
34 * Consider a matrix stored starting in register g32. In an A matrix, the
35 * packed dwords of g32 contain only the data for a single row of the
36 * matrix. g32 is row 0, g33 is row 1, etc. In a B matrix, the packed dwords
37 * of g(32+N).X contain only the data for a single column of the
38 * matrix. g[32:40].0 is column 0, g[32:40].1 is column 1, etc.
39 *
40 * This leads to some shenanigans in \c lower_cmat_load_store.
41 *
42 * In the common case, A, C, and result matrices are stored row major while B
43 * matrices are stored column major. This arrangement facilitates efficient
44 * dot product operations using DPAS or DP4A instructions.
45 *
46 * Future optimizations are possible when row and column major are
47 * flipped. That is, efficient dot products are also possible when A, C, and
48 * result matrices are column major while B is row major.
49 */
50
51 #include "brw_nir.h"
52
53 struct lower_cmat_state {
54 nir_shader *shader;
55
56 struct hash_table *slice_coop_types;
57
58 struct hash_table *vars_to_slice;
59
60 unsigned subgroup_size;
61 };
62
63 static void
print_coop_types(struct lower_cmat_state * state)64 print_coop_types(struct lower_cmat_state *state)
65 {
66 fprintf(stderr, "--- Slices to Cooperative Matrix type table\n");
67 hash_table_foreach(state->slice_coop_types, e) {
68 nir_variable *var = (void *)e->key;
69 const struct glsl_type *t = e->data;
70 fprintf(stderr, "%p: %s -> %s\n", var, var->name, glsl_get_type_name(t));
71 }
72 fprintf(stderr, "\n\n");
73 }
74
75 static const struct glsl_type *
get_coop_type_for_slice(struct lower_cmat_state * state,nir_deref_instr * deref)76 get_coop_type_for_slice(struct lower_cmat_state *state, nir_deref_instr *deref)
77 {
78 nir_variable *var = nir_deref_instr_get_variable(deref);
79 struct hash_entry *entry = _mesa_hash_table_search(state->slice_coop_types, var);
80
81 assert(entry != NULL);
82
83 return entry->data;
84 }
85
86 static bool
lower_cmat_filter(const nir_instr * instr,const void * _state)87 lower_cmat_filter(const nir_instr *instr, const void *_state)
88 {
89 if (instr->type == nir_instr_type_deref) {
90 nir_deref_instr *deref = nir_instr_as_deref(instr);
91 return glsl_type_is_cmat(deref->type);
92 }
93
94 if (instr->type != nir_instr_type_intrinsic)
95 return false;
96
97 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
98 switch (intrin->intrinsic) {
99 case nir_intrinsic_cmat_construct:
100 case nir_intrinsic_cmat_load:
101 case nir_intrinsic_cmat_store:
102 case nir_intrinsic_cmat_length:
103 case nir_intrinsic_cmat_muladd:
104 case nir_intrinsic_cmat_unary_op:
105 case nir_intrinsic_cmat_binary_op:
106 case nir_intrinsic_cmat_scalar_op:
107 case nir_intrinsic_cmat_bitcast:
108 case nir_intrinsic_cmat_insert:
109 case nir_intrinsic_cmat_extract:
110 case nir_intrinsic_cmat_copy:
111 return true;
112
113 default:
114 return false;
115 }
116 }
117
118 /**
119 * Get number of matrix elements packed in each component of the slice.
120 */
121 static unsigned
get_packing_factor(const struct glsl_cmat_description desc,const struct glsl_type * slice_type)122 get_packing_factor(const struct glsl_cmat_description desc,
123 const struct glsl_type *slice_type)
124 {
125 const struct glsl_type *slice_element_type = glsl_without_array(slice_type);
126
127 assert(!glsl_type_is_cmat(slice_type));
128
129 assert(glsl_get_bit_size(slice_element_type) >= glsl_base_type_get_bit_size(desc.element_type));
130 assert(glsl_get_bit_size(slice_element_type) % glsl_base_type_get_bit_size(desc.element_type) == 0);
131
132 return glsl_get_bit_size(slice_element_type) / glsl_base_type_get_bit_size(desc.element_type);
133 }
134
135 static const struct glsl_type *
get_slice_type_from_desc(const struct lower_cmat_state * state,const struct glsl_cmat_description desc)136 get_slice_type_from_desc(const struct lower_cmat_state *state,
137 const struct glsl_cmat_description desc)
138 {
139 enum glsl_base_type base_type;
140
141 /* Number of matrix elements stored by each subgroup invocation. If the
142 * data is packed, the slice size will be less than this.
143 */
144 const unsigned elements_per_invocation =
145 (desc.rows * desc.cols) / state->subgroup_size;
146
147 assert(elements_per_invocation > 0);
148
149 const unsigned element_bits = 32;
150 const unsigned bits = glsl_base_type_get_bit_size(desc.element_type);
151
152 /* Each invocation must have at least one dword of data, and that dword
153 * must be tightly packed with values. No matter the matrix dimensions, a
154 * matrix of uint8_t data must pack 4 values in each entry.
155 */
156 const unsigned packing_factor = element_bits / bits;
157
158 assert(elements_per_invocation >= packing_factor);
159
160 switch (desc.element_type) {
161 case GLSL_TYPE_FLOAT:
162 base_type = GLSL_TYPE_FLOAT;
163 break;
164 case GLSL_TYPE_UINT:
165 case GLSL_TYPE_FLOAT16:
166 case GLSL_TYPE_UINT8:
167 case GLSL_TYPE_UINT16:
168 base_type = GLSL_TYPE_UINT;
169 break;
170 case GLSL_TYPE_INT:
171 case GLSL_TYPE_INT8:
172 case GLSL_TYPE_INT16:
173 base_type = GLSL_TYPE_INT;
174 break;
175 default:
176 unreachable("Invalid cooperative matrix element type.");
177 }
178
179 unsigned len = elements_per_invocation / packing_factor;
180
181 /* Supported matrix sizes are designed to fill either 4 or 8 SIMD8
182 * registers on DG2. That means:
183 *
184 * 4 regsiters 8 registers
185 * SIMD32 len = 1 len = 2
186 * SIMD16 len = 2 len = 4
187 * SIMD8 len = 4 len = 8
188 *
189 * On Xe2, supported matrix sizes are still designed to fill 4 registers
190 * (e.g., 8x32 uint8_t) or 8 registers (e.g., 16x16 float16). However, the
191 * 16x16 float16 matrix will assign 16 elements per channel at SIMD16.
192 */
193 assert(len == 1 || len == 2 || len == 4 || len == 8 || len == 16);
194
195 const struct glsl_type *slice_type = glsl_vector_type(base_type, len);
196
197 assert(packing_factor == get_packing_factor(desc, slice_type));
198
199 return slice_type;
200 }
201
202 static const struct glsl_type *
get_slice_type(const struct lower_cmat_state * state,const struct glsl_type * type)203 get_slice_type(const struct lower_cmat_state *state,
204 const struct glsl_type *type)
205 {
206 if (glsl_type_is_array(type)) {
207 const struct glsl_type *slice_type =
208 get_slice_type(state, glsl_get_array_element(type));
209
210 return glsl_array_type(slice_type, glsl_array_size(type), 0);
211 }
212
213 assert(glsl_type_is_cmat(type));
214
215 return get_slice_type_from_desc(state,
216 *glsl_get_cmat_description(type));
217 }
218
219 static nir_deref_instr *
create_local_slice(struct lower_cmat_state * state,nir_builder * b,const struct glsl_type * mat_type,const char * name)220 create_local_slice(struct lower_cmat_state *state, nir_builder *b,
221 const struct glsl_type *mat_type, const char *name)
222 {
223 const struct glsl_type *slice_type = get_slice_type(state, mat_type);
224 nir_variable *slice_var = nir_local_variable_create(b->impl, slice_type, name);
225 _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type);
226 return nir_build_deref_var(b, slice_var);
227 }
228
229 static void
lower_cmat_load_store(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)230 lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin,
231 struct lower_cmat_state *state)
232 {
233 const bool load = intrin->intrinsic == nir_intrinsic_cmat_load;
234 const unsigned mat_src = load ? 0 : 1;
235 const unsigned ptr_src = load ? 1 : 0;
236
237 nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]);
238 const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
239 const struct glsl_cmat_description *desc = glsl_get_cmat_description(mat_type);
240
241 nir_def *results[NIR_MAX_VEC_COMPONENTS];
242 const unsigned num_components = glsl_get_vector_elements(slice->type);
243 const unsigned packing_factor = get_packing_factor(*desc, slice->type);
244
245 nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]);
246 const unsigned ptr_comp_width = glsl_get_bit_size(pointer->type);
247 const unsigned ptr_num_comps = glsl_get_vector_elements(pointer->type);
248
249 /* The stride is given in number of elements of the pointed type, which
250 * doesn't necessarily match the matrix element type, so we need to adjust
251 * it considering it may be a vector and have a different bit-width.
252 */
253 nir_def *stride = nir_udiv_imm(b,
254 nir_imul_imm(b,
255 intrin->src[2].ssa,
256 ptr_comp_width * ptr_num_comps),
257 glsl_base_type_get_bit_size(desc->element_type));
258
259 /* The data that will be packed is in successive columns for A and
260 * accumulator matrices. The data that will be packed for B matrices is in
261 * successive rows.
262 */
263 const unsigned cols =
264 desc->use != GLSL_CMAT_USE_B ? desc->cols / packing_factor : desc->cols;
265
266 nir_def *invocation = nir_load_subgroup_invocation(b);
267 nir_def *invocation_div_cols = nir_udiv_imm(b, invocation, cols);
268 nir_def *invocation_mod_cols = nir_umod_imm(b, invocation, cols);
269
270 nir_def *i_stride;
271
272 const bool memory_layout_matches_register_layout =
273 (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) ==
274 (desc->use != GLSL_CMAT_USE_B);
275
276 if (memory_layout_matches_register_layout) {
277 /* In the row-major arrangement, data is loaded a dword at a time
278 * instead of a single element at a time. For this reason the stride is
279 * divided by the packing factor.
280 */
281 i_stride = nir_udiv_imm(b, stride, packing_factor);
282 } else {
283 /* In the column-major arrangement, data is loaded a single element at a
284 * time. Because the data elements are transposed, the step direction
285 * that moves a single (packed) element in the row-major arrangement has
286 * to explicitly step over the packing factor count of elements. For
287 * this reason the stride is multiplied by the packing factor.
288 *
289 * NOTE: The unscaled stride is also still needed when stepping from one
290 * packed element to the next. This occurs in the for-j loop below.
291 */
292 i_stride = nir_imul_imm(b, stride, packing_factor);
293 }
294
295 nir_def *base_offset;
296 nir_def *i_step;
297
298 if (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
299 base_offset = nir_iadd(b,
300 nir_imul(b,
301 invocation_div_cols,
302 i_stride),
303 invocation_mod_cols);
304
305 i_step = nir_imul_imm(b, i_stride, state->subgroup_size / cols);
306 } else {
307 base_offset = nir_iadd(b,
308 nir_imul(b,
309 invocation_mod_cols,
310 i_stride),
311 invocation_div_cols);
312
313 i_step = nir_imm_int(b, state->subgroup_size / cols);
314 }
315
316 if (memory_layout_matches_register_layout) {
317 const struct glsl_type *element_type =
318 glsl_scalar_type(glsl_get_base_type(slice->type));
319
320 pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes,
321 element_type,
322 glsl_get_bit_size(element_type) / 8);
323
324 for (unsigned i = 0; i < num_components; i++) {
325 nir_def *offset = nir_imul_imm(b, i_step, i);
326
327 nir_deref_instr *memory_deref =
328 nir_build_deref_ptr_as_array(b, pointer,
329 nir_i2iN(b,
330 nir_iadd(b,
331 base_offset,
332 offset),
333 pointer->def.bit_size));
334
335 if (load) {
336 results[i] = nir_load_deref(b, memory_deref);
337 } else {
338 nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
339 nir_store_deref(b, memory_deref, src, 0x1);
340 }
341 }
342 } else {
343 const struct glsl_type *element_type = glsl_scalar_type(desc->element_type);
344 const unsigned element_bits = glsl_base_type_get_bit_size(desc->element_type);
345 const unsigned element_stride = element_bits / 8;
346
347 pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type,
348 element_stride);
349
350 for (unsigned i = 0; i < num_components; i++) {
351 nir_def *i_offset = nir_imul_imm(b, i_step, i);
352 nir_def *v[4];
353
354 for (unsigned j = 0; j < packing_factor; j++) {
355 nir_def *offset = nir_iadd(b, nir_imul_imm(b, stride, j), i_offset);
356
357 nir_deref_instr *memory_deref =
358 nir_build_deref_ptr_as_array(b, pointer,
359 nir_i2iN(b,
360 nir_iadd(b,
361 base_offset,
362 offset),
363 pointer->def.bit_size));
364
365 if (load) {
366 v[j] = nir_load_deref(b, memory_deref);
367 } else {
368 nir_def *src = nir_channel(b, nir_load_deref(b, slice), i);
369
370 nir_def *v =
371 nir_channel(b, nir_unpack_bits(b, src, element_bits), j);
372
373 nir_store_deref(b, memory_deref, v, 0x1);
374 }
375 }
376
377 if (load) {
378 results[i] = nir_pack_bits(b, nir_vec(b, v, packing_factor),
379 packing_factor * element_bits);
380 }
381 }
382 }
383
384 if (load)
385 nir_store_deref(b, slice, nir_vec(b, results, num_components),
386 nir_component_mask(num_components));
387 }
388
389 static void
lower_cmat_unary_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)390 lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
391 struct lower_cmat_state *state)
392 {
393 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
394 nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
395 nir_def *results[NIR_MAX_VEC_COMPONENTS];
396 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
397
398 const struct glsl_type *dst_mat_type =
399 get_coop_type_for_slice(state, dst_slice);
400 const struct glsl_type *src_mat_type =
401 get_coop_type_for_slice(state, src_slice);
402
403 const struct glsl_cmat_description dst_desc =
404 *glsl_get_cmat_description(dst_mat_type);
405
406 const struct glsl_cmat_description src_desc =
407 *glsl_get_cmat_description(src_mat_type);
408
409 const unsigned dst_bits = glsl_base_type_bit_size(dst_desc.element_type);
410 const unsigned src_bits = glsl_base_type_bit_size(src_desc.element_type);
411
412 /* The type of the returned slice may be different from the type of the
413 * input slice.
414 */
415 const unsigned dst_packing_factor =
416 get_packing_factor(dst_desc, dst_slice->type);
417
418 const unsigned src_packing_factor =
419 get_packing_factor(src_desc, src_slice->type);
420
421 const nir_op op = nir_intrinsic_alu_op(intrin);
422
423 /* With the combinations of formats exposed on all platforms, matrices with
424 * the same dimensions will always have the same data size. The only real
425 * type conversion possible is int32 <-> float32. As a result
426 * dst_packing_factor == src_packing_factor.
427 */
428 assert(dst_packing_factor == src_packing_factor);
429
430 /* Stores at most dst_packing_factor partial results. */
431 nir_def *v[4];
432 assert(dst_packing_factor <= 4);
433
434 for (unsigned i = 0; i < num_components; i++) {
435 nir_def *chan = nir_channel(b, nir_load_deref(b, src_slice), i);
436
437 for (unsigned j = 0; j < dst_packing_factor; j++) {
438 nir_def *src =
439 nir_channel(b, nir_unpack_bits(b, chan, src_bits), j);
440
441 v[j] = nir_build_alu1(b, op, src);
442 }
443
444 results[i] =
445 nir_pack_bits(b, nir_vec(b, v, dst_packing_factor),
446 dst_packing_factor * dst_bits);
447 }
448
449 nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
450 nir_component_mask(num_components));
451 }
452
453 static void
lower_cmat_binary_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)454 lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin,
455 struct lower_cmat_state *state)
456 {
457 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
458 nir_deref_instr *src_a_slice = nir_src_as_deref(intrin->src[1]);
459 nir_deref_instr *src_b_slice = nir_src_as_deref(intrin->src[2]);
460
461 nir_def *src_a = nir_load_deref(b, src_a_slice);
462 nir_def *src_b = nir_load_deref(b, src_b_slice);
463 nir_def *results[NIR_MAX_VEC_COMPONENTS];
464 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
465
466 const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
467 ASSERTED const struct glsl_type *src_a_mat_type = get_coop_type_for_slice(state, src_a_slice);
468 ASSERTED const struct glsl_type *src_b_mat_type = get_coop_type_for_slice(state, src_b_slice);
469
470 const struct glsl_cmat_description desc =
471 *glsl_get_cmat_description(dst_mat_type);
472
473 assert(dst_mat_type == src_a_mat_type);
474 assert(dst_mat_type == src_b_mat_type);
475
476 const unsigned bits = glsl_base_type_bit_size(desc.element_type);
477 const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
478
479 for (unsigned i = 0; i < num_components; i++) {
480 nir_def *val_a = nir_channel(b, src_a, i);
481 nir_def *val_b = nir_channel(b, src_b, i);
482
483 results[i] =
484 nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin),
485 nir_unpack_bits(b, val_a, bits),
486 nir_unpack_bits(b, val_b, bits)),
487 packing_factor * bits);
488 }
489
490 nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
491 nir_component_mask(num_components));
492 }
493
494 static void
lower_cmat_scalar_op(nir_builder * b,nir_intrinsic_instr * intrin,struct lower_cmat_state * state)495 lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin,
496 struct lower_cmat_state *state)
497 {
498 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
499 nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
500 nir_def *scalar = intrin->src[2].ssa;
501
502 nir_def *src = nir_load_deref(b, src_slice);
503 nir_def *results[NIR_MAX_VEC_COMPONENTS];
504 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
505
506 ASSERTED const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
507 ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice);
508 assert(dst_mat_type == src_mat_type);
509
510 const struct glsl_cmat_description desc =
511 *glsl_get_cmat_description(dst_mat_type);
512
513 const unsigned bits = glsl_base_type_bit_size(desc.element_type);
514 const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
515
516 for (unsigned i = 0; i < num_components; i++) {
517 nir_def *val = nir_channel(b, src, i);
518
519 results[i] =
520 nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin),
521 nir_unpack_bits(b, val, bits),
522 scalar),
523 packing_factor * bits);
524 }
525
526 nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
527 nir_component_mask(num_components));
528 }
529
530 static nir_deref_instr *
lower_cmat_deref(nir_builder * b,nir_deref_instr * deref,struct lower_cmat_state * state)531 lower_cmat_deref(nir_builder *b, nir_deref_instr *deref,
532 struct lower_cmat_state *state)
533 {
534 nir_deref_instr *parent = nir_deref_instr_parent(deref);
535 if (parent) {
536 assert(deref->deref_type == nir_deref_type_array);
537 parent = lower_cmat_deref(b, parent, state);
538 return nir_build_deref_array(b, parent, deref->arr.index.ssa);
539 } else {
540 assert(deref->deref_type == nir_deref_type_var);
541 assert(deref->var);
542 assert(glsl_type_is_cmat(glsl_without_array(deref->var->type)));
543
544 struct hash_entry *entry = _mesa_hash_table_search(state->vars_to_slice, deref->var);
545 assert(entry);
546 return nir_build_deref_var(b, (nir_variable *)entry->data);
547 }
548 }
549
550 static nir_def *
lower_cmat_instr(nir_builder * b,nir_instr * instr,void * _state)551 lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
552 {
553 struct lower_cmat_state *state = _state;
554
555 if (instr->type == nir_instr_type_deref) {
556 nir_deref_instr *deref = lower_cmat_deref(b, nir_instr_as_deref(instr), state);
557 return &deref->def;
558 }
559
560 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
561 switch (intrin->intrinsic) {
562 case nir_intrinsic_cmat_load:
563 case nir_intrinsic_cmat_store:
564 lower_cmat_load_store(b, intrin, state);
565 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
566
567 case nir_intrinsic_cmat_construct: {
568 nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
569 nir_def *src = intrin->src[1].ssa;
570
571 const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
572 const struct glsl_cmat_description desc =
573 *glsl_get_cmat_description(mat_type);
574 const unsigned packing_factor = get_packing_factor(desc, slice->type);
575
576 if (packing_factor > 1) {
577 src = nir_pack_bits(b, nir_replicate(b, src, packing_factor),
578 packing_factor * glsl_base_type_get_bit_size(desc.element_type));
579 }
580
581 const unsigned num_components = glsl_get_vector_elements(slice->type);
582
583 nir_store_deref(b, slice, nir_replicate(b, src, num_components),
584 nir_component_mask(num_components));
585 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
586 }
587
588 case nir_intrinsic_cmat_unary_op:
589 lower_cmat_unary_op(b, intrin, state);
590 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
591
592 case nir_intrinsic_cmat_binary_op:
593 lower_cmat_binary_op(b, intrin, state);
594 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
595
596 case nir_intrinsic_cmat_scalar_op:
597 lower_cmat_scalar_op(b, intrin, state);
598 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
599
600 case nir_intrinsic_cmat_length: {
601 const struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intrin);
602 const struct glsl_type *mat_type = glsl_cmat_type(&desc);
603 const struct glsl_type *slice_type = get_slice_type(state, mat_type);
604 return nir_imm_intN_t(b, (get_packing_factor(desc, slice_type) *
605 glsl_get_vector_elements(slice_type)), 32);
606 }
607
608 case nir_intrinsic_cmat_muladd: {
609 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
610 nir_deref_instr *A_slice = nir_src_as_deref(intrin->src[1]);
611 nir_deref_instr *B_slice = nir_src_as_deref(intrin->src[2]);
612 nir_deref_instr *accum_slice = nir_src_as_deref(intrin->src[3]);
613
614 const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
615 const struct glsl_cmat_description dst_desc = *glsl_get_cmat_description(dst_mat_type);
616
617 const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, A_slice);
618 const struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_mat_type);
619
620 const unsigned packing_factor = get_packing_factor(dst_desc, dst_slice->type);
621 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
622
623 const nir_cmat_signed cmat_signed_mask =
624 nir_intrinsic_cmat_signed_mask(intrin);
625
626 assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) ==
627 ((cmat_signed_mask & NIR_CMAT_B_SIGNED) == 0));
628 assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) ==
629 ((cmat_signed_mask & NIR_CMAT_C_SIGNED) == 0));
630 assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) ==
631 ((cmat_signed_mask & NIR_CMAT_RESULT_SIGNED) == 0));
632
633 nir_alu_type src_type =
634 nir_get_nir_type_for_glsl_base_type(src_desc.element_type);
635 nir_alu_type dest_type =
636 nir_get_nir_type_for_glsl_base_type(dst_desc.element_type);
637
638 /* For integer types, the signedness is determined by flags on the
639 * muladd instruction. The types of the sources play no role. Adjust the
640 * types passed to the dpas_intel intrinsic to match.
641 */
642 if (nir_alu_type_get_base_type(src_type) == nir_type_uint ||
643 nir_alu_type_get_base_type(src_type) == nir_type_int) {
644 if ((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) {
645 src_type = nir_alu_type_get_type_size(src_type) | nir_type_uint;
646 dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_uint;
647 } else {
648 src_type = nir_alu_type_get_type_size(src_type) | nir_type_int;
649 dest_type = nir_alu_type_get_type_size(dest_type) | nir_type_int;
650 }
651 }
652
653 nir_def *result =
654 nir_dpas_intel(b,
655 packing_factor * glsl_base_type_get_bit_size(dst_desc.element_type),
656 nir_load_deref(b, accum_slice),
657 nir_load_deref(b, A_slice),
658 nir_load_deref(b, B_slice),
659 .dest_type = dest_type,
660 .src_type = src_type,
661 .saturate = nir_intrinsic_saturate(intrin),
662 .systolic_depth = 8,
663 .repeat_count = 8);
664
665 nir_store_deref(b, dst_slice, result,
666 nir_component_mask(num_components));
667
668 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
669 }
670
671 case nir_intrinsic_cmat_bitcast: {
672 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
673 nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
674
675 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
676
677 assert(glsl_get_vector_elements(src_slice->type) == num_components);
678
679 nir_store_deref(b, dst_slice, nir_load_deref(b, src_slice),
680 nir_component_mask(num_components));
681 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
682 }
683
684 case nir_intrinsic_cmat_copy:
685 nir_copy_deref(b,
686 nir_src_as_deref(intrin->src[0]),
687 nir_src_as_deref(intrin->src[1]));
688 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
689
690 case nir_intrinsic_cmat_insert: {
691 nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
692 nir_def *scalar = intrin->src[1].ssa;
693 nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[2]);
694 const nir_src dst_index = intrin->src[3];
695
696 const struct glsl_type *dst_mat_type = get_coop_type_for_slice(state, dst_slice);
697 ASSERTED const struct glsl_type *src_mat_type = get_coop_type_for_slice(state, src_slice);
698 assert(dst_mat_type == src_mat_type);
699
700 const struct glsl_cmat_description desc =
701 *glsl_get_cmat_description(dst_mat_type);
702
703 const unsigned bits = glsl_base_type_bit_size(desc.element_type);
704 const unsigned packing_factor = get_packing_factor(desc, dst_slice->type);
705 const unsigned num_components = glsl_get_vector_elements(dst_slice->type);
706
707 nir_def *slice_index = nir_udiv_imm(b, dst_index.ssa, packing_factor);
708 nir_def *vector_index = nir_umod_imm(b, dst_index.ssa, packing_factor);
709 nir_def *results[NIR_MAX_VEC_COMPONENTS];
710
711 const int slice_constant_index = nir_src_is_const(dst_index)
712 ? nir_src_as_uint(dst_index) / packing_factor
713 : -1;
714
715 for (unsigned i = 0; i < num_components; i++) {
716 nir_def *val = nir_channel(b, nir_load_deref(b, src_slice), i);
717 nir_def *insert;
718
719 if (slice_constant_index < 0 || slice_constant_index == i) {
720 if (packing_factor == 1) {
721 insert = scalar;
722 } else {
723 nir_def *unpacked = nir_unpack_bits(b, val, bits);
724 nir_def *v = nir_vector_insert(b, unpacked, scalar, vector_index);
725
726 insert = nir_pack_bits(b, v, bits * packing_factor);
727 }
728 } else {
729 insert = val;
730 }
731
732 results[i] = slice_constant_index < 0
733 ? nir_bcsel(b, nir_ieq_imm(b, slice_index, i), insert, val)
734 : insert;
735 }
736
737 nir_store_deref(b, dst_slice, nir_vec(b, results, num_components),
738 nir_component_mask(num_components));
739
740 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
741 }
742
743 case nir_intrinsic_cmat_extract: {
744 nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]);
745 const struct glsl_type *mat_type = get_coop_type_for_slice(state, slice);
746 nir_def *index = intrin->src[1].ssa;
747
748 const struct glsl_cmat_description desc =
749 *glsl_get_cmat_description(mat_type);
750
751 const unsigned bits = glsl_base_type_bit_size(desc.element_type);
752 const unsigned packing_factor = get_packing_factor(desc, slice->type);
753
754 nir_def *src =
755 nir_vector_extract(b, nir_load_deref(b, slice),
756 nir_udiv_imm(b, index, packing_factor));
757
758 if (packing_factor == 1) {
759 return src;
760 } else {
761 return nir_vector_extract(b,
762 nir_unpack_bits(b, src, bits),
763 nir_umod_imm(b, index, packing_factor));
764 }
765
766 return NIR_LOWER_INSTR_PROGRESS_REPLACE;
767 }
768
769 default:
770 unreachable("invalid cooperative matrix intrinsic");
771 }
772 }
773
774 static void
create_slice_var(struct lower_cmat_state * state,nir_variable * var,nir_function_impl * impl)775 create_slice_var(struct lower_cmat_state *state, nir_variable *var,
776 nir_function_impl *impl)
777 {
778 // TODO: without array
779 const struct glsl_type *mat_type = glsl_without_array(var->type);
780
781 assert(glsl_type_is_cmat(mat_type));
782 assert((!impl && var->data.mode == nir_var_shader_temp) ||
783 ( impl && var->data.mode == nir_var_function_temp));
784
785 const struct glsl_type *slice_type = get_slice_type(state, var->type);
786 const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name);
787 nir_variable *slice_var = impl ?
788 nir_local_variable_create(impl, slice_type, slice_name) :
789 nir_variable_create(state->shader, var->data.mode, slice_type, slice_name);
790
791 _mesa_hash_table_insert(state->vars_to_slice, var, slice_var);
792 _mesa_hash_table_insert(state->slice_coop_types, slice_var, (void *)mat_type);
793 }
794
795 bool
brw_nir_lower_cmat(nir_shader * shader,unsigned subgroup_size)796 brw_nir_lower_cmat(nir_shader *shader, unsigned subgroup_size)
797 {
798 void *temp_ctx = ralloc_context(NULL);
799
800 struct lower_cmat_state state = {
801 .shader = shader,
802 .slice_coop_types = _mesa_pointer_hash_table_create(temp_ctx),
803 .vars_to_slice = _mesa_pointer_hash_table_create(temp_ctx),
804 .subgroup_size = subgroup_size,
805 };
806
807 /* Create a slice array for each variable and add a map from the original
808 * variable back to it, so it can be reached during lowering.
809 *
810 * TODO: Cooperative matrix inside struct?
811 */
812 nir_foreach_variable_in_shader(var, shader) {
813 if (glsl_type_is_cmat(glsl_without_array(var->type)))
814 create_slice_var(&state, var, NULL);
815 }
816 nir_foreach_function(func, shader) {
817 nir_foreach_function_temp_variable(var, func->impl) {
818 if (glsl_type_is_cmat(glsl_without_array(var->type)))
819 create_slice_var(&state, var, func->impl);
820 }
821 }
822
823 bool progress = nir_shader_lower_instructions(shader,
824 lower_cmat_filter,
825 lower_cmat_instr,
826 &state);
827
828 ralloc_free(temp_ctx);
829
830 return progress;
831 }
832