1 /*
2 * Copyright 2023 Intel Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "glsl_types.h"
7 #include "nir.h"
8 #include "vtn_private.h"
9
10 static enum glsl_cmat_use
vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)11 vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)
12 {
13 switch (use) {
14 case SpvCooperativeMatrixUseMatrixAKHR:
15 return GLSL_CMAT_USE_A;
16 case SpvCooperativeMatrixUseMatrixBKHR:
17 return GLSL_CMAT_USE_B;
18 case SpvCooperativeMatrixUseMatrixAccumulatorKHR:
19 return GLSL_CMAT_USE_ACCUMULATOR;
20 default:
21 unreachable("Unexpected cooperative matrix use");
22 }
23 }
24
25 void
vtn_handle_cooperative_type(struct vtn_builder * b,struct vtn_value * val,SpvOp opcode,const uint32_t * w,unsigned count)26 vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
27 SpvOp opcode, const uint32_t *w, unsigned count)
28 {
29 vtn_assert(opcode == SpvOpTypeCooperativeMatrixKHR);
30
31 b->shader->info.cs.has_cooperative_matrix = true;
32
33 struct vtn_type *component_type = vtn_get_type(b, w[2]);
34
35 const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
36 const uint32_t rows = vtn_constant_uint(b, w[4]);
37 const uint32_t cols = vtn_constant_uint(b, w[5]);
38
39 vtn_assert(rows < 256);
40 vtn_assert(cols < 256);
41
42 enum glsl_cmat_use use = vtn_cooperative_matrix_use_to_glsl(vtn_constant_uint(b, w[6]));
43
44 val->type->base_type = vtn_base_type_cooperative_matrix;
45 vtn_fail_if(!glsl_type_is_numeric(component_type->type),
46 "OpTypeCooperativeMatrixKHR "
47 "Component Type must be a scalar numerical type.");
48
49 val->type->desc.element_type = glsl_get_base_type(component_type->type);
50 val->type->desc.scope = scope;
51 val->type->desc.rows = rows;
52 val->type->desc.cols = cols;
53 val->type->desc.use = use;
54
55 val->type->type = glsl_cmat_type(&val->type->desc);
56 val->type->component_type = component_type;
57 }
58
59 static enum glsl_matrix_layout
vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)60 vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)
61 {
62 switch (layout) {
63 case SpvCooperativeMatrixLayoutRowMajorKHR:
64 return GLSL_MATRIX_LAYOUT_ROW_MAJOR;
65 case SpvCooperativeMatrixLayoutColumnMajorKHR:
66 return GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
67 default:
68 unreachable("Unexpected cooperative matrix layout");
69 }
70 }
71
72 nir_deref_instr *
vtn_create_cmat_temporary(struct vtn_builder * b,const struct glsl_type * t,const char * name)73 vtn_create_cmat_temporary(struct vtn_builder *b, const struct glsl_type *t, const char *name)
74 {
75 nir_variable *var = nir_local_variable_create(b->nb.impl, t, name);
76 return nir_build_deref_var(&b->nb, var);
77 }
78
79 static nir_deref_instr *
vtn_get_cmat_deref(struct vtn_builder * b,uint32_t value_id)80 vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id)
81 {
82 nir_deref_instr *deref = vtn_get_deref_for_id(b, value_id);
83 vtn_assert(glsl_type_is_cmat(deref->type));
84 return deref;
85 }
86
87 void
vtn_handle_cooperative_instruction(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)88 vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
89 const uint32_t *w, unsigned count)
90 {
91 switch (opcode) {
92 case SpvOpCooperativeMatrixLoadKHR: {
93 struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
94 struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
95 struct vtn_type *dst_type = vtn_get_type(b, w[1]);
96
97 const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]);
98 nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32);
99
100 SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
101 if (count > 6) {
102 unsigned idx = 6, alignment;
103 SpvScope scope;
104 vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
105 vtn_emit_make_visible_barrier(b, access, scope, src->mode);
106 }
107
108 nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
109 nir_cmat_load(&b->nb, &dst->def, vtn_pointer_to_ssa(b, src), stride,
110 .matrix_layout = vtn_matrix_layout_to_glsl(layout));
111 vtn_push_var_ssa(b, w[2], dst->var);
112 break;
113 }
114
115 case SpvOpCooperativeMatrixStoreKHR: {
116 struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer);
117 struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
118
119 const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]);
120 nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32);
121
122 SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
123 if (count > 5) {
124 unsigned idx = 5, alignment;
125 SpvScope scope;
126 vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
127 vtn_emit_make_available_barrier(b, access, scope, dest->mode);
128 }
129
130 nir_deref_instr *src = vtn_get_cmat_deref(b, w[2]);
131 nir_cmat_store(&b->nb, vtn_pointer_to_ssa(b, dest), &src->def, stride,
132 .matrix_layout = vtn_matrix_layout_to_glsl(layout));
133 break;
134 }
135
136 case SpvOpCooperativeMatrixLengthKHR: {
137 struct vtn_type *type = vtn_get_type(b, w[3]);
138 nir_def *def = nir_cmat_length(&b->nb, .cmat_desc = type->desc);
139 vtn_push_nir_ssa(b, w[2], def);
140 break;
141 }
142
143 case SpvOpCooperativeMatrixMulAddKHR: {
144 nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
145 nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
146 nir_deref_instr *mat_c = vtn_get_cmat_deref(b, w[5]);
147
148 const uint32_t operands = count > 6 ? w[6] : 0;
149 const bool saturate = operands & SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask;
150 const unsigned signed_mask = operands & (SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
151 SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
152 SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
153 SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask);
154
155 STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask == NIR_CMAT_A_SIGNED);
156 STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask == NIR_CMAT_B_SIGNED);
157 STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask == NIR_CMAT_C_SIGNED);
158 STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask == NIR_CMAT_RESULT_SIGNED);
159
160 struct vtn_type *dst_type = vtn_get_type(b, w[1]);
161 nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_muladd");
162
163 nir_cmat_muladd(&b->nb, &dst->def, &mat_a->def, &mat_b->def, &mat_c->def,
164 .saturate = saturate,
165 .cmat_signed_mask = signed_mask);
166
167 vtn_push_var_ssa(b, w[2], dst->var);
168 break;
169 }
170
171 case SpvOpBitcast: {
172 struct vtn_type *dst_type = vtn_get_type(b, w[1]);
173 vtn_assert(dst_type->base_type == vtn_base_type_cooperative_matrix);
174 nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
175
176 nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
177 nir_cmat_bitcast(&b->nb, &dst->def, &src->def);
178 vtn_push_var_ssa(b, w[2], dst->var);
179 break;
180 }
181
182 default:
183 unreachable("Unexpected opcode for cooperative matrix instruction");
184 }
185 }
186
187 void
vtn_handle_cooperative_alu(struct vtn_builder * b,struct vtn_value * dest_val,const struct glsl_type * dest_type,SpvOp opcode,const uint32_t * w,unsigned count)188 vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
189 const struct glsl_type *dest_type, SpvOp opcode,
190 const uint32_t *w, unsigned count)
191 {
192 vtn_assert(glsl_type_is_cmat(dest_type));
193
194 switch (opcode) {
195 case SpvOpConvertFToU:
196 case SpvOpConvertFToS:
197 case SpvOpConvertSToF:
198 case SpvOpConvertUToF:
199 case SpvOpUConvert:
200 case SpvOpSConvert:
201 case SpvOpFConvert:
202 case SpvOpFNegate:
203 case SpvOpSNegate: {
204 struct vtn_type *dst_type = vtn_get_type(b, w[1]);
205 nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
206
207 unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type));
208 unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type));
209
210 bool ignored = false;
211 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
212 src_bit_size, dst_bit_size);
213
214 nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
215 nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
216 .alu_op = op);
217 vtn_push_var_ssa(b, w[2], dst->var);
218 break;
219 }
220
221 case SpvOpFAdd:
222 case SpvOpFSub:
223 case SpvOpFMul:
224 case SpvOpFDiv:
225 case SpvOpIAdd:
226 case SpvOpISub:
227 case SpvOpIMul:
228 case SpvOpSDiv:
229 case SpvOpUDiv: {
230 bool ignored = false;
231 nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0);
232
233 struct vtn_type *dst_type = vtn_get_type(b, w[1]);
234 nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
235 nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
236
237 nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
238 nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,
239 .alu_op = op);
240 vtn_push_var_ssa(b, w[2], dst->var);
241 break;
242 }
243
244 case SpvOpMatrixTimesScalar: {
245 struct vtn_type *dst_type = vtn_get_type(b, w[1]);
246 nir_deref_instr *mat = vtn_get_cmat_deref(b, w[3]);
247
248 struct vtn_ssa_value *scalar_val = vtn_ssa_value(b, w[4]);
249 vtn_assert(glsl_type_is_scalar(scalar_val->type));
250 nir_op op = glsl_type_is_integer(scalar_val->type) ? nir_op_imul : nir_op_fmul;
251
252 nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_times_scalar");
253 nir_cmat_scalar_op(&b->nb, &dst->def, &mat->def, scalar_val->def,
254 .alu_op = op);
255 vtn_push_var_ssa(b, w[2], dst->var);
256 break;
257 }
258
259 default:
260 unreachable("invalid cooperative matrix alu instruction");
261 }
262 }
263
264 struct vtn_ssa_value *
vtn_cooperative_matrix_extract(struct vtn_builder * b,struct vtn_ssa_value * mat,const uint32_t * indices,unsigned num_indices)265 vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
266 const uint32_t *indices, unsigned num_indices)
267 {
268 vtn_assert(glsl_type_is_cmat(mat->type));
269 nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
270
271 vtn_assert(num_indices == 1);
272 nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
273
274 const struct glsl_type *element_type = glsl_get_cmat_element(mat->type);
275 struct vtn_ssa_value *ret = vtn_create_ssa_value(b, element_type);
276 ret->def = nir_cmat_extract(&b->nb, glsl_get_bit_size(element_type), &mat_deref->def, index);
277 return ret;
278 }
279
280 struct vtn_ssa_value *
vtn_cooperative_matrix_insert(struct vtn_builder * b,struct vtn_ssa_value * mat,struct vtn_ssa_value * insert,const uint32_t * indices,unsigned num_indices)281 vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
282 struct vtn_ssa_value *insert, const uint32_t *indices,
283 unsigned num_indices)
284 {
285 vtn_assert(glsl_type_is_cmat(mat->type));
286 nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
287
288 vtn_assert(num_indices == 1);
289 nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
290
291 nir_deref_instr *dst = vtn_create_cmat_temporary(b, mat_deref->type, "cmat_insert");
292 nir_cmat_insert(&b->nb, &dst->def, insert->def, &mat_deref->def, index);
293
294 struct vtn_ssa_value *ret = vtn_create_ssa_value(b, dst->type);
295 vtn_set_ssa_value_var(b, ret, dst->var);
296 return ret;
297 }
298