xref: /aosp_15_r20/external/mesa3d/src/compiler/spirv/vtn_cmat.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1*61046927SAndroid Build Coastguard Worker /*
2*61046927SAndroid Build Coastguard Worker  * Copyright 2023 Intel Corporation
3*61046927SAndroid Build Coastguard Worker  * SPDX-License-Identifier: MIT
4*61046927SAndroid Build Coastguard Worker  */
5*61046927SAndroid Build Coastguard Worker 
6*61046927SAndroid Build Coastguard Worker #include "glsl_types.h"
7*61046927SAndroid Build Coastguard Worker #include "nir.h"
8*61046927SAndroid Build Coastguard Worker #include "vtn_private.h"
9*61046927SAndroid Build Coastguard Worker 
10*61046927SAndroid Build Coastguard Worker static enum glsl_cmat_use
vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)11*61046927SAndroid Build Coastguard Worker vtn_cooperative_matrix_use_to_glsl(SpvCooperativeMatrixUse use)
12*61046927SAndroid Build Coastguard Worker {
13*61046927SAndroid Build Coastguard Worker    switch (use) {
14*61046927SAndroid Build Coastguard Worker    case SpvCooperativeMatrixUseMatrixAKHR:
15*61046927SAndroid Build Coastguard Worker       return GLSL_CMAT_USE_A;
16*61046927SAndroid Build Coastguard Worker    case SpvCooperativeMatrixUseMatrixBKHR:
17*61046927SAndroid Build Coastguard Worker       return GLSL_CMAT_USE_B;
18*61046927SAndroid Build Coastguard Worker    case SpvCooperativeMatrixUseMatrixAccumulatorKHR:
19*61046927SAndroid Build Coastguard Worker       return GLSL_CMAT_USE_ACCUMULATOR;
20*61046927SAndroid Build Coastguard Worker    default:
21*61046927SAndroid Build Coastguard Worker       unreachable("Unexpected cooperative matrix use");
22*61046927SAndroid Build Coastguard Worker    }
23*61046927SAndroid Build Coastguard Worker }
24*61046927SAndroid Build Coastguard Worker 
25*61046927SAndroid Build Coastguard Worker void
vtn_handle_cooperative_type(struct vtn_builder * b,struct vtn_value * val,SpvOp opcode,const uint32_t * w,unsigned count)26*61046927SAndroid Build Coastguard Worker vtn_handle_cooperative_type(struct vtn_builder *b, struct vtn_value *val,
27*61046927SAndroid Build Coastguard Worker                             SpvOp opcode, const uint32_t *w, unsigned count)
28*61046927SAndroid Build Coastguard Worker {
29*61046927SAndroid Build Coastguard Worker    vtn_assert(opcode == SpvOpTypeCooperativeMatrixKHR);
30*61046927SAndroid Build Coastguard Worker 
31*61046927SAndroid Build Coastguard Worker    b->shader->info.cs.has_cooperative_matrix = true;
32*61046927SAndroid Build Coastguard Worker 
33*61046927SAndroid Build Coastguard Worker    struct vtn_type *component_type = vtn_get_type(b, w[2]);
34*61046927SAndroid Build Coastguard Worker 
35*61046927SAndroid Build Coastguard Worker    const mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[3]));
36*61046927SAndroid Build Coastguard Worker    const uint32_t rows = vtn_constant_uint(b, w[4]);
37*61046927SAndroid Build Coastguard Worker    const uint32_t cols = vtn_constant_uint(b, w[5]);
38*61046927SAndroid Build Coastguard Worker 
39*61046927SAndroid Build Coastguard Worker    vtn_assert(rows < 256);
40*61046927SAndroid Build Coastguard Worker    vtn_assert(cols < 256);
41*61046927SAndroid Build Coastguard Worker 
42*61046927SAndroid Build Coastguard Worker    enum glsl_cmat_use use = vtn_cooperative_matrix_use_to_glsl(vtn_constant_uint(b, w[6]));
43*61046927SAndroid Build Coastguard Worker 
44*61046927SAndroid Build Coastguard Worker    val->type->base_type = vtn_base_type_cooperative_matrix;
45*61046927SAndroid Build Coastguard Worker    vtn_fail_if(!glsl_type_is_numeric(component_type->type),
46*61046927SAndroid Build Coastguard Worker                "OpTypeCooperativeMatrixKHR "
47*61046927SAndroid Build Coastguard Worker                "Component Type must be a scalar numerical type.");
48*61046927SAndroid Build Coastguard Worker 
49*61046927SAndroid Build Coastguard Worker    val->type->desc.element_type = glsl_get_base_type(component_type->type);
50*61046927SAndroid Build Coastguard Worker    val->type->desc.scope = scope;
51*61046927SAndroid Build Coastguard Worker    val->type->desc.rows = rows;
52*61046927SAndroid Build Coastguard Worker    val->type->desc.cols = cols;
53*61046927SAndroid Build Coastguard Worker    val->type->desc.use = use;
54*61046927SAndroid Build Coastguard Worker 
55*61046927SAndroid Build Coastguard Worker    val->type->type = glsl_cmat_type(&val->type->desc);
56*61046927SAndroid Build Coastguard Worker    val->type->component_type = component_type;
57*61046927SAndroid Build Coastguard Worker }
58*61046927SAndroid Build Coastguard Worker 
59*61046927SAndroid Build Coastguard Worker static enum glsl_matrix_layout
vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)60*61046927SAndroid Build Coastguard Worker vtn_matrix_layout_to_glsl(SpvCooperativeMatrixLayout layout)
61*61046927SAndroid Build Coastguard Worker {
62*61046927SAndroid Build Coastguard Worker    switch (layout) {
63*61046927SAndroid Build Coastguard Worker    case SpvCooperativeMatrixLayoutRowMajorKHR:
64*61046927SAndroid Build Coastguard Worker       return GLSL_MATRIX_LAYOUT_ROW_MAJOR;
65*61046927SAndroid Build Coastguard Worker    case SpvCooperativeMatrixLayoutColumnMajorKHR:
66*61046927SAndroid Build Coastguard Worker       return GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
67*61046927SAndroid Build Coastguard Worker    default:
68*61046927SAndroid Build Coastguard Worker       unreachable("Unexpected cooperative matrix layout");
69*61046927SAndroid Build Coastguard Worker    }
70*61046927SAndroid Build Coastguard Worker }
71*61046927SAndroid Build Coastguard Worker 
72*61046927SAndroid Build Coastguard Worker nir_deref_instr *
vtn_create_cmat_temporary(struct vtn_builder * b,const struct glsl_type * t,const char * name)73*61046927SAndroid Build Coastguard Worker vtn_create_cmat_temporary(struct vtn_builder *b, const struct glsl_type *t, const char *name)
74*61046927SAndroid Build Coastguard Worker {
75*61046927SAndroid Build Coastguard Worker    nir_variable *var = nir_local_variable_create(b->nb.impl, t, name);
76*61046927SAndroid Build Coastguard Worker    return nir_build_deref_var(&b->nb, var);
77*61046927SAndroid Build Coastguard Worker }
78*61046927SAndroid Build Coastguard Worker 
79*61046927SAndroid Build Coastguard Worker static nir_deref_instr *
vtn_get_cmat_deref(struct vtn_builder * b,uint32_t value_id)80*61046927SAndroid Build Coastguard Worker vtn_get_cmat_deref(struct vtn_builder *b, uint32_t value_id)
81*61046927SAndroid Build Coastguard Worker {
82*61046927SAndroid Build Coastguard Worker    nir_deref_instr *deref = vtn_get_deref_for_id(b, value_id);
83*61046927SAndroid Build Coastguard Worker    vtn_assert(glsl_type_is_cmat(deref->type));
84*61046927SAndroid Build Coastguard Worker    return deref;
85*61046927SAndroid Build Coastguard Worker }
86*61046927SAndroid Build Coastguard Worker 
87*61046927SAndroid Build Coastguard Worker void
vtn_handle_cooperative_instruction(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)88*61046927SAndroid Build Coastguard Worker vtn_handle_cooperative_instruction(struct vtn_builder *b, SpvOp opcode,
89*61046927SAndroid Build Coastguard Worker                                    const uint32_t *w, unsigned count)
90*61046927SAndroid Build Coastguard Worker {
91*61046927SAndroid Build Coastguard Worker    switch (opcode) {
92*61046927SAndroid Build Coastguard Worker    case SpvOpCooperativeMatrixLoadKHR: {
93*61046927SAndroid Build Coastguard Worker       struct vtn_value *src_val = vtn_value(b, w[3], vtn_value_type_pointer);
94*61046927SAndroid Build Coastguard Worker       struct vtn_pointer *src = vtn_value_to_pointer(b, src_val);
95*61046927SAndroid Build Coastguard Worker       struct vtn_type *dst_type = vtn_get_type(b, w[1]);
96*61046927SAndroid Build Coastguard Worker 
97*61046927SAndroid Build Coastguard Worker       const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[4]);
98*61046927SAndroid Build Coastguard Worker       nir_def *stride = count > 5 ? vtn_get_nir_ssa(b, w[5]) : nir_imm_zero(&b->nb, 1, 32);
99*61046927SAndroid Build Coastguard Worker 
100*61046927SAndroid Build Coastguard Worker       SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
101*61046927SAndroid Build Coastguard Worker       if (count > 6) {
102*61046927SAndroid Build Coastguard Worker          unsigned idx = 6, alignment;
103*61046927SAndroid Build Coastguard Worker          SpvScope scope;
104*61046927SAndroid Build Coastguard Worker          vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, NULL, &scope);
105*61046927SAndroid Build Coastguard Worker          vtn_emit_make_visible_barrier(b, access, scope, src->mode);
106*61046927SAndroid Build Coastguard Worker       }
107*61046927SAndroid Build Coastguard Worker 
108*61046927SAndroid Build Coastguard Worker       nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
109*61046927SAndroid Build Coastguard Worker       nir_cmat_load(&b->nb, &dst->def, vtn_pointer_to_ssa(b, src), stride,
110*61046927SAndroid Build Coastguard Worker                     .matrix_layout = vtn_matrix_layout_to_glsl(layout));
111*61046927SAndroid Build Coastguard Worker       vtn_push_var_ssa(b, w[2], dst->var);
112*61046927SAndroid Build Coastguard Worker       break;
113*61046927SAndroid Build Coastguard Worker    }
114*61046927SAndroid Build Coastguard Worker 
115*61046927SAndroid Build Coastguard Worker    case SpvOpCooperativeMatrixStoreKHR: {
116*61046927SAndroid Build Coastguard Worker       struct vtn_value *dest_val = vtn_value(b, w[1], vtn_value_type_pointer);
117*61046927SAndroid Build Coastguard Worker       struct vtn_pointer *dest = vtn_value_to_pointer(b, dest_val);
118*61046927SAndroid Build Coastguard Worker 
119*61046927SAndroid Build Coastguard Worker       const SpvCooperativeMatrixLayout layout = vtn_constant_uint(b, w[3]);
120*61046927SAndroid Build Coastguard Worker       nir_def *stride = count > 4 ? vtn_get_nir_ssa(b, w[4]) : nir_imm_zero(&b->nb, 1, 32);
121*61046927SAndroid Build Coastguard Worker 
122*61046927SAndroid Build Coastguard Worker       SpvMemoryAccessMask access = SpvMemoryAccessMaskNone;
123*61046927SAndroid Build Coastguard Worker       if (count > 5) {
124*61046927SAndroid Build Coastguard Worker          unsigned idx = 5, alignment;
125*61046927SAndroid Build Coastguard Worker          SpvScope scope;
126*61046927SAndroid Build Coastguard Worker          vtn_get_mem_operands(b, w, count, &idx, &access, &alignment, &scope, NULL);
127*61046927SAndroid Build Coastguard Worker          vtn_emit_make_available_barrier(b, access, scope, dest->mode);
128*61046927SAndroid Build Coastguard Worker       }
129*61046927SAndroid Build Coastguard Worker 
130*61046927SAndroid Build Coastguard Worker       nir_deref_instr *src = vtn_get_cmat_deref(b, w[2]);
131*61046927SAndroid Build Coastguard Worker       nir_cmat_store(&b->nb, vtn_pointer_to_ssa(b, dest), &src->def, stride,
132*61046927SAndroid Build Coastguard Worker                      .matrix_layout = vtn_matrix_layout_to_glsl(layout));
133*61046927SAndroid Build Coastguard Worker       break;
134*61046927SAndroid Build Coastguard Worker    }
135*61046927SAndroid Build Coastguard Worker 
136*61046927SAndroid Build Coastguard Worker    case SpvOpCooperativeMatrixLengthKHR: {
137*61046927SAndroid Build Coastguard Worker       struct vtn_type *type = vtn_get_type(b, w[3]);
138*61046927SAndroid Build Coastguard Worker       nir_def *def = nir_cmat_length(&b->nb, .cmat_desc = type->desc);
139*61046927SAndroid Build Coastguard Worker       vtn_push_nir_ssa(b, w[2], def);
140*61046927SAndroid Build Coastguard Worker       break;
141*61046927SAndroid Build Coastguard Worker    }
142*61046927SAndroid Build Coastguard Worker 
143*61046927SAndroid Build Coastguard Worker    case SpvOpCooperativeMatrixMulAddKHR: {
144*61046927SAndroid Build Coastguard Worker       nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
145*61046927SAndroid Build Coastguard Worker       nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
146*61046927SAndroid Build Coastguard Worker       nir_deref_instr *mat_c = vtn_get_cmat_deref(b, w[5]);
147*61046927SAndroid Build Coastguard Worker 
148*61046927SAndroid Build Coastguard Worker       const uint32_t operands = count > 6 ? w[6] : 0;
149*61046927SAndroid Build Coastguard Worker       const bool saturate = operands & SpvCooperativeMatrixOperandsSaturatingAccumulationKHRMask;
150*61046927SAndroid Build Coastguard Worker       const unsigned signed_mask = operands & (SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask |
151*61046927SAndroid Build Coastguard Worker                                                SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask |
152*61046927SAndroid Build Coastguard Worker                                                SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask |
153*61046927SAndroid Build Coastguard Worker                                                SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask);
154*61046927SAndroid Build Coastguard Worker 
155*61046927SAndroid Build Coastguard Worker       STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixASignedComponentsKHRMask == NIR_CMAT_A_SIGNED);
156*61046927SAndroid Build Coastguard Worker       STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixBSignedComponentsKHRMask == NIR_CMAT_B_SIGNED);
157*61046927SAndroid Build Coastguard Worker       STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixCSignedComponentsKHRMask == NIR_CMAT_C_SIGNED);
158*61046927SAndroid Build Coastguard Worker       STATIC_ASSERT((unsigned)SpvCooperativeMatrixOperandsMatrixResultSignedComponentsKHRMask == NIR_CMAT_RESULT_SIGNED);
159*61046927SAndroid Build Coastguard Worker 
160*61046927SAndroid Build Coastguard Worker       struct vtn_type *dst_type = vtn_get_type(b, w[1]);
161*61046927SAndroid Build Coastguard Worker       nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_muladd");
162*61046927SAndroid Build Coastguard Worker 
163*61046927SAndroid Build Coastguard Worker       nir_cmat_muladd(&b->nb, &dst->def, &mat_a->def, &mat_b->def, &mat_c->def,
164*61046927SAndroid Build Coastguard Worker                       .saturate = saturate,
165*61046927SAndroid Build Coastguard Worker                       .cmat_signed_mask = signed_mask);
166*61046927SAndroid Build Coastguard Worker 
167*61046927SAndroid Build Coastguard Worker       vtn_push_var_ssa(b, w[2], dst->var);
168*61046927SAndroid Build Coastguard Worker       break;
169*61046927SAndroid Build Coastguard Worker    }
170*61046927SAndroid Build Coastguard Worker 
171*61046927SAndroid Build Coastguard Worker    case SpvOpBitcast: {
172*61046927SAndroid Build Coastguard Worker       struct vtn_type *dst_type = vtn_get_type(b, w[1]);
173*61046927SAndroid Build Coastguard Worker       vtn_assert(dst_type->base_type == vtn_base_type_cooperative_matrix);
174*61046927SAndroid Build Coastguard Worker       nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
175*61046927SAndroid Build Coastguard Worker 
176*61046927SAndroid Build Coastguard Worker       nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_bitcast");
177*61046927SAndroid Build Coastguard Worker       nir_cmat_bitcast(&b->nb, &dst->def, &src->def);
178*61046927SAndroid Build Coastguard Worker       vtn_push_var_ssa(b, w[2], dst->var);
179*61046927SAndroid Build Coastguard Worker       break;
180*61046927SAndroid Build Coastguard Worker    }
181*61046927SAndroid Build Coastguard Worker 
182*61046927SAndroid Build Coastguard Worker    default:
183*61046927SAndroid Build Coastguard Worker       unreachable("Unexpected opcode for cooperative matrix instruction");
184*61046927SAndroid Build Coastguard Worker    }
185*61046927SAndroid Build Coastguard Worker }
186*61046927SAndroid Build Coastguard Worker 
187*61046927SAndroid Build Coastguard Worker void
vtn_handle_cooperative_alu(struct vtn_builder * b,struct vtn_value * dest_val,const struct glsl_type * dest_type,SpvOp opcode,const uint32_t * w,unsigned count)188*61046927SAndroid Build Coastguard Worker vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
189*61046927SAndroid Build Coastguard Worker                            const struct glsl_type *dest_type, SpvOp opcode,
190*61046927SAndroid Build Coastguard Worker                            const uint32_t *w, unsigned count)
191*61046927SAndroid Build Coastguard Worker {
192*61046927SAndroid Build Coastguard Worker       vtn_assert(glsl_type_is_cmat(dest_type));
193*61046927SAndroid Build Coastguard Worker 
194*61046927SAndroid Build Coastguard Worker       switch (opcode) {
195*61046927SAndroid Build Coastguard Worker       case SpvOpConvertFToU:
196*61046927SAndroid Build Coastguard Worker       case SpvOpConvertFToS:
197*61046927SAndroid Build Coastguard Worker       case SpvOpConvertSToF:
198*61046927SAndroid Build Coastguard Worker       case SpvOpConvertUToF:
199*61046927SAndroid Build Coastguard Worker       case SpvOpUConvert:
200*61046927SAndroid Build Coastguard Worker       case SpvOpSConvert:
201*61046927SAndroid Build Coastguard Worker       case SpvOpFConvert:
202*61046927SAndroid Build Coastguard Worker       case SpvOpFNegate:
203*61046927SAndroid Build Coastguard Worker       case SpvOpSNegate: {
204*61046927SAndroid Build Coastguard Worker          struct vtn_type *dst_type = vtn_get_type(b, w[1]);
205*61046927SAndroid Build Coastguard Worker          nir_deref_instr *src = vtn_get_cmat_deref(b, w[3]);
206*61046927SAndroid Build Coastguard Worker 
207*61046927SAndroid Build Coastguard Worker          unsigned src_bit_size = glsl_get_bit_size(glsl_get_cmat_element(src->type));
208*61046927SAndroid Build Coastguard Worker          unsigned dst_bit_size = glsl_get_bit_size(glsl_get_cmat_element(dst_type->type));
209*61046927SAndroid Build Coastguard Worker 
210*61046927SAndroid Build Coastguard Worker          bool ignored = false;
211*61046927SAndroid Build Coastguard Worker          nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored,
212*61046927SAndroid Build Coastguard Worker                                                      src_bit_size, dst_bit_size);
213*61046927SAndroid Build Coastguard Worker 
214*61046927SAndroid Build Coastguard Worker          nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
215*61046927SAndroid Build Coastguard Worker          nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
216*61046927SAndroid Build Coastguard Worker                            .alu_op = op);
217*61046927SAndroid Build Coastguard Worker          vtn_push_var_ssa(b, w[2], dst->var);
218*61046927SAndroid Build Coastguard Worker          break;
219*61046927SAndroid Build Coastguard Worker       }
220*61046927SAndroid Build Coastguard Worker 
221*61046927SAndroid Build Coastguard Worker       case SpvOpFAdd:
222*61046927SAndroid Build Coastguard Worker       case SpvOpFSub:
223*61046927SAndroid Build Coastguard Worker       case SpvOpFMul:
224*61046927SAndroid Build Coastguard Worker       case SpvOpFDiv:
225*61046927SAndroid Build Coastguard Worker       case SpvOpIAdd:
226*61046927SAndroid Build Coastguard Worker       case SpvOpISub:
227*61046927SAndroid Build Coastguard Worker       case SpvOpIMul:
228*61046927SAndroid Build Coastguard Worker       case SpvOpSDiv:
229*61046927SAndroid Build Coastguard Worker       case SpvOpUDiv: {
230*61046927SAndroid Build Coastguard Worker          bool ignored = false;
231*61046927SAndroid Build Coastguard Worker          nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &ignored, &ignored, 0, 0);
232*61046927SAndroid Build Coastguard Worker 
233*61046927SAndroid Build Coastguard Worker          struct vtn_type *dst_type = vtn_get_type(b, w[1]);
234*61046927SAndroid Build Coastguard Worker          nir_deref_instr *mat_a = vtn_get_cmat_deref(b, w[3]);
235*61046927SAndroid Build Coastguard Worker          nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
236*61046927SAndroid Build Coastguard Worker 
237*61046927SAndroid Build Coastguard Worker          nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
238*61046927SAndroid Build Coastguard Worker          nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,
239*61046927SAndroid Build Coastguard Worker                             .alu_op = op);
240*61046927SAndroid Build Coastguard Worker          vtn_push_var_ssa(b, w[2], dst->var);
241*61046927SAndroid Build Coastguard Worker          break;
242*61046927SAndroid Build Coastguard Worker       }
243*61046927SAndroid Build Coastguard Worker 
244*61046927SAndroid Build Coastguard Worker       case SpvOpMatrixTimesScalar: {
245*61046927SAndroid Build Coastguard Worker          struct vtn_type *dst_type = vtn_get_type(b, w[1]);
246*61046927SAndroid Build Coastguard Worker          nir_deref_instr *mat = vtn_get_cmat_deref(b, w[3]);
247*61046927SAndroid Build Coastguard Worker 
248*61046927SAndroid Build Coastguard Worker          struct vtn_ssa_value *scalar_val = vtn_ssa_value(b, w[4]);
249*61046927SAndroid Build Coastguard Worker          vtn_assert(glsl_type_is_scalar(scalar_val->type));
250*61046927SAndroid Build Coastguard Worker          nir_op op = glsl_type_is_integer(scalar_val->type) ? nir_op_imul : nir_op_fmul;
251*61046927SAndroid Build Coastguard Worker 
252*61046927SAndroid Build Coastguard Worker          nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_times_scalar");
253*61046927SAndroid Build Coastguard Worker          nir_cmat_scalar_op(&b->nb, &dst->def, &mat->def, scalar_val->def,
254*61046927SAndroid Build Coastguard Worker                             .alu_op = op);
255*61046927SAndroid Build Coastguard Worker          vtn_push_var_ssa(b, w[2], dst->var);
256*61046927SAndroid Build Coastguard Worker          break;
257*61046927SAndroid Build Coastguard Worker       }
258*61046927SAndroid Build Coastguard Worker 
259*61046927SAndroid Build Coastguard Worker       default:
260*61046927SAndroid Build Coastguard Worker          unreachable("invalid cooperative matrix alu instruction");
261*61046927SAndroid Build Coastguard Worker       }
262*61046927SAndroid Build Coastguard Worker }
263*61046927SAndroid Build Coastguard Worker 
264*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *
vtn_cooperative_matrix_extract(struct vtn_builder * b,struct vtn_ssa_value * mat,const uint32_t * indices,unsigned num_indices)265*61046927SAndroid Build Coastguard Worker vtn_cooperative_matrix_extract(struct vtn_builder *b, struct vtn_ssa_value *mat,
266*61046927SAndroid Build Coastguard Worker                                const uint32_t *indices, unsigned num_indices)
267*61046927SAndroid Build Coastguard Worker {
268*61046927SAndroid Build Coastguard Worker    vtn_assert(glsl_type_is_cmat(mat->type));
269*61046927SAndroid Build Coastguard Worker    nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
270*61046927SAndroid Build Coastguard Worker 
271*61046927SAndroid Build Coastguard Worker    vtn_assert(num_indices == 1);
272*61046927SAndroid Build Coastguard Worker    nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
273*61046927SAndroid Build Coastguard Worker 
274*61046927SAndroid Build Coastguard Worker    const struct glsl_type *element_type = glsl_get_cmat_element(mat->type);
275*61046927SAndroid Build Coastguard Worker    struct vtn_ssa_value *ret = vtn_create_ssa_value(b, element_type);
276*61046927SAndroid Build Coastguard Worker    ret->def = nir_cmat_extract(&b->nb, glsl_get_bit_size(element_type), &mat_deref->def, index);
277*61046927SAndroid Build Coastguard Worker    return ret;
278*61046927SAndroid Build Coastguard Worker }
279*61046927SAndroid Build Coastguard Worker 
280*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *
vtn_cooperative_matrix_insert(struct vtn_builder * b,struct vtn_ssa_value * mat,struct vtn_ssa_value * insert,const uint32_t * indices,unsigned num_indices)281*61046927SAndroid Build Coastguard Worker vtn_cooperative_matrix_insert(struct vtn_builder *b, struct vtn_ssa_value *mat,
282*61046927SAndroid Build Coastguard Worker                               struct vtn_ssa_value *insert, const uint32_t *indices,
283*61046927SAndroid Build Coastguard Worker                               unsigned num_indices)
284*61046927SAndroid Build Coastguard Worker {
285*61046927SAndroid Build Coastguard Worker    vtn_assert(glsl_type_is_cmat(mat->type));
286*61046927SAndroid Build Coastguard Worker    nir_deref_instr *mat_deref = vtn_get_deref_for_ssa_value(b, mat);
287*61046927SAndroid Build Coastguard Worker 
288*61046927SAndroid Build Coastguard Worker    vtn_assert(num_indices == 1);
289*61046927SAndroid Build Coastguard Worker    nir_def *index = nir_imm_intN_t(&b->nb, indices[0], 32);
290*61046927SAndroid Build Coastguard Worker 
291*61046927SAndroid Build Coastguard Worker    nir_deref_instr *dst = vtn_create_cmat_temporary(b, mat_deref->type, "cmat_insert");
292*61046927SAndroid Build Coastguard Worker    nir_cmat_insert(&b->nb, &dst->def, insert->def, &mat_deref->def, index);
293*61046927SAndroid Build Coastguard Worker 
294*61046927SAndroid Build Coastguard Worker    struct vtn_ssa_value *ret = vtn_create_ssa_value(b, dst->type);
295*61046927SAndroid Build Coastguard Worker    vtn_set_ssa_value_var(b, ret, dst->var);
296*61046927SAndroid Build Coastguard Worker    return ret;
297*61046927SAndroid Build Coastguard Worker }
298