xref: /aosp_15_r20/external/mesa3d/src/compiler/spirv/vtn_cmat.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Intel Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "glsl_types.h"
7 #include "nir.h"
8 #include "vtn_private.h"
9 
10 static enum glsl_cmat_use
vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)11 vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)
12 {
13    switch (use) {
14    case SpvCooperativeMatrixUseMatrixAKHR:
15       return GLSL_CMAT_USE_A;
16    case SpvCooperativeMatrixUseMatrixBKHR:
17       return GLSL_CMAT_USE_B;
18    case SpvCooperativeMatrixUseMatrixAccumulatorKHR:
19       return GLSL_CMAT_USE_ACCUMULATOR;
20    default:
21       unreachable("Unexpected cooperative matrix use");
22    }
23 }
24 
25 void
vtn_handle_cooperative_type(struct vtn_builder * b,struct vtn_value * val,SpvOp opcode,const uint32_t * w,unsigned count)26 vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
27                             SpvOp opcode, const uint32_t *w, unsigned count)
28 {
29    vtn_assert(opcode == SpvOpTypeCooperativeMatrixKHR);
30 
31    b->shader->info.cs.has_cooperative_matrix = true;
32 
33    struct vtn_type *component_type = vtn_get_type(b, w[2]);
34 
35    const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
36    const uint32_t rows = vtn_constant_uint(b, w[4]);
37    const uint32_t cols = vtn_constant_uint(b, w[5]);
38 
39    vtn_assert(rows < 256);
40    vtn_assert(cols < 256);
41 
42    enum glsl_cmat_use use = vtn_cooperative_matrix_use_to_glsl(vtn_constant_uint(b, w[6]));
43 
44    val->type->base_type = vtn_base_type_cooperative_matrix;
45    vtn_fail_if(!glsl_type_is_numeric(component_type->type),
46                "OpTypeCooperativeMatrixKHR "
47                "Component Type must be a scalar numerical type.");
48 
49    val->type->desc.element_type = glsl_get_base_type(component_type->type);
50    val->type->desc.scope = scope;
51    val->type->desc.rows = rows;
52    val->type->desc.cols = cols;
53    val->type->desc.use = use;
54 
55    val->type->type = glsl_cmat_type(&val->type->desc);
56    val->type->component_type = component_type;
57 }
58 
59 static enum glsl_matrix_layout
vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)60 vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)
61 {
62    switch (layout) {
63    case SpvCooperativeMatrixLayoutRowMajorKHR:
64       return GLSL_MATRIX_LAYOUT_ROW_MAJOR;
65    case SpvCooperativeMatrixLayoutColumnMajorKHR:
66       return GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
67    default:
68       unreachable("Unexpected cooperative matrix layout");
69    }
70 }
71 
72 nir_deref_instr *
vtn_create_cmat_temporary(struct vtn_builder * b,const struct glsl_type * t,const char * name)73 vtn_create_cmat_temporary(struct vtn_builder *b, const struct glsl_type *t, const char *name)
74 {
75    nir_variable *var = nir_local_variable_create(b->nb.impl, t, name);
76    return nir_build_deref_var(&b->nb, var);
77 }
78 
79 static nir_deref_instr *
vtn_get_cmat_deref(struct vtn_builder * b,uint32_t value_id)80 vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id)
81 {
82    nir_deref_instr *deref = vtn_get_deref_for_id(b, value_id);
83    vtn_assert(glsl_type_is_cmat(deref->type));
84    return deref;
85 }
86 
87 void
vtn_handle_cooperative_instruction(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)88 vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
89                                    const uint32_t *w, unsigned count)
90 {
91    switch (opcode) {
92    case SpvOpCooperativeMatrixLoadKHR: {
93       struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
94       struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
95       struct vtn_type *dst_type = vtn_get_type(b, w[1]);
96 
97       const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]);
98       nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32);
99 
100       SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
101       if (count > 6) {
102          unsigned idx = 6, alignment;
103          SpvScope scope;
104          vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
105          vtn_emit_make_visible_barrier(b, access, scope, src->mode);
106       }
107 
108       nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
109       nir_cmat_load(&b->nb, &dst->def, vtn_pointer_to_ssa(b, src), stride,
110                     .matrix_layout = vtn_matrix_layout_to_glsl(layout));
111       vtn_push_var_ssa(b, w[2], dst->var);
112       break;
113    }
114 
115    case SpvOpCooperativeMatrixStoreKHR: {
116       struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer);
117       struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
118 
119       const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]);
120       nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32);
121 
122       SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
123       if (count > 5) {
124          unsigned idx = 5, alignment;
125          SpvScope scope;
126          vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
127          vtn_emit_make_available_barrier(b, access, scope, dest->mode);
128       }
129 
130       nir_deref_instr *src = vtn_get_cmat_deref(b, w[2]);
131       nir_cmat_store(&b->nb, vtn_pointer_to_ssa(b, dest), &src->def, stride,
132                      .matrix_layout = vtn_matrix_layout_to_glsl(layout));
133       break;
134    }
135 
136    case SpvOpCooperativeMatrixLengthKHR: {
137       struct vtn_type *type = vtn_get_type(b, w[3]);
138       nir_def *def = nir_cmat_length(&b->nb, .cmat_desc = type->desc);
139       vtn_push_nir_ssa(b, w[2], def);
140       break;
141    }
142 
143    case SpvOpCooperativeMatrixMulAddKHR: {
144       nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
145       nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
146       nir_deref_instr *mat_c = vtn_get_cmat_deref(b, w[5]);
147 
148       const uint32_t operands = count > 6 ? w[6] : 0;
149       const bool saturate = operands & SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask;
150       const unsigned signed_mask = operands & (SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
151                                                SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
152                                                SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
153                                                SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask);
154 
155       STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask == NIR_CMAT_A_SIGNED);
156       STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask == NIR_CMAT_B_SIGNED);
157       STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask == NIR_CMAT_C_SIGNED);
158       STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask == NIR_CMAT_RESULT_SIGNED);
159 
160       struct vtn_type *dst_type = vtn_get_type(b, w[1]);
161       nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_muladd");
162 
163       nir_cmat_muladd(&b->nb, &dst->def, &mat_a->def, &mat_b->def, &mat_c->def,
164                       .saturate = saturate,
165                       .cmat_signed_mask = signed_mask);
166 
167       vtn_push_var_ssa(b, w[2], dst->var);
168       break;
169    }
170 
171    case SpvOpBitcast: {
172       struct vtn_type *dst_type = vtn_get_type(b, w[1]);
173       vtn_assert(dst_type->base_type == vtn_base_type_cooperative_matrix);
174       nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
175 
176       nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
177       nir_cmat_bitcast(&b->nb, &dst->def, &src->def);
178       vtn_push_var_ssa(b, w[2], dst->var);
179       break;
180    }
181 
182    default:
183       unreachable("Unexpected opcode for cooperative matrix instruction");
184    }
185 }
186 
187 void
vtn_handle_cooperative_alu(struct vtn_builder * b,struct vtn_value * dest_val,const struct glsl_type * dest_type,SpvOp opcode,const uint32_t * w,unsigned count)188 vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
189                            const struct glsl_type *dest_type, SpvOp opcode,
190                            const uint32_t *w, unsigned count)
191 {
192       vtn_assert(glsl_type_is_cmat(dest_type));
193 
194       switch (opcode) {
195       case SpvOpConvertFToU:
196       case SpvOpConvertFToS:
197       case SpvOpConvertSToF:
198       case SpvOpConvertUToF:
199       case SpvOpUConvert:
200       case SpvOpSConvert:
201       case SpvOpFConvert:
202       case SpvOpFNegate:
203       case SpvOpSNegate: {
204          struct vtn_type *dst_type = vtn_get_type(b, w[1]);
205          nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
206 
207          unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type));
208          unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type));
209 
210          bool ignored = false;
211          nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
212                                                      src_bit_size, dst_bit_size);
213 
214          nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
215          nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
216                            .alu_op = op);
217          vtn_push_var_ssa(b, w[2], dst->var);
218          break;
219       }
220 
221       case SpvOpFAdd:
222       case SpvOpFSub:
223       case SpvOpFMul:
224       case SpvOpFDiv:
225       case SpvOpIAdd:
226       case SpvOpISub:
227       case SpvOpIMul:
228       case SpvOpSDiv:
229       case SpvOpUDiv: {
230          bool ignored = false;
231          nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0);
232 
233          struct vtn_type *dst_type = vtn_get_type(b, w[1]);
234          nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
235          nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
236 
237          nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
238          nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,
239                             .alu_op = op);
240          vtn_push_var_ssa(b, w[2], dst->var);
241          break;
242       }
243 
244       case SpvOpMatrixTimesScalar: {
245          struct vtn_type *dst_type = vtn_get_type(b, w[1]);
246          nir_deref_instr *mat = vtn_get_cmat_deref(b, w[3]);
247 
248          struct vtn_ssa_value *scalar_val = vtn_ssa_value(b, w[4]);
249          vtn_assert(glsl_type_is_scalar(scalar_val->type));
250          nir_op op = glsl_type_is_integer(scalar_val->type) ? nir_op_imul : nir_op_fmul;
251 
252          nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_times_scalar");
253          nir_cmat_scalar_op(&b->nb, &dst->def, &mat->def, scalar_val->def,
254                             .alu_op = op);
255          vtn_push_var_ssa(b, w[2], dst->var);
256          break;
257       }
258 
259       default:
260          unreachable("invalid cooperative matrix alu instruction");
261       }
262 }
263 
264 struct vtn_ssa_value *
vtn_cooperative_matrix_extract(struct vtn_builder * b,struct vtn_ssa_value * mat,const uint32_t * indices,unsigned num_indices)265 vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
266                                const uint32_t *indices, unsigned num_indices)
267 {
268    vtn_assert(glsl_type_is_cmat(mat->type));
269    nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
270 
271    vtn_assert(num_indices == 1);
272    nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
273 
274    const struct glsl_type *element_type = glsl_get_cmat_element(mat->type);
275    struct vtn_ssa_value *ret = vtn_create_ssa_value(b, element_type);
276    ret->def = nir_cmat_extract(&b->nb, glsl_get_bit_size(element_type), &mat_deref->def, index);
277    return ret;
278 }
279 
280 struct vtn_ssa_value *
vtn_cooperative_matrix_insert(struct vtn_builder * b,struct vtn_ssa_value * mat,struct vtn_ssa_value * insert,const uint32_t * indices,unsigned num_indices)281 vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
282                               struct vtn_ssa_value *insert, const uint32_t *indices,
283                               unsigned num_indices)
284 {
285    vtn_assert(glsl_type_is_cmat(mat->type));
286    nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
287 
288    vtn_assert(num_indices == 1);
289    nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
290 
291    nir_deref_instr *dst = vtn_create_cmat_temporary(b, mat_deref->type, "cmat_insert");
292    nir_cmat_insert(&b->nb, &dst->def, insert->def, &mat_deref->def, index);
293 
294    struct vtn_ssa_value *ret = vtn_create_ssa_value(b, dst->type);
295    vtn_set_ssa_value_var(b, ret, dst->var);
296    return ret;
297 }
298