1*61046927SAndroid Build Coastguard Worker /*
2*61046927SAndroid Build Coastguard Worker * Copyright © 2016 Intel Corporation
3*61046927SAndroid Build Coastguard Worker *
4*61046927SAndroid Build Coastguard Worker * Permission is hereby granted, free of charge, to any person obtaining a
5*61046927SAndroid Build Coastguard Worker * copy of this software and associated documentation files (the "Software"),
6*61046927SAndroid Build Coastguard Worker * to deal in the Software without restriction, including without limitation
7*61046927SAndroid Build Coastguard Worker * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8*61046927SAndroid Build Coastguard Worker * and/or sell copies of the Software, and to permit persons to whom the
9*61046927SAndroid Build Coastguard Worker * Software is furnished to do so, subject to the following conditions:
10*61046927SAndroid Build Coastguard Worker *
11*61046927SAndroid Build Coastguard Worker * The above copyright notice and this permission notice (including the next
12*61046927SAndroid Build Coastguard Worker * paragraph) shall be included in all copies or substantial portions of the
13*61046927SAndroid Build Coastguard Worker * Software.
14*61046927SAndroid Build Coastguard Worker *
15*61046927SAndroid Build Coastguard Worker * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16*61046927SAndroid Build Coastguard Worker * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17*61046927SAndroid Build Coastguard Worker * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18*61046927SAndroid Build Coastguard Worker * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19*61046927SAndroid Build Coastguard Worker * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20*61046927SAndroid Build Coastguard Worker * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21*61046927SAndroid Build Coastguard Worker * IN THE SOFTWARE.
22*61046927SAndroid Build Coastguard Worker */
23*61046927SAndroid Build Coastguard Worker
24*61046927SAndroid Build Coastguard Worker #include <math.h>
25*61046927SAndroid Build Coastguard Worker #include "vtn_private.h"
26*61046927SAndroid Build Coastguard Worker #include "spirv_info.h"
27*61046927SAndroid Build Coastguard Worker
28*61046927SAndroid Build Coastguard Worker /*
29*61046927SAndroid Build Coastguard Worker * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30*61046927SAndroid Build Coastguard Worker * definition. But for matrix multiplies, we want to do one routine for
31*61046927SAndroid Build Coastguard Worker * multiplying a matrix by a matrix and then pretend that vectors are matrices
32*61046927SAndroid Build Coastguard Worker * with one column. So we "wrap" these things, and unwrap the result before we
33*61046927SAndroid Build Coastguard Worker * send it off.
34*61046927SAndroid Build Coastguard Worker */
35*61046927SAndroid Build Coastguard Worker
36*61046927SAndroid Build Coastguard Worker static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37*61046927SAndroid Build Coastguard Worker wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38*61046927SAndroid Build Coastguard Worker {
39*61046927SAndroid Build Coastguard Worker if (val == NULL)
40*61046927SAndroid Build Coastguard Worker return NULL;
41*61046927SAndroid Build Coastguard Worker
42*61046927SAndroid Build Coastguard Worker if (glsl_type_is_matrix(val->type))
43*61046927SAndroid Build Coastguard Worker return val;
44*61046927SAndroid Build Coastguard Worker
45*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *dest = vtn_zalloc(b, struct vtn_ssa_value);
46*61046927SAndroid Build Coastguard Worker dest->type = glsl_get_bare_type(val->type);
47*61046927SAndroid Build Coastguard Worker dest->elems = vtn_alloc_array(b, struct vtn_ssa_value *, 1);
48*61046927SAndroid Build Coastguard Worker dest->elems[0] = val;
49*61046927SAndroid Build Coastguard Worker
50*61046927SAndroid Build Coastguard Worker return dest;
51*61046927SAndroid Build Coastguard Worker }
52*61046927SAndroid Build Coastguard Worker
53*61046927SAndroid Build Coastguard Worker static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54*61046927SAndroid Build Coastguard Worker unwrap_matrix(struct vtn_ssa_value *val)
55*61046927SAndroid Build Coastguard Worker {
56*61046927SAndroid Build Coastguard Worker if (glsl_type_is_matrix(val->type))
57*61046927SAndroid Build Coastguard Worker return val;
58*61046927SAndroid Build Coastguard Worker
59*61046927SAndroid Build Coastguard Worker return val->elems[0];
60*61046927SAndroid Build Coastguard Worker }
61*61046927SAndroid Build Coastguard Worker
62*61046927SAndroid Build Coastguard Worker static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63*61046927SAndroid Build Coastguard Worker matrix_multiply(struct vtn_builder *b,
64*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65*61046927SAndroid Build Coastguard Worker {
66*61046927SAndroid Build Coastguard Worker
67*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71*61046927SAndroid Build Coastguard Worker
72*61046927SAndroid Build Coastguard Worker unsigned src0_rows = glsl_get_vector_elements(src0->type);
73*61046927SAndroid Build Coastguard Worker unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74*61046927SAndroid Build Coastguard Worker unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75*61046927SAndroid Build Coastguard Worker
76*61046927SAndroid Build Coastguard Worker const struct glsl_type *dest_type;
77*61046927SAndroid Build Coastguard Worker if (src1_columns > 1) {
78*61046927SAndroid Build Coastguard Worker dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79*61046927SAndroid Build Coastguard Worker src0_rows, src1_columns);
80*61046927SAndroid Build Coastguard Worker } else {
81*61046927SAndroid Build Coastguard Worker dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82*61046927SAndroid Build Coastguard Worker }
83*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84*61046927SAndroid Build Coastguard Worker
85*61046927SAndroid Build Coastguard Worker dest = wrap_matrix(b, dest);
86*61046927SAndroid Build Coastguard Worker
87*61046927SAndroid Build Coastguard Worker bool transpose_result = false;
88*61046927SAndroid Build Coastguard Worker if (src0_transpose && src1_transpose) {
89*61046927SAndroid Build Coastguard Worker /* transpose(A) * transpose(B) = transpose(B * A) */
90*61046927SAndroid Build Coastguard Worker src1 = src0_transpose;
91*61046927SAndroid Build Coastguard Worker src0 = src1_transpose;
92*61046927SAndroid Build Coastguard Worker src0_transpose = NULL;
93*61046927SAndroid Build Coastguard Worker src1_transpose = NULL;
94*61046927SAndroid Build Coastguard Worker transpose_result = true;
95*61046927SAndroid Build Coastguard Worker }
96*61046927SAndroid Build Coastguard Worker
97*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < src1_columns; i++) {
98*61046927SAndroid Build Coastguard Worker /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
99*61046927SAndroid Build Coastguard Worker dest->elems[i]->def =
100*61046927SAndroid Build Coastguard Worker nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
101*61046927SAndroid Build Coastguard Worker nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
102*61046927SAndroid Build Coastguard Worker for (int j = src0_columns - 2; j >= 0; j--) {
103*61046927SAndroid Build Coastguard Worker dest->elems[i]->def =
104*61046927SAndroid Build Coastguard Worker nir_ffma(&b->nb, src0->elems[j]->def,
105*61046927SAndroid Build Coastguard Worker nir_channel(&b->nb, src1->elems[i]->def, j),
106*61046927SAndroid Build Coastguard Worker dest->elems[i]->def);
107*61046927SAndroid Build Coastguard Worker }
108*61046927SAndroid Build Coastguard Worker }
109*61046927SAndroid Build Coastguard Worker
110*61046927SAndroid Build Coastguard Worker dest = unwrap_matrix(dest);
111*61046927SAndroid Build Coastguard Worker
112*61046927SAndroid Build Coastguard Worker if (transpose_result)
113*61046927SAndroid Build Coastguard Worker dest = vtn_ssa_transpose(b, dest);
114*61046927SAndroid Build Coastguard Worker
115*61046927SAndroid Build Coastguard Worker return dest;
116*61046927SAndroid Build Coastguard Worker }
117*61046927SAndroid Build Coastguard Worker
118*61046927SAndroid Build Coastguard Worker static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_def * scalar)119*61046927SAndroid Build Coastguard Worker mat_times_scalar(struct vtn_builder *b,
120*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *mat,
121*61046927SAndroid Build Coastguard Worker nir_def *scalar)
122*61046927SAndroid Build Coastguard Worker {
123*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
124*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
125*61046927SAndroid Build Coastguard Worker if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
126*61046927SAndroid Build Coastguard Worker dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
127*61046927SAndroid Build Coastguard Worker else
128*61046927SAndroid Build Coastguard Worker dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
129*61046927SAndroid Build Coastguard Worker }
130*61046927SAndroid Build Coastguard Worker
131*61046927SAndroid Build Coastguard Worker return dest;
132*61046927SAndroid Build Coastguard Worker }
133*61046927SAndroid Build Coastguard Worker
134*61046927SAndroid Build Coastguard Worker nir_def *
vtn_mediump_downconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)135*61046927SAndroid Build Coastguard Worker vtn_mediump_downconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
136*61046927SAndroid Build Coastguard Worker {
137*61046927SAndroid Build Coastguard Worker if (def->bit_size == 16)
138*61046927SAndroid Build Coastguard Worker return def;
139*61046927SAndroid Build Coastguard Worker
140*61046927SAndroid Build Coastguard Worker switch (base_type) {
141*61046927SAndroid Build Coastguard Worker case GLSL_TYPE_FLOAT:
142*61046927SAndroid Build Coastguard Worker return nir_f2fmp(&b->nb, def);
143*61046927SAndroid Build Coastguard Worker case GLSL_TYPE_INT:
144*61046927SAndroid Build Coastguard Worker case GLSL_TYPE_UINT:
145*61046927SAndroid Build Coastguard Worker return nir_i2imp(&b->nb, def);
146*61046927SAndroid Build Coastguard Worker /* Workaround for 3DMark Wild Life which has RelaxedPrecision on
147*61046927SAndroid Build Coastguard Worker * OpLogical* operations (which is forbidden by spec).
148*61046927SAndroid Build Coastguard Worker */
149*61046927SAndroid Build Coastguard Worker case GLSL_TYPE_BOOL:
150*61046927SAndroid Build Coastguard Worker return def;
151*61046927SAndroid Build Coastguard Worker default:
152*61046927SAndroid Build Coastguard Worker unreachable("bad relaxed precision input type");
153*61046927SAndroid Build Coastguard Worker }
154*61046927SAndroid Build Coastguard Worker }
155*61046927SAndroid Build Coastguard Worker
156*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *
vtn_mediump_downconvert_value(struct vtn_builder * b,struct vtn_ssa_value * src)157*61046927SAndroid Build Coastguard Worker vtn_mediump_downconvert_value(struct vtn_builder *b, struct vtn_ssa_value *src)
158*61046927SAndroid Build Coastguard Worker {
159*61046927SAndroid Build Coastguard Worker if (!src)
160*61046927SAndroid Build Coastguard Worker return src;
161*61046927SAndroid Build Coastguard Worker
162*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *srcmp = vtn_create_ssa_value(b, src->type);
163*61046927SAndroid Build Coastguard Worker
164*61046927SAndroid Build Coastguard Worker if (src->transposed) {
165*61046927SAndroid Build Coastguard Worker srcmp->transposed = vtn_mediump_downconvert_value(b, src->transposed);
166*61046927SAndroid Build Coastguard Worker } else {
167*61046927SAndroid Build Coastguard Worker enum glsl_base_type base_type = glsl_get_base_type(src->type);
168*61046927SAndroid Build Coastguard Worker
169*61046927SAndroid Build Coastguard Worker if (glsl_type_is_vector_or_scalar(src->type)) {
170*61046927SAndroid Build Coastguard Worker srcmp->def = vtn_mediump_downconvert(b, base_type, src->def);
171*61046927SAndroid Build Coastguard Worker } else {
172*61046927SAndroid Build Coastguard Worker assert(glsl_get_base_type(src->type) == GLSL_TYPE_FLOAT);
173*61046927SAndroid Build Coastguard Worker for (int i = 0; i < glsl_get_matrix_columns(src->type); i++)
174*61046927SAndroid Build Coastguard Worker srcmp->elems[i]->def = vtn_mediump_downconvert(b, base_type, src->elems[i]->def);
175*61046927SAndroid Build Coastguard Worker }
176*61046927SAndroid Build Coastguard Worker }
177*61046927SAndroid Build Coastguard Worker
178*61046927SAndroid Build Coastguard Worker return srcmp;
179*61046927SAndroid Build Coastguard Worker }
180*61046927SAndroid Build Coastguard Worker
181*61046927SAndroid Build Coastguard Worker static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)182*61046927SAndroid Build Coastguard Worker vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
183*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
184*61046927SAndroid Build Coastguard Worker {
185*61046927SAndroid Build Coastguard Worker switch (opcode) {
186*61046927SAndroid Build Coastguard Worker case SpvOpFNegate: {
187*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
188*61046927SAndroid Build Coastguard Worker unsigned cols = glsl_get_matrix_columns(src0->type);
189*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < cols; i++)
190*61046927SAndroid Build Coastguard Worker dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
191*61046927SAndroid Build Coastguard Worker return dest;
192*61046927SAndroid Build Coastguard Worker }
193*61046927SAndroid Build Coastguard Worker
194*61046927SAndroid Build Coastguard Worker case SpvOpFAdd: {
195*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
196*61046927SAndroid Build Coastguard Worker unsigned cols = glsl_get_matrix_columns(src0->type);
197*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < cols; i++)
198*61046927SAndroid Build Coastguard Worker dest->elems[i]->def =
199*61046927SAndroid Build Coastguard Worker nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
200*61046927SAndroid Build Coastguard Worker return dest;
201*61046927SAndroid Build Coastguard Worker }
202*61046927SAndroid Build Coastguard Worker
203*61046927SAndroid Build Coastguard Worker case SpvOpFSub: {
204*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
205*61046927SAndroid Build Coastguard Worker unsigned cols = glsl_get_matrix_columns(src0->type);
206*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < cols; i++)
207*61046927SAndroid Build Coastguard Worker dest->elems[i]->def =
208*61046927SAndroid Build Coastguard Worker nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
209*61046927SAndroid Build Coastguard Worker return dest;
210*61046927SAndroid Build Coastguard Worker }
211*61046927SAndroid Build Coastguard Worker
212*61046927SAndroid Build Coastguard Worker case SpvOpTranspose:
213*61046927SAndroid Build Coastguard Worker return vtn_ssa_transpose(b, src0);
214*61046927SAndroid Build Coastguard Worker
215*61046927SAndroid Build Coastguard Worker case SpvOpMatrixTimesScalar:
216*61046927SAndroid Build Coastguard Worker if (src0->transposed) {
217*61046927SAndroid Build Coastguard Worker return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
218*61046927SAndroid Build Coastguard Worker src1->def));
219*61046927SAndroid Build Coastguard Worker } else {
220*61046927SAndroid Build Coastguard Worker return mat_times_scalar(b, src0, src1->def);
221*61046927SAndroid Build Coastguard Worker }
222*61046927SAndroid Build Coastguard Worker break;
223*61046927SAndroid Build Coastguard Worker
224*61046927SAndroid Build Coastguard Worker case SpvOpVectorTimesMatrix:
225*61046927SAndroid Build Coastguard Worker case SpvOpMatrixTimesVector:
226*61046927SAndroid Build Coastguard Worker case SpvOpMatrixTimesMatrix:
227*61046927SAndroid Build Coastguard Worker if (opcode == SpvOpVectorTimesMatrix) {
228*61046927SAndroid Build Coastguard Worker return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
229*61046927SAndroid Build Coastguard Worker } else {
230*61046927SAndroid Build Coastguard Worker return matrix_multiply(b, src0, src1);
231*61046927SAndroid Build Coastguard Worker }
232*61046927SAndroid Build Coastguard Worker break;
233*61046927SAndroid Build Coastguard Worker
234*61046927SAndroid Build Coastguard Worker default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
235*61046927SAndroid Build Coastguard Worker }
236*61046927SAndroid Build Coastguard Worker }
237*61046927SAndroid Build Coastguard Worker
238*61046927SAndroid Build Coastguard Worker static nir_alu_type
convert_op_src_type(SpvOp opcode)239*61046927SAndroid Build Coastguard Worker convert_op_src_type(SpvOp opcode)
240*61046927SAndroid Build Coastguard Worker {
241*61046927SAndroid Build Coastguard Worker switch (opcode) {
242*61046927SAndroid Build Coastguard Worker case SpvOpFConvert:
243*61046927SAndroid Build Coastguard Worker case SpvOpConvertFToS:
244*61046927SAndroid Build Coastguard Worker case SpvOpConvertFToU:
245*61046927SAndroid Build Coastguard Worker return nir_type_float;
246*61046927SAndroid Build Coastguard Worker case SpvOpSConvert:
247*61046927SAndroid Build Coastguard Worker case SpvOpConvertSToF:
248*61046927SAndroid Build Coastguard Worker case SpvOpSatConvertSToU:
249*61046927SAndroid Build Coastguard Worker return nir_type_int;
250*61046927SAndroid Build Coastguard Worker case SpvOpUConvert:
251*61046927SAndroid Build Coastguard Worker case SpvOpConvertUToF:
252*61046927SAndroid Build Coastguard Worker case SpvOpSatConvertUToS:
253*61046927SAndroid Build Coastguard Worker return nir_type_uint;
254*61046927SAndroid Build Coastguard Worker default:
255*61046927SAndroid Build Coastguard Worker unreachable("Unhandled conversion op");
256*61046927SAndroid Build Coastguard Worker }
257*61046927SAndroid Build Coastguard Worker }
258*61046927SAndroid Build Coastguard Worker
259*61046927SAndroid Build Coastguard Worker static nir_alu_type
convert_op_dst_type(SpvOp opcode)260*61046927SAndroid Build Coastguard Worker convert_op_dst_type(SpvOp opcode)
261*61046927SAndroid Build Coastguard Worker {
262*61046927SAndroid Build Coastguard Worker switch (opcode) {
263*61046927SAndroid Build Coastguard Worker case SpvOpFConvert:
264*61046927SAndroid Build Coastguard Worker case SpvOpConvertSToF:
265*61046927SAndroid Build Coastguard Worker case SpvOpConvertUToF:
266*61046927SAndroid Build Coastguard Worker return nir_type_float;
267*61046927SAndroid Build Coastguard Worker case SpvOpSConvert:
268*61046927SAndroid Build Coastguard Worker case SpvOpConvertFToS:
269*61046927SAndroid Build Coastguard Worker case SpvOpSatConvertUToS:
270*61046927SAndroid Build Coastguard Worker return nir_type_int;
271*61046927SAndroid Build Coastguard Worker case SpvOpUConvert:
272*61046927SAndroid Build Coastguard Worker case SpvOpConvertFToU:
273*61046927SAndroid Build Coastguard Worker case SpvOpSatConvertSToU:
274*61046927SAndroid Build Coastguard Worker return nir_type_uint;
275*61046927SAndroid Build Coastguard Worker default:
276*61046927SAndroid Build Coastguard Worker unreachable("Unhandled conversion op");
277*61046927SAndroid Build Coastguard Worker }
278*61046927SAndroid Build Coastguard Worker }
279*61046927SAndroid Build Coastguard Worker
280*61046927SAndroid Build Coastguard Worker nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)281*61046927SAndroid Build Coastguard Worker vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
282*61046927SAndroid Build Coastguard Worker SpvOp opcode, bool *swap, bool *exact,
283*61046927SAndroid Build Coastguard Worker unsigned src_bit_size, unsigned dst_bit_size)
284*61046927SAndroid Build Coastguard Worker {
285*61046927SAndroid Build Coastguard Worker /* Indicates that the first two arguments should be swapped. This is
286*61046927SAndroid Build Coastguard Worker * used for implementing greater-than and less-than-or-equal.
287*61046927SAndroid Build Coastguard Worker */
288*61046927SAndroid Build Coastguard Worker *swap = false;
289*61046927SAndroid Build Coastguard Worker
290*61046927SAndroid Build Coastguard Worker *exact = false;
291*61046927SAndroid Build Coastguard Worker
292*61046927SAndroid Build Coastguard Worker switch (opcode) {
293*61046927SAndroid Build Coastguard Worker case SpvOpSNegate: return nir_op_ineg;
294*61046927SAndroid Build Coastguard Worker case SpvOpFNegate: return nir_op_fneg;
295*61046927SAndroid Build Coastguard Worker case SpvOpNot: return nir_op_inot;
296*61046927SAndroid Build Coastguard Worker case SpvOpIAdd: return nir_op_iadd;
297*61046927SAndroid Build Coastguard Worker case SpvOpFAdd: return nir_op_fadd;
298*61046927SAndroid Build Coastguard Worker case SpvOpISub: return nir_op_isub;
299*61046927SAndroid Build Coastguard Worker case SpvOpFSub: return nir_op_fsub;
300*61046927SAndroid Build Coastguard Worker case SpvOpIMul: return nir_op_imul;
301*61046927SAndroid Build Coastguard Worker case SpvOpFMul: return nir_op_fmul;
302*61046927SAndroid Build Coastguard Worker case SpvOpUDiv: return nir_op_udiv;
303*61046927SAndroid Build Coastguard Worker case SpvOpSDiv: return nir_op_idiv;
304*61046927SAndroid Build Coastguard Worker case SpvOpFDiv: return nir_op_fdiv;
305*61046927SAndroid Build Coastguard Worker case SpvOpUMod: return nir_op_umod;
306*61046927SAndroid Build Coastguard Worker case SpvOpSMod: return nir_op_imod;
307*61046927SAndroid Build Coastguard Worker case SpvOpFMod: return nir_op_fmod;
308*61046927SAndroid Build Coastguard Worker case SpvOpSRem: return nir_op_irem;
309*61046927SAndroid Build Coastguard Worker case SpvOpFRem: return nir_op_frem;
310*61046927SAndroid Build Coastguard Worker
311*61046927SAndroid Build Coastguard Worker case SpvOpShiftRightLogical: return nir_op_ushr;
312*61046927SAndroid Build Coastguard Worker case SpvOpShiftRightArithmetic: return nir_op_ishr;
313*61046927SAndroid Build Coastguard Worker case SpvOpShiftLeftLogical: return nir_op_ishl;
314*61046927SAndroid Build Coastguard Worker case SpvOpLogicalOr: return nir_op_ior;
315*61046927SAndroid Build Coastguard Worker case SpvOpLogicalEqual: return nir_op_ieq;
316*61046927SAndroid Build Coastguard Worker case SpvOpLogicalNotEqual: return nir_op_ine;
317*61046927SAndroid Build Coastguard Worker case SpvOpLogicalAnd: return nir_op_iand;
318*61046927SAndroid Build Coastguard Worker case SpvOpLogicalNot: return nir_op_inot;
319*61046927SAndroid Build Coastguard Worker case SpvOpBitwiseOr: return nir_op_ior;
320*61046927SAndroid Build Coastguard Worker case SpvOpBitwiseXor: return nir_op_ixor;
321*61046927SAndroid Build Coastguard Worker case SpvOpBitwiseAnd: return nir_op_iand;
322*61046927SAndroid Build Coastguard Worker case SpvOpSelect: return nir_op_bcsel;
323*61046927SAndroid Build Coastguard Worker case SpvOpIEqual: return nir_op_ieq;
324*61046927SAndroid Build Coastguard Worker
325*61046927SAndroid Build Coastguard Worker case SpvOpBitFieldInsert: return nir_op_bitfield_insert;
326*61046927SAndroid Build Coastguard Worker case SpvOpBitFieldSExtract: return nir_op_ibitfield_extract;
327*61046927SAndroid Build Coastguard Worker case SpvOpBitFieldUExtract: return nir_op_ubitfield_extract;
328*61046927SAndroid Build Coastguard Worker case SpvOpBitReverse: return nir_op_bitfield_reverse;
329*61046927SAndroid Build Coastguard Worker
330*61046927SAndroid Build Coastguard Worker case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
331*61046927SAndroid Build Coastguard Worker /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
332*61046927SAndroid Build Coastguard Worker case SpvOpAbsISubINTEL: return nir_op_uabs_isub;
333*61046927SAndroid Build Coastguard Worker case SpvOpAbsUSubINTEL: return nir_op_uabs_usub;
334*61046927SAndroid Build Coastguard Worker case SpvOpIAddSatINTEL: return nir_op_iadd_sat;
335*61046927SAndroid Build Coastguard Worker case SpvOpUAddSatINTEL: return nir_op_uadd_sat;
336*61046927SAndroid Build Coastguard Worker case SpvOpIAverageINTEL: return nir_op_ihadd;
337*61046927SAndroid Build Coastguard Worker case SpvOpUAverageINTEL: return nir_op_uhadd;
338*61046927SAndroid Build Coastguard Worker case SpvOpIAverageRoundedINTEL: return nir_op_irhadd;
339*61046927SAndroid Build Coastguard Worker case SpvOpUAverageRoundedINTEL: return nir_op_urhadd;
340*61046927SAndroid Build Coastguard Worker case SpvOpISubSatINTEL: return nir_op_isub_sat;
341*61046927SAndroid Build Coastguard Worker case SpvOpUSubSatINTEL: return nir_op_usub_sat;
342*61046927SAndroid Build Coastguard Worker case SpvOpIMul32x16INTEL: return nir_op_imul_32x16;
343*61046927SAndroid Build Coastguard Worker case SpvOpUMul32x16INTEL: return nir_op_umul_32x16;
344*61046927SAndroid Build Coastguard Worker
345*61046927SAndroid Build Coastguard Worker /* The ordered / unordered operators need special implementation besides
346*61046927SAndroid Build Coastguard Worker * the logical operator to use since they also need to check if operands are
347*61046927SAndroid Build Coastguard Worker * ordered.
348*61046927SAndroid Build Coastguard Worker */
349*61046927SAndroid Build Coastguard Worker case SpvOpFOrdEqual: *exact = true; return nir_op_feq;
350*61046927SAndroid Build Coastguard Worker case SpvOpFUnordEqual: *exact = true; return nir_op_feq;
351*61046927SAndroid Build Coastguard Worker case SpvOpINotEqual: return nir_op_ine;
352*61046927SAndroid Build Coastguard Worker case SpvOpLessOrGreater: /* Deprecated, use OrdNotEqual */
353*61046927SAndroid Build Coastguard Worker case SpvOpFOrdNotEqual: *exact = true; return nir_op_fneu;
354*61046927SAndroid Build Coastguard Worker case SpvOpFUnordNotEqual: *exact = true; return nir_op_fneu;
355*61046927SAndroid Build Coastguard Worker case SpvOpULessThan: return nir_op_ult;
356*61046927SAndroid Build Coastguard Worker case SpvOpSLessThan: return nir_op_ilt;
357*61046927SAndroid Build Coastguard Worker case SpvOpFOrdLessThan: *exact = true; return nir_op_flt;
358*61046927SAndroid Build Coastguard Worker case SpvOpFUnordLessThan: *exact = true; return nir_op_flt;
359*61046927SAndroid Build Coastguard Worker case SpvOpUGreaterThan: *swap = true; return nir_op_ult;
360*61046927SAndroid Build Coastguard Worker case SpvOpSGreaterThan: *swap = true; return nir_op_ilt;
361*61046927SAndroid Build Coastguard Worker case SpvOpFOrdGreaterThan: *swap = true; *exact = true; return nir_op_flt;
362*61046927SAndroid Build Coastguard Worker case SpvOpFUnordGreaterThan: *swap = true; *exact = true; return nir_op_flt;
363*61046927SAndroid Build Coastguard Worker case SpvOpULessThanEqual: *swap = true; return nir_op_uge;
364*61046927SAndroid Build Coastguard Worker case SpvOpSLessThanEqual: *swap = true; return nir_op_ige;
365*61046927SAndroid Build Coastguard Worker case SpvOpFOrdLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
366*61046927SAndroid Build Coastguard Worker case SpvOpFUnordLessThanEqual: *swap = true; *exact = true; return nir_op_fge;
367*61046927SAndroid Build Coastguard Worker case SpvOpUGreaterThanEqual: return nir_op_uge;
368*61046927SAndroid Build Coastguard Worker case SpvOpSGreaterThanEqual: return nir_op_ige;
369*61046927SAndroid Build Coastguard Worker case SpvOpFOrdGreaterThanEqual: *exact = true; return nir_op_fge;
370*61046927SAndroid Build Coastguard Worker case SpvOpFUnordGreaterThanEqual: *exact = true; return nir_op_fge;
371*61046927SAndroid Build Coastguard Worker
372*61046927SAndroid Build Coastguard Worker /* Conversions: */
373*61046927SAndroid Build Coastguard Worker case SpvOpQuantizeToF16: return nir_op_fquantize2f16;
374*61046927SAndroid Build Coastguard Worker case SpvOpUConvert:
375*61046927SAndroid Build Coastguard Worker case SpvOpConvertFToU:
376*61046927SAndroid Build Coastguard Worker case SpvOpConvertFToS:
377*61046927SAndroid Build Coastguard Worker case SpvOpConvertSToF:
378*61046927SAndroid Build Coastguard Worker case SpvOpConvertUToF:
379*61046927SAndroid Build Coastguard Worker case SpvOpSConvert:
380*61046927SAndroid Build Coastguard Worker case SpvOpFConvert: {
381*61046927SAndroid Build Coastguard Worker nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
382*61046927SAndroid Build Coastguard Worker nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
383*61046927SAndroid Build Coastguard Worker return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
384*61046927SAndroid Build Coastguard Worker }
385*61046927SAndroid Build Coastguard Worker
386*61046927SAndroid Build Coastguard Worker case SpvOpPtrCastToGeneric: return nir_op_mov;
387*61046927SAndroid Build Coastguard Worker case SpvOpGenericCastToPtr: return nir_op_mov;
388*61046927SAndroid Build Coastguard Worker
389*61046927SAndroid Build Coastguard Worker case SpvOpIsNormal: return nir_op_fisnormal;
390*61046927SAndroid Build Coastguard Worker case SpvOpIsFinite: return nir_op_fisfinite;
391*61046927SAndroid Build Coastguard Worker
392*61046927SAndroid Build Coastguard Worker default:
393*61046927SAndroid Build Coastguard Worker vtn_fail("No NIR equivalent: %u", opcode);
394*61046927SAndroid Build Coastguard Worker }
395*61046927SAndroid Build Coastguard Worker }
396*61046927SAndroid Build Coastguard Worker
397*61046927SAndroid Build Coastguard Worker static void
handle_fp_fast_math(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)398*61046927SAndroid Build Coastguard Worker handle_fp_fast_math(struct vtn_builder *b, UNUSED struct vtn_value *val,
399*61046927SAndroid Build Coastguard Worker UNUSED int member, const struct vtn_decoration *dec,
400*61046927SAndroid Build Coastguard Worker UNUSED void *_void)
401*61046927SAndroid Build Coastguard Worker {
402*61046927SAndroid Build Coastguard Worker vtn_assert(dec->scope == VTN_DEC_DECORATION);
403*61046927SAndroid Build Coastguard Worker if (dec->decoration != SpvDecorationFPFastMathMode)
404*61046927SAndroid Build Coastguard Worker return;
405*61046927SAndroid Build Coastguard Worker
406*61046927SAndroid Build Coastguard Worker SpvFPFastMathModeMask can_fast_math =
407*61046927SAndroid Build Coastguard Worker SpvFPFastMathModeAllowRecipMask |
408*61046927SAndroid Build Coastguard Worker SpvFPFastMathModeAllowContractMask |
409*61046927SAndroid Build Coastguard Worker SpvFPFastMathModeAllowReassocMask |
410*61046927SAndroid Build Coastguard Worker SpvFPFastMathModeAllowTransformMask;
411*61046927SAndroid Build Coastguard Worker
412*61046927SAndroid Build Coastguard Worker if ((dec->operands[0] & can_fast_math) != can_fast_math)
413*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
414*61046927SAndroid Build Coastguard Worker
415*61046927SAndroid Build Coastguard Worker /* Decoration overrides defaults */
416*61046927SAndroid Build Coastguard Worker b->nb.fp_fast_math = 0;
417*61046927SAndroid Build Coastguard Worker if (!(dec->operands[0] & SpvFPFastMathModeNSZMask))
418*61046927SAndroid Build Coastguard Worker b->nb.fp_fast_math |=
419*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP16 |
420*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP32 |
421*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP64;
422*61046927SAndroid Build Coastguard Worker if (!(dec->operands[0] & SpvFPFastMathModeNotNaNMask))
423*61046927SAndroid Build Coastguard Worker b->nb.fp_fast_math |=
424*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_NAN_PRESERVE_FP16 |
425*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_NAN_PRESERVE_FP32 |
426*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_NAN_PRESERVE_FP64;
427*61046927SAndroid Build Coastguard Worker if (!(dec->operands[0] & SpvFPFastMathModeNotInfMask))
428*61046927SAndroid Build Coastguard Worker b->nb.fp_fast_math |=
429*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_INF_PRESERVE_FP16 |
430*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_INF_PRESERVE_FP32 |
431*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_INF_PRESERVE_FP64;
432*61046927SAndroid Build Coastguard Worker }
433*61046927SAndroid Build Coastguard Worker
434*61046927SAndroid Build Coastguard Worker void
vtn_handle_fp_fast_math(struct vtn_builder * b,struct vtn_value * val)435*61046927SAndroid Build Coastguard Worker vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val)
436*61046927SAndroid Build Coastguard Worker {
437*61046927SAndroid Build Coastguard Worker /* Take the NaN/Inf/SZ preserve bits from the execution mode and set them
438*61046927SAndroid Build Coastguard Worker * on the builder, so the generated instructions can take it from it.
439*61046927SAndroid Build Coastguard Worker * We only care about some of them, check nir_alu_instr for details.
440*61046927SAndroid Build Coastguard Worker * We also copy all bit widths, because we can't easily get the correct one
441*61046927SAndroid Build Coastguard Worker * here.
442*61046927SAndroid Build Coastguard Worker */
443*61046927SAndroid Build Coastguard Worker #define FLOAT_CONTROLS2_BITS (FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 | \
444*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32 | \
445*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64)
446*61046927SAndroid Build Coastguard Worker static_assert(FLOAT_CONTROLS2_BITS == BITSET_MASK(9),
447*61046927SAndroid Build Coastguard Worker "enum float_controls and fp_fast_math out of sync!");
448*61046927SAndroid Build Coastguard Worker b->nb.fp_fast_math = b->shader->info.float_controls_execution_mode &
449*61046927SAndroid Build Coastguard Worker FLOAT_CONTROLS2_BITS;
450*61046927SAndroid Build Coastguard Worker vtn_foreach_decoration(b, val, handle_fp_fast_math, NULL);
451*61046927SAndroid Build Coastguard Worker #undef FLOAT_CONTROLS2_BITS
452*61046927SAndroid Build Coastguard Worker }
453*61046927SAndroid Build Coastguard Worker
454*61046927SAndroid Build Coastguard Worker static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)455*61046927SAndroid Build Coastguard Worker handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
456*61046927SAndroid Build Coastguard Worker UNUSED int member, const struct vtn_decoration *dec,
457*61046927SAndroid Build Coastguard Worker UNUSED void *_void)
458*61046927SAndroid Build Coastguard Worker {
459*61046927SAndroid Build Coastguard Worker vtn_assert(dec->scope == VTN_DEC_DECORATION);
460*61046927SAndroid Build Coastguard Worker if (dec->decoration != SpvDecorationNoContraction)
461*61046927SAndroid Build Coastguard Worker return;
462*61046927SAndroid Build Coastguard Worker
463*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
464*61046927SAndroid Build Coastguard Worker }
465*61046927SAndroid Build Coastguard Worker
466*61046927SAndroid Build Coastguard Worker void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)467*61046927SAndroid Build Coastguard Worker vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
468*61046927SAndroid Build Coastguard Worker {
469*61046927SAndroid Build Coastguard Worker vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
470*61046927SAndroid Build Coastguard Worker }
471*61046927SAndroid Build Coastguard Worker
472*61046927SAndroid Build Coastguard Worker nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)473*61046927SAndroid Build Coastguard Worker vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
474*61046927SAndroid Build Coastguard Worker {
475*61046927SAndroid Build Coastguard Worker switch (mode) {
476*61046927SAndroid Build Coastguard Worker case SpvFPRoundingModeRTE:
477*61046927SAndroid Build Coastguard Worker return nir_rounding_mode_rtne;
478*61046927SAndroid Build Coastguard Worker case SpvFPRoundingModeRTZ:
479*61046927SAndroid Build Coastguard Worker return nir_rounding_mode_rtz;
480*61046927SAndroid Build Coastguard Worker case SpvFPRoundingModeRTP:
481*61046927SAndroid Build Coastguard Worker vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
482*61046927SAndroid Build Coastguard Worker "FPRoundingModeRTP is only supported in kernels");
483*61046927SAndroid Build Coastguard Worker return nir_rounding_mode_ru;
484*61046927SAndroid Build Coastguard Worker case SpvFPRoundingModeRTN:
485*61046927SAndroid Build Coastguard Worker vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
486*61046927SAndroid Build Coastguard Worker "FPRoundingModeRTN is only supported in kernels");
487*61046927SAndroid Build Coastguard Worker return nir_rounding_mode_rd;
488*61046927SAndroid Build Coastguard Worker default:
489*61046927SAndroid Build Coastguard Worker vtn_fail("Unsupported rounding mode: %s",
490*61046927SAndroid Build Coastguard Worker spirv_fproundingmode_to_string(mode));
491*61046927SAndroid Build Coastguard Worker break;
492*61046927SAndroid Build Coastguard Worker }
493*61046927SAndroid Build Coastguard Worker }
494*61046927SAndroid Build Coastguard Worker
495*61046927SAndroid Build Coastguard Worker struct conversion_opts {
496*61046927SAndroid Build Coastguard Worker nir_rounding_mode rounding_mode;
497*61046927SAndroid Build Coastguard Worker bool saturate;
498*61046927SAndroid Build Coastguard Worker };
499*61046927SAndroid Build Coastguard Worker
500*61046927SAndroid Build Coastguard Worker static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)501*61046927SAndroid Build Coastguard Worker handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
502*61046927SAndroid Build Coastguard Worker UNUSED int member,
503*61046927SAndroid Build Coastguard Worker const struct vtn_decoration *dec, void *_opts)
504*61046927SAndroid Build Coastguard Worker {
505*61046927SAndroid Build Coastguard Worker struct conversion_opts *opts = _opts;
506*61046927SAndroid Build Coastguard Worker
507*61046927SAndroid Build Coastguard Worker switch (dec->decoration) {
508*61046927SAndroid Build Coastguard Worker case SpvDecorationFPRoundingMode:
509*61046927SAndroid Build Coastguard Worker opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
510*61046927SAndroid Build Coastguard Worker break;
511*61046927SAndroid Build Coastguard Worker
512*61046927SAndroid Build Coastguard Worker case SpvDecorationSaturatedConversion:
513*61046927SAndroid Build Coastguard Worker vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
514*61046927SAndroid Build Coastguard Worker "Saturated conversions are only allowed in kernels");
515*61046927SAndroid Build Coastguard Worker opts->saturate = true;
516*61046927SAndroid Build Coastguard Worker break;
517*61046927SAndroid Build Coastguard Worker
518*61046927SAndroid Build Coastguard Worker default:
519*61046927SAndroid Build Coastguard Worker break;
520*61046927SAndroid Build Coastguard Worker }
521*61046927SAndroid Build Coastguard Worker }
522*61046927SAndroid Build Coastguard Worker
523*61046927SAndroid Build Coastguard Worker static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)524*61046927SAndroid Build Coastguard Worker handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
525*61046927SAndroid Build Coastguard Worker UNUSED int member,
526*61046927SAndroid Build Coastguard Worker const struct vtn_decoration *dec, void *_alu)
527*61046927SAndroid Build Coastguard Worker {
528*61046927SAndroid Build Coastguard Worker nir_alu_instr *alu = _alu;
529*61046927SAndroid Build Coastguard Worker switch (dec->decoration) {
530*61046927SAndroid Build Coastguard Worker case SpvDecorationNoSignedWrap:
531*61046927SAndroid Build Coastguard Worker alu->no_signed_wrap = true;
532*61046927SAndroid Build Coastguard Worker break;
533*61046927SAndroid Build Coastguard Worker case SpvDecorationNoUnsignedWrap:
534*61046927SAndroid Build Coastguard Worker alu->no_unsigned_wrap = true;
535*61046927SAndroid Build Coastguard Worker break;
536*61046927SAndroid Build Coastguard Worker default:
537*61046927SAndroid Build Coastguard Worker /* Do nothing. */
538*61046927SAndroid Build Coastguard Worker break;
539*61046927SAndroid Build Coastguard Worker }
540*61046927SAndroid Build Coastguard Worker }
541*61046927SAndroid Build Coastguard Worker
542*61046927SAndroid Build Coastguard Worker static void
vtn_value_is_relaxed_precision_cb(struct vtn_builder * b,struct vtn_value * val,int member,const struct vtn_decoration * dec,void * void_ctx)543*61046927SAndroid Build Coastguard Worker vtn_value_is_relaxed_precision_cb(struct vtn_builder *b,
544*61046927SAndroid Build Coastguard Worker struct vtn_value *val, int member,
545*61046927SAndroid Build Coastguard Worker const struct vtn_decoration *dec, void *void_ctx)
546*61046927SAndroid Build Coastguard Worker {
547*61046927SAndroid Build Coastguard Worker bool *relaxed_precision = void_ctx;
548*61046927SAndroid Build Coastguard Worker switch (dec->decoration) {
549*61046927SAndroid Build Coastguard Worker case SpvDecorationRelaxedPrecision:
550*61046927SAndroid Build Coastguard Worker *relaxed_precision = true;
551*61046927SAndroid Build Coastguard Worker break;
552*61046927SAndroid Build Coastguard Worker
553*61046927SAndroid Build Coastguard Worker default:
554*61046927SAndroid Build Coastguard Worker break;
555*61046927SAndroid Build Coastguard Worker }
556*61046927SAndroid Build Coastguard Worker }
557*61046927SAndroid Build Coastguard Worker
558*61046927SAndroid Build Coastguard Worker bool
vtn_value_is_relaxed_precision(struct vtn_builder * b,struct vtn_value * val)559*61046927SAndroid Build Coastguard Worker vtn_value_is_relaxed_precision(struct vtn_builder *b, struct vtn_value *val)
560*61046927SAndroid Build Coastguard Worker {
561*61046927SAndroid Build Coastguard Worker bool result = false;
562*61046927SAndroid Build Coastguard Worker vtn_foreach_decoration(b, val,
563*61046927SAndroid Build Coastguard Worker vtn_value_is_relaxed_precision_cb, &result);
564*61046927SAndroid Build Coastguard Worker return result;
565*61046927SAndroid Build Coastguard Worker }
566*61046927SAndroid Build Coastguard Worker
567*61046927SAndroid Build Coastguard Worker static bool
vtn_alu_op_mediump_16bit(struct vtn_builder * b,SpvOp opcode,struct vtn_value * dest_val)568*61046927SAndroid Build Coastguard Worker vtn_alu_op_mediump_16bit(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val)
569*61046927SAndroid Build Coastguard Worker {
570*61046927SAndroid Build Coastguard Worker if (!b->options->mediump_16bit_alu || !vtn_value_is_relaxed_precision(b, dest_val))
571*61046927SAndroid Build Coastguard Worker return false;
572*61046927SAndroid Build Coastguard Worker
573*61046927SAndroid Build Coastguard Worker switch (opcode) {
574*61046927SAndroid Build Coastguard Worker case SpvOpDPdx:
575*61046927SAndroid Build Coastguard Worker case SpvOpDPdy:
576*61046927SAndroid Build Coastguard Worker case SpvOpDPdxFine:
577*61046927SAndroid Build Coastguard Worker case SpvOpDPdyFine:
578*61046927SAndroid Build Coastguard Worker case SpvOpDPdxCoarse:
579*61046927SAndroid Build Coastguard Worker case SpvOpDPdyCoarse:
580*61046927SAndroid Build Coastguard Worker case SpvOpFwidth:
581*61046927SAndroid Build Coastguard Worker case SpvOpFwidthFine:
582*61046927SAndroid Build Coastguard Worker case SpvOpFwidthCoarse:
583*61046927SAndroid Build Coastguard Worker return b->options->mediump_16bit_derivatives;
584*61046927SAndroid Build Coastguard Worker default:
585*61046927SAndroid Build Coastguard Worker return true;
586*61046927SAndroid Build Coastguard Worker }
587*61046927SAndroid Build Coastguard Worker }
588*61046927SAndroid Build Coastguard Worker
589*61046927SAndroid Build Coastguard Worker static nir_def *
vtn_mediump_upconvert(struct vtn_builder * b,enum glsl_base_type base_type,nir_def * def)590*61046927SAndroid Build Coastguard Worker vtn_mediump_upconvert(struct vtn_builder *b, enum glsl_base_type base_type, nir_def *def)
591*61046927SAndroid Build Coastguard Worker {
592*61046927SAndroid Build Coastguard Worker if (def->bit_size != 16)
593*61046927SAndroid Build Coastguard Worker return def;
594*61046927SAndroid Build Coastguard Worker
595*61046927SAndroid Build Coastguard Worker switch (base_type) {
596*61046927SAndroid Build Coastguard Worker case GLSL_TYPE_FLOAT:
597*61046927SAndroid Build Coastguard Worker return nir_f2f32(&b->nb, def);
598*61046927SAndroid Build Coastguard Worker case GLSL_TYPE_INT:
599*61046927SAndroid Build Coastguard Worker return nir_i2i32(&b->nb, def);
600*61046927SAndroid Build Coastguard Worker case GLSL_TYPE_UINT:
601*61046927SAndroid Build Coastguard Worker return nir_u2u32(&b->nb, def);
602*61046927SAndroid Build Coastguard Worker default:
603*61046927SAndroid Build Coastguard Worker unreachable("bad relaxed precision output type");
604*61046927SAndroid Build Coastguard Worker }
605*61046927SAndroid Build Coastguard Worker }
606*61046927SAndroid Build Coastguard Worker
607*61046927SAndroid Build Coastguard Worker void
vtn_mediump_upconvert_value(struct vtn_builder * b,struct vtn_ssa_value * value)608*61046927SAndroid Build Coastguard Worker vtn_mediump_upconvert_value(struct vtn_builder *b, struct vtn_ssa_value *value)
609*61046927SAndroid Build Coastguard Worker {
610*61046927SAndroid Build Coastguard Worker enum glsl_base_type base_type = glsl_get_base_type(value->type);
611*61046927SAndroid Build Coastguard Worker
612*61046927SAndroid Build Coastguard Worker if (glsl_type_is_vector_or_scalar(value->type)) {
613*61046927SAndroid Build Coastguard Worker value->def = vtn_mediump_upconvert(b, base_type, value->def);
614*61046927SAndroid Build Coastguard Worker } else {
615*61046927SAndroid Build Coastguard Worker for (int i = 0; i < glsl_get_matrix_columns(value->type); i++)
616*61046927SAndroid Build Coastguard Worker value->elems[i]->def = vtn_mediump_upconvert(b, base_type, value->elems[i]->def);
617*61046927SAndroid Build Coastguard Worker }
618*61046927SAndroid Build Coastguard Worker }
619*61046927SAndroid Build Coastguard Worker
620*61046927SAndroid Build Coastguard Worker void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)621*61046927SAndroid Build Coastguard Worker vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
622*61046927SAndroid Build Coastguard Worker const uint32_t *w, unsigned count)
623*61046927SAndroid Build Coastguard Worker {
624*61046927SAndroid Build Coastguard Worker struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
625*61046927SAndroid Build Coastguard Worker const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
626*61046927SAndroid Build Coastguard Worker
627*61046927SAndroid Build Coastguard Worker if (glsl_type_is_cmat(dest_type)) {
628*61046927SAndroid Build Coastguard Worker vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
629*61046927SAndroid Build Coastguard Worker return;
630*61046927SAndroid Build Coastguard Worker }
631*61046927SAndroid Build Coastguard Worker
632*61046927SAndroid Build Coastguard Worker vtn_handle_no_contraction(b, dest_val);
633*61046927SAndroid Build Coastguard Worker vtn_handle_fp_fast_math(b, dest_val);
634*61046927SAndroid Build Coastguard Worker bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
635*61046927SAndroid Build Coastguard Worker
636*61046927SAndroid Build Coastguard Worker /* Collect the various SSA sources */
637*61046927SAndroid Build Coastguard Worker const unsigned num_inputs = count - 3;
638*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *vtn_src[4] = { NULL, };
639*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < num_inputs; i++) {
640*61046927SAndroid Build Coastguard Worker vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
641*61046927SAndroid Build Coastguard Worker if (mediump_16bit)
642*61046927SAndroid Build Coastguard Worker vtn_src[i] = vtn_mediump_downconvert_value(b, vtn_src[i]);
643*61046927SAndroid Build Coastguard Worker }
644*61046927SAndroid Build Coastguard Worker
645*61046927SAndroid Build Coastguard Worker if (glsl_type_is_matrix(vtn_src[0]->type) ||
646*61046927SAndroid Build Coastguard Worker (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
647*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *dest = vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]);
648*61046927SAndroid Build Coastguard Worker
649*61046927SAndroid Build Coastguard Worker if (mediump_16bit)
650*61046927SAndroid Build Coastguard Worker vtn_mediump_upconvert_value(b, dest);
651*61046927SAndroid Build Coastguard Worker
652*61046927SAndroid Build Coastguard Worker vtn_push_ssa_value(b, w[2], dest);
653*61046927SAndroid Build Coastguard Worker b->nb.exact = b->exact;
654*61046927SAndroid Build Coastguard Worker return;
655*61046927SAndroid Build Coastguard Worker }
656*61046927SAndroid Build Coastguard Worker
657*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
658*61046927SAndroid Build Coastguard Worker nir_def *src[4] = { NULL, };
659*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < num_inputs; i++) {
660*61046927SAndroid Build Coastguard Worker vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
661*61046927SAndroid Build Coastguard Worker src[i] = vtn_src[i]->def;
662*61046927SAndroid Build Coastguard Worker }
663*61046927SAndroid Build Coastguard Worker
664*61046927SAndroid Build Coastguard Worker switch (opcode) {
665*61046927SAndroid Build Coastguard Worker case SpvOpAny:
666*61046927SAndroid Build Coastguard Worker dest->def = nir_bany(&b->nb, src[0]);
667*61046927SAndroid Build Coastguard Worker break;
668*61046927SAndroid Build Coastguard Worker
669*61046927SAndroid Build Coastguard Worker case SpvOpAll:
670*61046927SAndroid Build Coastguard Worker dest->def = nir_ball(&b->nb, src[0]);
671*61046927SAndroid Build Coastguard Worker break;
672*61046927SAndroid Build Coastguard Worker
673*61046927SAndroid Build Coastguard Worker case SpvOpOuterProduct: {
674*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < src[1]->num_components; i++) {
675*61046927SAndroid Build Coastguard Worker dest->elems[i]->def =
676*61046927SAndroid Build Coastguard Worker nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
677*61046927SAndroid Build Coastguard Worker }
678*61046927SAndroid Build Coastguard Worker break;
679*61046927SAndroid Build Coastguard Worker }
680*61046927SAndroid Build Coastguard Worker
681*61046927SAndroid Build Coastguard Worker case SpvOpDot:
682*61046927SAndroid Build Coastguard Worker dest->def = nir_fdot(&b->nb, src[0], src[1]);
683*61046927SAndroid Build Coastguard Worker break;
684*61046927SAndroid Build Coastguard Worker
685*61046927SAndroid Build Coastguard Worker case SpvOpIAddCarry:
686*61046927SAndroid Build Coastguard Worker vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
687*61046927SAndroid Build Coastguard Worker dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
688*61046927SAndroid Build Coastguard Worker dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
689*61046927SAndroid Build Coastguard Worker break;
690*61046927SAndroid Build Coastguard Worker
691*61046927SAndroid Build Coastguard Worker case SpvOpISubBorrow:
692*61046927SAndroid Build Coastguard Worker vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
693*61046927SAndroid Build Coastguard Worker dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
694*61046927SAndroid Build Coastguard Worker dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
695*61046927SAndroid Build Coastguard Worker break;
696*61046927SAndroid Build Coastguard Worker
697*61046927SAndroid Build Coastguard Worker case SpvOpUMulExtended: {
698*61046927SAndroid Build Coastguard Worker vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
699*61046927SAndroid Build Coastguard Worker if (src[0]->bit_size == 32) {
700*61046927SAndroid Build Coastguard Worker nir_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
701*61046927SAndroid Build Coastguard Worker dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
702*61046927SAndroid Build Coastguard Worker dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
703*61046927SAndroid Build Coastguard Worker } else {
704*61046927SAndroid Build Coastguard Worker dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
705*61046927SAndroid Build Coastguard Worker dest->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
706*61046927SAndroid Build Coastguard Worker }
707*61046927SAndroid Build Coastguard Worker break;
708*61046927SAndroid Build Coastguard Worker }
709*61046927SAndroid Build Coastguard Worker
710*61046927SAndroid Build Coastguard Worker case SpvOpSMulExtended: {
711*61046927SAndroid Build Coastguard Worker vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
712*61046927SAndroid Build Coastguard Worker if (src[0]->bit_size == 32) {
713*61046927SAndroid Build Coastguard Worker nir_def *umul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
714*61046927SAndroid Build Coastguard Worker dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
715*61046927SAndroid Build Coastguard Worker dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
716*61046927SAndroid Build Coastguard Worker } else {
717*61046927SAndroid Build Coastguard Worker dest->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
718*61046927SAndroid Build Coastguard Worker dest->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
719*61046927SAndroid Build Coastguard Worker }
720*61046927SAndroid Build Coastguard Worker break;
721*61046927SAndroid Build Coastguard Worker }
722*61046927SAndroid Build Coastguard Worker
723*61046927SAndroid Build Coastguard Worker case SpvOpDPdx:
724*61046927SAndroid Build Coastguard Worker dest->def = nir_ddx(&b->nb, src[0]);
725*61046927SAndroid Build Coastguard Worker break;
726*61046927SAndroid Build Coastguard Worker case SpvOpDPdxFine:
727*61046927SAndroid Build Coastguard Worker dest->def = nir_ddx_fine(&b->nb, src[0]);
728*61046927SAndroid Build Coastguard Worker break;
729*61046927SAndroid Build Coastguard Worker case SpvOpDPdxCoarse:
730*61046927SAndroid Build Coastguard Worker dest->def = nir_ddx_coarse(&b->nb, src[0]);
731*61046927SAndroid Build Coastguard Worker break;
732*61046927SAndroid Build Coastguard Worker case SpvOpDPdy:
733*61046927SAndroid Build Coastguard Worker dest->def = nir_ddy(&b->nb, src[0]);
734*61046927SAndroid Build Coastguard Worker break;
735*61046927SAndroid Build Coastguard Worker case SpvOpDPdyFine:
736*61046927SAndroid Build Coastguard Worker dest->def = nir_ddy_fine(&b->nb, src[0]);
737*61046927SAndroid Build Coastguard Worker break;
738*61046927SAndroid Build Coastguard Worker case SpvOpDPdyCoarse:
739*61046927SAndroid Build Coastguard Worker dest->def = nir_ddy_coarse(&b->nb, src[0]);
740*61046927SAndroid Build Coastguard Worker break;
741*61046927SAndroid Build Coastguard Worker
742*61046927SAndroid Build Coastguard Worker case SpvOpFwidth:
743*61046927SAndroid Build Coastguard Worker dest->def = nir_fadd(&b->nb,
744*61046927SAndroid Build Coastguard Worker nir_fabs(&b->nb, nir_ddx(&b->nb, src[0])),
745*61046927SAndroid Build Coastguard Worker nir_fabs(&b->nb, nir_ddy(&b->nb, src[0])));
746*61046927SAndroid Build Coastguard Worker break;
747*61046927SAndroid Build Coastguard Worker case SpvOpFwidthFine:
748*61046927SAndroid Build Coastguard Worker dest->def = nir_fadd(&b->nb,
749*61046927SAndroid Build Coastguard Worker nir_fabs(&b->nb, nir_ddx_fine(&b->nb, src[0])),
750*61046927SAndroid Build Coastguard Worker nir_fabs(&b->nb, nir_ddy_fine(&b->nb, src[0])));
751*61046927SAndroid Build Coastguard Worker break;
752*61046927SAndroid Build Coastguard Worker case SpvOpFwidthCoarse:
753*61046927SAndroid Build Coastguard Worker dest->def = nir_fadd(&b->nb,
754*61046927SAndroid Build Coastguard Worker nir_fabs(&b->nb, nir_ddx_coarse(&b->nb, src[0])),
755*61046927SAndroid Build Coastguard Worker nir_fabs(&b->nb, nir_ddy_coarse(&b->nb, src[0])));
756*61046927SAndroid Build Coastguard Worker break;
757*61046927SAndroid Build Coastguard Worker
758*61046927SAndroid Build Coastguard Worker case SpvOpVectorTimesScalar:
759*61046927SAndroid Build Coastguard Worker /* The builder will take care of splatting for us. */
760*61046927SAndroid Build Coastguard Worker dest->def = nir_fmul(&b->nb, src[0], src[1]);
761*61046927SAndroid Build Coastguard Worker break;
762*61046927SAndroid Build Coastguard Worker
763*61046927SAndroid Build Coastguard Worker case SpvOpIsNan: {
764*61046927SAndroid Build Coastguard Worker const bool save_exact = b->nb.exact;
765*61046927SAndroid Build Coastguard Worker
766*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
767*61046927SAndroid Build Coastguard Worker dest->def = nir_fneu(&b->nb, src[0], src[0]);
768*61046927SAndroid Build Coastguard Worker b->nb.exact = save_exact;
769*61046927SAndroid Build Coastguard Worker break;
770*61046927SAndroid Build Coastguard Worker }
771*61046927SAndroid Build Coastguard Worker
772*61046927SAndroid Build Coastguard Worker case SpvOpOrdered: {
773*61046927SAndroid Build Coastguard Worker const bool save_exact = b->nb.exact;
774*61046927SAndroid Build Coastguard Worker
775*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
776*61046927SAndroid Build Coastguard Worker dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
777*61046927SAndroid Build Coastguard Worker nir_feq(&b->nb, src[1], src[1]));
778*61046927SAndroid Build Coastguard Worker b->nb.exact = save_exact;
779*61046927SAndroid Build Coastguard Worker break;
780*61046927SAndroid Build Coastguard Worker }
781*61046927SAndroid Build Coastguard Worker
782*61046927SAndroid Build Coastguard Worker case SpvOpUnordered: {
783*61046927SAndroid Build Coastguard Worker const bool save_exact = b->nb.exact;
784*61046927SAndroid Build Coastguard Worker
785*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
786*61046927SAndroid Build Coastguard Worker dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
787*61046927SAndroid Build Coastguard Worker nir_fneu(&b->nb, src[1], src[1]));
788*61046927SAndroid Build Coastguard Worker b->nb.exact = save_exact;
789*61046927SAndroid Build Coastguard Worker break;
790*61046927SAndroid Build Coastguard Worker }
791*61046927SAndroid Build Coastguard Worker
792*61046927SAndroid Build Coastguard Worker case SpvOpIsInf: {
793*61046927SAndroid Build Coastguard Worker nir_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
794*61046927SAndroid Build Coastguard Worker dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
795*61046927SAndroid Build Coastguard Worker break;
796*61046927SAndroid Build Coastguard Worker }
797*61046927SAndroid Build Coastguard Worker
798*61046927SAndroid Build Coastguard Worker case SpvOpFUnordEqual: {
799*61046927SAndroid Build Coastguard Worker const bool save_exact = b->nb.exact;
800*61046927SAndroid Build Coastguard Worker
801*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
802*61046927SAndroid Build Coastguard Worker
803*61046927SAndroid Build Coastguard Worker /* This could also be implemented as !(a < b || b < a). If one or both
804*61046927SAndroid Build Coastguard Worker * of the source are numbers, later optimization passes can easily
805*61046927SAndroid Build Coastguard Worker * eliminate the isnan() checks. This may trim the sequence down to a
806*61046927SAndroid Build Coastguard Worker * single (a == b) operation. Otherwise, the optimizer can transform
807*61046927SAndroid Build Coastguard Worker * whatever is left to !(a < b || b < a). Since some applications will
808*61046927SAndroid Build Coastguard Worker * open-code this sequence, these optimizations are needed anyway.
809*61046927SAndroid Build Coastguard Worker */
810*61046927SAndroid Build Coastguard Worker dest->def =
811*61046927SAndroid Build Coastguard Worker nir_ior(&b->nb,
812*61046927SAndroid Build Coastguard Worker nir_feq(&b->nb, src[0], src[1]),
813*61046927SAndroid Build Coastguard Worker nir_ior(&b->nb,
814*61046927SAndroid Build Coastguard Worker nir_fneu(&b->nb, src[0], src[0]),
815*61046927SAndroid Build Coastguard Worker nir_fneu(&b->nb, src[1], src[1])));
816*61046927SAndroid Build Coastguard Worker
817*61046927SAndroid Build Coastguard Worker b->nb.exact = save_exact;
818*61046927SAndroid Build Coastguard Worker break;
819*61046927SAndroid Build Coastguard Worker }
820*61046927SAndroid Build Coastguard Worker
821*61046927SAndroid Build Coastguard Worker case SpvOpFUnordLessThan:
822*61046927SAndroid Build Coastguard Worker case SpvOpFUnordGreaterThan:
823*61046927SAndroid Build Coastguard Worker case SpvOpFUnordLessThanEqual:
824*61046927SAndroid Build Coastguard Worker case SpvOpFUnordGreaterThanEqual: {
825*61046927SAndroid Build Coastguard Worker bool swap;
826*61046927SAndroid Build Coastguard Worker bool unused_exact;
827*61046927SAndroid Build Coastguard Worker unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
828*61046927SAndroid Build Coastguard Worker unsigned dst_bit_size = glsl_get_bit_size(dest_type);
829*61046927SAndroid Build Coastguard Worker nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
830*61046927SAndroid Build Coastguard Worker &unused_exact,
831*61046927SAndroid Build Coastguard Worker src_bit_size, dst_bit_size);
832*61046927SAndroid Build Coastguard Worker
833*61046927SAndroid Build Coastguard Worker if (swap) {
834*61046927SAndroid Build Coastguard Worker nir_def *tmp = src[0];
835*61046927SAndroid Build Coastguard Worker src[0] = src[1];
836*61046927SAndroid Build Coastguard Worker src[1] = tmp;
837*61046927SAndroid Build Coastguard Worker }
838*61046927SAndroid Build Coastguard Worker
839*61046927SAndroid Build Coastguard Worker const bool save_exact = b->nb.exact;
840*61046927SAndroid Build Coastguard Worker
841*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
842*61046927SAndroid Build Coastguard Worker
843*61046927SAndroid Build Coastguard Worker /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
844*61046927SAndroid Build Coastguard Worker switch (op) {
845*61046927SAndroid Build Coastguard Worker case nir_op_fge: op = nir_op_flt; break;
846*61046927SAndroid Build Coastguard Worker case nir_op_flt: op = nir_op_fge; break;
847*61046927SAndroid Build Coastguard Worker default: unreachable("Impossible opcode.");
848*61046927SAndroid Build Coastguard Worker }
849*61046927SAndroid Build Coastguard Worker
850*61046927SAndroid Build Coastguard Worker dest->def =
851*61046927SAndroid Build Coastguard Worker nir_inot(&b->nb,
852*61046927SAndroid Build Coastguard Worker nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
853*61046927SAndroid Build Coastguard Worker
854*61046927SAndroid Build Coastguard Worker b->nb.exact = save_exact;
855*61046927SAndroid Build Coastguard Worker break;
856*61046927SAndroid Build Coastguard Worker }
857*61046927SAndroid Build Coastguard Worker
858*61046927SAndroid Build Coastguard Worker case SpvOpLessOrGreater:
859*61046927SAndroid Build Coastguard Worker case SpvOpFOrdNotEqual: {
860*61046927SAndroid Build Coastguard Worker /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
861*61046927SAndroid Build Coastguard Worker * from the ALU will probably already be false if the operands are not
862*61046927SAndroid Build Coastguard Worker * ordered so we don’t need to handle it specially.
863*61046927SAndroid Build Coastguard Worker */
864*61046927SAndroid Build Coastguard Worker const bool save_exact = b->nb.exact;
865*61046927SAndroid Build Coastguard Worker
866*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
867*61046927SAndroid Build Coastguard Worker
868*61046927SAndroid Build Coastguard Worker /* This could also be implemented as (a < b || b < a). If one or both
869*61046927SAndroid Build Coastguard Worker * of the source are numbers, later optimization passes can easily
870*61046927SAndroid Build Coastguard Worker * eliminate the isnan() checks. This may trim the sequence down to a
871*61046927SAndroid Build Coastguard Worker * single (a != b) operation. Otherwise, the optimizer can transform
872*61046927SAndroid Build Coastguard Worker * whatever is left to (a < b || b < a). Since some applications will
873*61046927SAndroid Build Coastguard Worker * open-code this sequence, these optimizations are needed anyway.
874*61046927SAndroid Build Coastguard Worker */
875*61046927SAndroid Build Coastguard Worker dest->def =
876*61046927SAndroid Build Coastguard Worker nir_iand(&b->nb,
877*61046927SAndroid Build Coastguard Worker nir_fneu(&b->nb, src[0], src[1]),
878*61046927SAndroid Build Coastguard Worker nir_iand(&b->nb,
879*61046927SAndroid Build Coastguard Worker nir_feq(&b->nb, src[0], src[0]),
880*61046927SAndroid Build Coastguard Worker nir_feq(&b->nb, src[1], src[1])));
881*61046927SAndroid Build Coastguard Worker
882*61046927SAndroid Build Coastguard Worker b->nb.exact = save_exact;
883*61046927SAndroid Build Coastguard Worker break;
884*61046927SAndroid Build Coastguard Worker }
885*61046927SAndroid Build Coastguard Worker
886*61046927SAndroid Build Coastguard Worker case SpvOpUConvert:
887*61046927SAndroid Build Coastguard Worker case SpvOpConvertFToU:
888*61046927SAndroid Build Coastguard Worker case SpvOpConvertFToS:
889*61046927SAndroid Build Coastguard Worker case SpvOpConvertSToF:
890*61046927SAndroid Build Coastguard Worker case SpvOpConvertUToF:
891*61046927SAndroid Build Coastguard Worker case SpvOpSConvert:
892*61046927SAndroid Build Coastguard Worker case SpvOpFConvert:
893*61046927SAndroid Build Coastguard Worker case SpvOpSatConvertSToU:
894*61046927SAndroid Build Coastguard Worker case SpvOpSatConvertUToS: {
895*61046927SAndroid Build Coastguard Worker unsigned src_bit_size = src[0]->bit_size;
896*61046927SAndroid Build Coastguard Worker unsigned dst_bit_size = glsl_get_bit_size(dest_type);
897*61046927SAndroid Build Coastguard Worker nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
898*61046927SAndroid Build Coastguard Worker nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
899*61046927SAndroid Build Coastguard Worker
900*61046927SAndroid Build Coastguard Worker struct conversion_opts opts = {
901*61046927SAndroid Build Coastguard Worker .rounding_mode = nir_rounding_mode_undef,
902*61046927SAndroid Build Coastguard Worker .saturate = false,
903*61046927SAndroid Build Coastguard Worker };
904*61046927SAndroid Build Coastguard Worker vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
905*61046927SAndroid Build Coastguard Worker
906*61046927SAndroid Build Coastguard Worker if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
907*61046927SAndroid Build Coastguard Worker opts.saturate = true;
908*61046927SAndroid Build Coastguard Worker
909*61046927SAndroid Build Coastguard Worker if (b->shader->info.stage == MESA_SHADER_KERNEL) {
910*61046927SAndroid Build Coastguard Worker if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
911*61046927SAndroid Build Coastguard Worker dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
912*61046927SAndroid Build Coastguard Worker nir_rounding_mode_undef);
913*61046927SAndroid Build Coastguard Worker } else {
914*61046927SAndroid Build Coastguard Worker dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
915*61046927SAndroid Build Coastguard Worker src_type, dst_type,
916*61046927SAndroid Build Coastguard Worker opts.rounding_mode, opts.saturate);
917*61046927SAndroid Build Coastguard Worker }
918*61046927SAndroid Build Coastguard Worker } else {
919*61046927SAndroid Build Coastguard Worker vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
920*61046927SAndroid Build Coastguard Worker dst_type != nir_type_float16,
921*61046927SAndroid Build Coastguard Worker "Rounding modes are only allowed on conversions to "
922*61046927SAndroid Build Coastguard Worker "16-bit float types");
923*61046927SAndroid Build Coastguard Worker dest->def = nir_type_convert(&b->nb, src[0], src_type, dst_type,
924*61046927SAndroid Build Coastguard Worker opts.rounding_mode);
925*61046927SAndroid Build Coastguard Worker }
926*61046927SAndroid Build Coastguard Worker break;
927*61046927SAndroid Build Coastguard Worker }
928*61046927SAndroid Build Coastguard Worker
929*61046927SAndroid Build Coastguard Worker case SpvOpBitFieldInsert:
930*61046927SAndroid Build Coastguard Worker case SpvOpBitFieldSExtract:
931*61046927SAndroid Build Coastguard Worker case SpvOpBitFieldUExtract:
932*61046927SAndroid Build Coastguard Worker case SpvOpShiftLeftLogical:
933*61046927SAndroid Build Coastguard Worker case SpvOpShiftRightArithmetic:
934*61046927SAndroid Build Coastguard Worker case SpvOpShiftRightLogical: {
935*61046927SAndroid Build Coastguard Worker bool swap;
936*61046927SAndroid Build Coastguard Worker bool exact;
937*61046927SAndroid Build Coastguard Worker unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
938*61046927SAndroid Build Coastguard Worker unsigned dst_bit_size = glsl_get_bit_size(dest_type);
939*61046927SAndroid Build Coastguard Worker nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
940*61046927SAndroid Build Coastguard Worker src0_bit_size, dst_bit_size);
941*61046927SAndroid Build Coastguard Worker
942*61046927SAndroid Build Coastguard Worker assert(!exact);
943*61046927SAndroid Build Coastguard Worker
944*61046927SAndroid Build Coastguard Worker assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
945*61046927SAndroid Build Coastguard Worker op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
946*61046927SAndroid Build Coastguard Worker op == nir_op_ibitfield_extract);
947*61046927SAndroid Build Coastguard Worker
948*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
949*61046927SAndroid Build Coastguard Worker unsigned src_bit_size =
950*61046927SAndroid Build Coastguard Worker nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
951*61046927SAndroid Build Coastguard Worker if (src_bit_size == 0)
952*61046927SAndroid Build Coastguard Worker continue;
953*61046927SAndroid Build Coastguard Worker if (src_bit_size != src[i]->bit_size) {
954*61046927SAndroid Build Coastguard Worker assert(src_bit_size == 32);
955*61046927SAndroid Build Coastguard Worker /* Convert the Shift, Offset and Count operands to 32 bits, which is the bitsize
956*61046927SAndroid Build Coastguard Worker * supported by the NIR instructions. See discussion here:
957*61046927SAndroid Build Coastguard Worker *
958*61046927SAndroid Build Coastguard Worker * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
959*61046927SAndroid Build Coastguard Worker */
960*61046927SAndroid Build Coastguard Worker src[i] = nir_u2u32(&b->nb, src[i]);
961*61046927SAndroid Build Coastguard Worker }
962*61046927SAndroid Build Coastguard Worker }
963*61046927SAndroid Build Coastguard Worker dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
964*61046927SAndroid Build Coastguard Worker break;
965*61046927SAndroid Build Coastguard Worker }
966*61046927SAndroid Build Coastguard Worker
967*61046927SAndroid Build Coastguard Worker case SpvOpSignBitSet:
968*61046927SAndroid Build Coastguard Worker dest->def = nir_i2b(&b->nb,
969*61046927SAndroid Build Coastguard Worker nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
970*61046927SAndroid Build Coastguard Worker break;
971*61046927SAndroid Build Coastguard Worker
972*61046927SAndroid Build Coastguard Worker case SpvOpUCountTrailingZerosINTEL:
973*61046927SAndroid Build Coastguard Worker dest->def = nir_umin(&b->nb,
974*61046927SAndroid Build Coastguard Worker nir_find_lsb(&b->nb, src[0]),
975*61046927SAndroid Build Coastguard Worker nir_imm_int(&b->nb, 32u));
976*61046927SAndroid Build Coastguard Worker break;
977*61046927SAndroid Build Coastguard Worker
978*61046927SAndroid Build Coastguard Worker case SpvOpBitCount: {
979*61046927SAndroid Build Coastguard Worker /* bit_count always returns int32, but the SPIR-V opcode just says the return
980*61046927SAndroid Build Coastguard Worker * value needs to be big enough to store the number of bits.
981*61046927SAndroid Build Coastguard Worker */
982*61046927SAndroid Build Coastguard Worker dest->def = nir_u2uN(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
983*61046927SAndroid Build Coastguard Worker break;
984*61046927SAndroid Build Coastguard Worker }
985*61046927SAndroid Build Coastguard Worker
986*61046927SAndroid Build Coastguard Worker case SpvOpSDotKHR:
987*61046927SAndroid Build Coastguard Worker case SpvOpUDotKHR:
988*61046927SAndroid Build Coastguard Worker case SpvOpSUDotKHR:
989*61046927SAndroid Build Coastguard Worker case SpvOpSDotAccSatKHR:
990*61046927SAndroid Build Coastguard Worker case SpvOpUDotAccSatKHR:
991*61046927SAndroid Build Coastguard Worker case SpvOpSUDotAccSatKHR:
992*61046927SAndroid Build Coastguard Worker unreachable("Should have called vtn_handle_integer_dot instead.");
993*61046927SAndroid Build Coastguard Worker
994*61046927SAndroid Build Coastguard Worker default: {
995*61046927SAndroid Build Coastguard Worker bool swap;
996*61046927SAndroid Build Coastguard Worker bool exact;
997*61046927SAndroid Build Coastguard Worker unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
998*61046927SAndroid Build Coastguard Worker unsigned dst_bit_size = glsl_get_bit_size(dest_type);
999*61046927SAndroid Build Coastguard Worker nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
1000*61046927SAndroid Build Coastguard Worker &exact,
1001*61046927SAndroid Build Coastguard Worker src_bit_size, dst_bit_size);
1002*61046927SAndroid Build Coastguard Worker
1003*61046927SAndroid Build Coastguard Worker if (swap) {
1004*61046927SAndroid Build Coastguard Worker nir_def *tmp = src[0];
1005*61046927SAndroid Build Coastguard Worker src[0] = src[1];
1006*61046927SAndroid Build Coastguard Worker src[1] = tmp;
1007*61046927SAndroid Build Coastguard Worker }
1008*61046927SAndroid Build Coastguard Worker
1009*61046927SAndroid Build Coastguard Worker switch (op) {
1010*61046927SAndroid Build Coastguard Worker case nir_op_ishl:
1011*61046927SAndroid Build Coastguard Worker case nir_op_ishr:
1012*61046927SAndroid Build Coastguard Worker case nir_op_ushr:
1013*61046927SAndroid Build Coastguard Worker if (src[1]->bit_size != 32)
1014*61046927SAndroid Build Coastguard Worker src[1] = nir_u2u32(&b->nb, src[1]);
1015*61046927SAndroid Build Coastguard Worker break;
1016*61046927SAndroid Build Coastguard Worker default:
1017*61046927SAndroid Build Coastguard Worker break;
1018*61046927SAndroid Build Coastguard Worker }
1019*61046927SAndroid Build Coastguard Worker
1020*61046927SAndroid Build Coastguard Worker const bool save_exact = b->nb.exact;
1021*61046927SAndroid Build Coastguard Worker
1022*61046927SAndroid Build Coastguard Worker if (exact)
1023*61046927SAndroid Build Coastguard Worker b->nb.exact = true;
1024*61046927SAndroid Build Coastguard Worker
1025*61046927SAndroid Build Coastguard Worker dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
1026*61046927SAndroid Build Coastguard Worker
1027*61046927SAndroid Build Coastguard Worker b->nb.exact = save_exact;
1028*61046927SAndroid Build Coastguard Worker break;
1029*61046927SAndroid Build Coastguard Worker } /* default */
1030*61046927SAndroid Build Coastguard Worker }
1031*61046927SAndroid Build Coastguard Worker
1032*61046927SAndroid Build Coastguard Worker switch (opcode) {
1033*61046927SAndroid Build Coastguard Worker case SpvOpIAdd:
1034*61046927SAndroid Build Coastguard Worker case SpvOpIMul:
1035*61046927SAndroid Build Coastguard Worker case SpvOpISub:
1036*61046927SAndroid Build Coastguard Worker case SpvOpShiftLeftLogical:
1037*61046927SAndroid Build Coastguard Worker case SpvOpSNegate: {
1038*61046927SAndroid Build Coastguard Worker nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
1039*61046927SAndroid Build Coastguard Worker vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
1040*61046927SAndroid Build Coastguard Worker break;
1041*61046927SAndroid Build Coastguard Worker }
1042*61046927SAndroid Build Coastguard Worker default:
1043*61046927SAndroid Build Coastguard Worker /* Do nothing. */
1044*61046927SAndroid Build Coastguard Worker break;
1045*61046927SAndroid Build Coastguard Worker }
1046*61046927SAndroid Build Coastguard Worker
1047*61046927SAndroid Build Coastguard Worker if (mediump_16bit)
1048*61046927SAndroid Build Coastguard Worker vtn_mediump_upconvert_value(b, dest);
1049*61046927SAndroid Build Coastguard Worker vtn_push_ssa_value(b, w[2], dest);
1050*61046927SAndroid Build Coastguard Worker
1051*61046927SAndroid Build Coastguard Worker b->nb.exact = b->exact;
1052*61046927SAndroid Build Coastguard Worker }
1053*61046927SAndroid Build Coastguard Worker
1054*61046927SAndroid Build Coastguard Worker void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)1055*61046927SAndroid Build Coastguard Worker vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
1056*61046927SAndroid Build Coastguard Worker const uint32_t *w, unsigned count)
1057*61046927SAndroid Build Coastguard Worker {
1058*61046927SAndroid Build Coastguard Worker struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
1059*61046927SAndroid Build Coastguard Worker const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
1060*61046927SAndroid Build Coastguard Worker const unsigned dest_size = glsl_get_bit_size(dest_type);
1061*61046927SAndroid Build Coastguard Worker
1062*61046927SAndroid Build Coastguard Worker vtn_handle_no_contraction(b, dest_val);
1063*61046927SAndroid Build Coastguard Worker
1064*61046927SAndroid Build Coastguard Worker /* Collect the various SSA sources.
1065*61046927SAndroid Build Coastguard Worker *
1066*61046927SAndroid Build Coastguard Worker * Due to the optional "Packed Vector Format" field, determine number of
1067*61046927SAndroid Build Coastguard Worker * inputs from the opcode. This differs from vtn_handle_alu.
1068*61046927SAndroid Build Coastguard Worker */
1069*61046927SAndroid Build Coastguard Worker const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
1070*61046927SAndroid Build Coastguard Worker opcode == SpvOpUDotAccSatKHR ||
1071*61046927SAndroid Build Coastguard Worker opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
1072*61046927SAndroid Build Coastguard Worker
1073*61046927SAndroid Build Coastguard Worker vtn_assert(count >= num_inputs + 3);
1074*61046927SAndroid Build Coastguard Worker
1075*61046927SAndroid Build Coastguard Worker struct vtn_ssa_value *vtn_src[3] = { NULL, };
1076*61046927SAndroid Build Coastguard Worker nir_def *src[3] = { NULL, };
1077*61046927SAndroid Build Coastguard Worker
1078*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < num_inputs; i++) {
1079*61046927SAndroid Build Coastguard Worker vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
1080*61046927SAndroid Build Coastguard Worker src[i] = vtn_src[i]->def;
1081*61046927SAndroid Build Coastguard Worker
1082*61046927SAndroid Build Coastguard Worker vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
1083*61046927SAndroid Build Coastguard Worker }
1084*61046927SAndroid Build Coastguard Worker
1085*61046927SAndroid Build Coastguard Worker /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
1086*61046927SAndroid Build Coastguard Worker * the SPV_KHR_integer_dot_product spec says:
1087*61046927SAndroid Build Coastguard Worker *
1088*61046927SAndroid Build Coastguard Worker * _Vector 1_ and _Vector 2_ must have the same type.
1089*61046927SAndroid Build Coastguard Worker *
1090*61046927SAndroid Build Coastguard Worker * The practical requirement is the same bit-size and the same number of
1091*61046927SAndroid Build Coastguard Worker * components.
1092*61046927SAndroid Build Coastguard Worker */
1093*61046927SAndroid Build Coastguard Worker vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
1094*61046927SAndroid Build Coastguard Worker glsl_get_bit_size(vtn_src[1]->type) ||
1095*61046927SAndroid Build Coastguard Worker glsl_get_vector_elements(vtn_src[0]->type) !=
1096*61046927SAndroid Build Coastguard Worker glsl_get_vector_elements(vtn_src[1]->type),
1097*61046927SAndroid Build Coastguard Worker "Vector 1 and vector 2 source of opcode %s must have the same "
1098*61046927SAndroid Build Coastguard Worker "type",
1099*61046927SAndroid Build Coastguard Worker spirv_op_to_string(opcode));
1100*61046927SAndroid Build Coastguard Worker
1101*61046927SAndroid Build Coastguard Worker if (num_inputs == 3) {
1102*61046927SAndroid Build Coastguard Worker /* The SPV_KHR_integer_dot_product spec says:
1103*61046927SAndroid Build Coastguard Worker *
1104*61046927SAndroid Build Coastguard Worker * The type of Accumulator must be the same as Result Type.
1105*61046927SAndroid Build Coastguard Worker *
1106*61046927SAndroid Build Coastguard Worker * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
1107*61046927SAndroid Build Coastguard Worker * types (far below) assumes these types have the same size.
1108*61046927SAndroid Build Coastguard Worker */
1109*61046927SAndroid Build Coastguard Worker vtn_fail_if(dest_type != vtn_src[2]->type,
1110*61046927SAndroid Build Coastguard Worker "Accumulator type must be the same as Result Type for "
1111*61046927SAndroid Build Coastguard Worker "opcode %s",
1112*61046927SAndroid Build Coastguard Worker spirv_op_to_string(opcode));
1113*61046927SAndroid Build Coastguard Worker }
1114*61046927SAndroid Build Coastguard Worker
1115*61046927SAndroid Build Coastguard Worker unsigned packed_bit_size = 8;
1116*61046927SAndroid Build Coastguard Worker if (glsl_type_is_vector(vtn_src[0]->type)) {
1117*61046927SAndroid Build Coastguard Worker /* FINISHME: Is this actually as good or better for platforms that don't
1118*61046927SAndroid Build Coastguard Worker * have the special instructions (i.e., one or both of has_dot_4x8 or
1119*61046927SAndroid Build Coastguard Worker * has_sudot_4x8 is false)?
1120*61046927SAndroid Build Coastguard Worker */
1121*61046927SAndroid Build Coastguard Worker if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
1122*61046927SAndroid Build Coastguard Worker glsl_get_bit_size(vtn_src[0]->type) == 8 &&
1123*61046927SAndroid Build Coastguard Worker glsl_get_bit_size(dest_type) <= 32) {
1124*61046927SAndroid Build Coastguard Worker src[0] = nir_pack_32_4x8(&b->nb, src[0]);
1125*61046927SAndroid Build Coastguard Worker src[1] = nir_pack_32_4x8(&b->nb, src[1]);
1126*61046927SAndroid Build Coastguard Worker } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
1127*61046927SAndroid Build Coastguard Worker glsl_get_bit_size(vtn_src[0]->type) == 16 &&
1128*61046927SAndroid Build Coastguard Worker glsl_get_bit_size(dest_type) <= 32 &&
1129*61046927SAndroid Build Coastguard Worker opcode != SpvOpSUDotKHR &&
1130*61046927SAndroid Build Coastguard Worker opcode != SpvOpSUDotAccSatKHR) {
1131*61046927SAndroid Build Coastguard Worker src[0] = nir_pack_32_2x16(&b->nb, src[0]);
1132*61046927SAndroid Build Coastguard Worker src[1] = nir_pack_32_2x16(&b->nb, src[1]);
1133*61046927SAndroid Build Coastguard Worker packed_bit_size = 16;
1134*61046927SAndroid Build Coastguard Worker }
1135*61046927SAndroid Build Coastguard Worker } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
1136*61046927SAndroid Build Coastguard Worker glsl_type_is_32bit(vtn_src[0]->type)) {
1137*61046927SAndroid Build Coastguard Worker /* The SPV_KHR_integer_dot_product spec says:
1138*61046927SAndroid Build Coastguard Worker *
1139*61046927SAndroid Build Coastguard Worker * When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
1140*61046927SAndroid Build Coastguard Worker * Vector Format_ must be specified to select how the integers are to
1141*61046927SAndroid Build Coastguard Worker * be interpreted as vectors.
1142*61046927SAndroid Build Coastguard Worker *
1143*61046927SAndroid Build Coastguard Worker * The "Packed Vector Format" value follows the last input.
1144*61046927SAndroid Build Coastguard Worker */
1145*61046927SAndroid Build Coastguard Worker vtn_assert(count == (num_inputs + 4));
1146*61046927SAndroid Build Coastguard Worker const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
1147*61046927SAndroid Build Coastguard Worker vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
1148*61046927SAndroid Build Coastguard Worker "Unsupported vector packing format %d for opcode %s",
1149*61046927SAndroid Build Coastguard Worker pack_format, spirv_op_to_string(opcode));
1150*61046927SAndroid Build Coastguard Worker } else {
1151*61046927SAndroid Build Coastguard Worker vtn_fail_with_opcode("Invalid source types.", opcode);
1152*61046927SAndroid Build Coastguard Worker }
1153*61046927SAndroid Build Coastguard Worker
1154*61046927SAndroid Build Coastguard Worker nir_def *dest = NULL;
1155*61046927SAndroid Build Coastguard Worker
1156*61046927SAndroid Build Coastguard Worker if (src[0]->num_components > 1) {
1157*61046927SAndroid Build Coastguard Worker nir_def *(*src0_conversion)(nir_builder *, nir_def *, unsigned);
1158*61046927SAndroid Build Coastguard Worker nir_def *(*src1_conversion)(nir_builder *, nir_def *, unsigned);
1159*61046927SAndroid Build Coastguard Worker
1160*61046927SAndroid Build Coastguard Worker switch (opcode) {
1161*61046927SAndroid Build Coastguard Worker case SpvOpSDotKHR:
1162*61046927SAndroid Build Coastguard Worker case SpvOpSDotAccSatKHR:
1163*61046927SAndroid Build Coastguard Worker src0_conversion = nir_i2iN;
1164*61046927SAndroid Build Coastguard Worker src1_conversion = nir_i2iN;
1165*61046927SAndroid Build Coastguard Worker break;
1166*61046927SAndroid Build Coastguard Worker
1167*61046927SAndroid Build Coastguard Worker case SpvOpUDotKHR:
1168*61046927SAndroid Build Coastguard Worker case SpvOpUDotAccSatKHR:
1169*61046927SAndroid Build Coastguard Worker src0_conversion = nir_u2uN;
1170*61046927SAndroid Build Coastguard Worker src1_conversion = nir_u2uN;
1171*61046927SAndroid Build Coastguard Worker break;
1172*61046927SAndroid Build Coastguard Worker
1173*61046927SAndroid Build Coastguard Worker case SpvOpSUDotKHR:
1174*61046927SAndroid Build Coastguard Worker case SpvOpSUDotAccSatKHR:
1175*61046927SAndroid Build Coastguard Worker src0_conversion = nir_i2iN;
1176*61046927SAndroid Build Coastguard Worker src1_conversion = nir_u2uN;
1177*61046927SAndroid Build Coastguard Worker break;
1178*61046927SAndroid Build Coastguard Worker
1179*61046927SAndroid Build Coastguard Worker default:
1180*61046927SAndroid Build Coastguard Worker unreachable("Invalid opcode.");
1181*61046927SAndroid Build Coastguard Worker }
1182*61046927SAndroid Build Coastguard Worker
1183*61046927SAndroid Build Coastguard Worker /* The SPV_KHR_integer_dot_product spec says:
1184*61046927SAndroid Build Coastguard Worker *
1185*61046927SAndroid Build Coastguard Worker * All components of the input vectors are sign-extended to the bit
1186*61046927SAndroid Build Coastguard Worker * width of the result's type. The sign-extended input vectors are
1187*61046927SAndroid Build Coastguard Worker * then multiplied component-wise and all components of the vector
1188*61046927SAndroid Build Coastguard Worker * resulting from the component-wise multiplication are added
1189*61046927SAndroid Build Coastguard Worker * together. The resulting value will equal the low-order N bits of
1190*61046927SAndroid Build Coastguard Worker * the correct result R, where N is the result width and R is
1191*61046927SAndroid Build Coastguard Worker * computed with enough precision to avoid overflow and underflow.
1192*61046927SAndroid Build Coastguard Worker */
1193*61046927SAndroid Build Coastguard Worker const unsigned vector_components =
1194*61046927SAndroid Build Coastguard Worker glsl_get_vector_elements(vtn_src[0]->type);
1195*61046927SAndroid Build Coastguard Worker
1196*61046927SAndroid Build Coastguard Worker for (unsigned i = 0; i < vector_components; i++) {
1197*61046927SAndroid Build Coastguard Worker nir_def *const src0 =
1198*61046927SAndroid Build Coastguard Worker src0_conversion(&b->nb, nir_channel(&b->nb, src[0], i), dest_size);
1199*61046927SAndroid Build Coastguard Worker
1200*61046927SAndroid Build Coastguard Worker nir_def *const src1 =
1201*61046927SAndroid Build Coastguard Worker src1_conversion(&b->nb, nir_channel(&b->nb, src[1], i), dest_size);
1202*61046927SAndroid Build Coastguard Worker
1203*61046927SAndroid Build Coastguard Worker nir_def *const mul_result = nir_imul(&b->nb, src0, src1);
1204*61046927SAndroid Build Coastguard Worker
1205*61046927SAndroid Build Coastguard Worker dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1206*61046927SAndroid Build Coastguard Worker }
1207*61046927SAndroid Build Coastguard Worker
1208*61046927SAndroid Build Coastguard Worker if (num_inputs == 3) {
1209*61046927SAndroid Build Coastguard Worker /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1210*61046927SAndroid Build Coastguard Worker *
1211*61046927SAndroid Build Coastguard Worker * Signed integer dot product of _Vector 1_ and _Vector 2_ and
1212*61046927SAndroid Build Coastguard Worker * signed saturating addition of the result with _Accumulator_.
1213*61046927SAndroid Build Coastguard Worker *
1214*61046927SAndroid Build Coastguard Worker * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1215*61046927SAndroid Build Coastguard Worker *
1216*61046927SAndroid Build Coastguard Worker * Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1217*61046927SAndroid Build Coastguard Worker * unsigned saturating addition of the result with _Accumulator_.
1218*61046927SAndroid Build Coastguard Worker *
1219*61046927SAndroid Build Coastguard Worker * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1220*61046927SAndroid Build Coastguard Worker *
1221*61046927SAndroid Build Coastguard Worker * Mixed-signedness integer dot product of _Vector 1_ and _Vector
1222*61046927SAndroid Build Coastguard Worker * 2_ and signed saturating addition of the result with
1223*61046927SAndroid Build Coastguard Worker * _Accumulator_.
1224*61046927SAndroid Build Coastguard Worker */
1225*61046927SAndroid Build Coastguard Worker dest = (opcode == SpvOpUDotAccSatKHR)
1226*61046927SAndroid Build Coastguard Worker ? nir_uadd_sat(&b->nb, dest, src[2])
1227*61046927SAndroid Build Coastguard Worker : nir_iadd_sat(&b->nb, dest, src[2]);
1228*61046927SAndroid Build Coastguard Worker }
1229*61046927SAndroid Build Coastguard Worker } else {
1230*61046927SAndroid Build Coastguard Worker assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1231*61046927SAndroid Build Coastguard Worker assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1232*61046927SAndroid Build Coastguard Worker
1233*61046927SAndroid Build Coastguard Worker nir_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1234*61046927SAndroid Build Coastguard Worker bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1235*61046927SAndroid Build Coastguard Worker opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1236*61046927SAndroid Build Coastguard Worker
1237*61046927SAndroid Build Coastguard Worker if (packed_bit_size == 16) {
1238*61046927SAndroid Build Coastguard Worker switch (opcode) {
1239*61046927SAndroid Build Coastguard Worker case SpvOpSDotKHR:
1240*61046927SAndroid Build Coastguard Worker dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1241*61046927SAndroid Build Coastguard Worker break;
1242*61046927SAndroid Build Coastguard Worker case SpvOpUDotKHR:
1243*61046927SAndroid Build Coastguard Worker dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1244*61046927SAndroid Build Coastguard Worker break;
1245*61046927SAndroid Build Coastguard Worker case SpvOpSDotAccSatKHR:
1246*61046927SAndroid Build Coastguard Worker if (dest_size == 32)
1247*61046927SAndroid Build Coastguard Worker dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1248*61046927SAndroid Build Coastguard Worker else
1249*61046927SAndroid Build Coastguard Worker dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1250*61046927SAndroid Build Coastguard Worker break;
1251*61046927SAndroid Build Coastguard Worker case SpvOpUDotAccSatKHR:
1252*61046927SAndroid Build Coastguard Worker if (dest_size == 32)
1253*61046927SAndroid Build Coastguard Worker dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1254*61046927SAndroid Build Coastguard Worker else
1255*61046927SAndroid Build Coastguard Worker dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1256*61046927SAndroid Build Coastguard Worker break;
1257*61046927SAndroid Build Coastguard Worker default:
1258*61046927SAndroid Build Coastguard Worker unreachable("Invalid opcode.");
1259*61046927SAndroid Build Coastguard Worker }
1260*61046927SAndroid Build Coastguard Worker } else {
1261*61046927SAndroid Build Coastguard Worker switch (opcode) {
1262*61046927SAndroid Build Coastguard Worker case SpvOpSDotKHR:
1263*61046927SAndroid Build Coastguard Worker dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1264*61046927SAndroid Build Coastguard Worker break;
1265*61046927SAndroid Build Coastguard Worker case SpvOpUDotKHR:
1266*61046927SAndroid Build Coastguard Worker dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1267*61046927SAndroid Build Coastguard Worker break;
1268*61046927SAndroid Build Coastguard Worker case SpvOpSUDotKHR:
1269*61046927SAndroid Build Coastguard Worker dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1270*61046927SAndroid Build Coastguard Worker break;
1271*61046927SAndroid Build Coastguard Worker case SpvOpSDotAccSatKHR:
1272*61046927SAndroid Build Coastguard Worker if (dest_size == 32)
1273*61046927SAndroid Build Coastguard Worker dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1274*61046927SAndroid Build Coastguard Worker else
1275*61046927SAndroid Build Coastguard Worker dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1276*61046927SAndroid Build Coastguard Worker break;
1277*61046927SAndroid Build Coastguard Worker case SpvOpUDotAccSatKHR:
1278*61046927SAndroid Build Coastguard Worker if (dest_size == 32)
1279*61046927SAndroid Build Coastguard Worker dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1280*61046927SAndroid Build Coastguard Worker else
1281*61046927SAndroid Build Coastguard Worker dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1282*61046927SAndroid Build Coastguard Worker break;
1283*61046927SAndroid Build Coastguard Worker case SpvOpSUDotAccSatKHR:
1284*61046927SAndroid Build Coastguard Worker if (dest_size == 32)
1285*61046927SAndroid Build Coastguard Worker dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1286*61046927SAndroid Build Coastguard Worker else
1287*61046927SAndroid Build Coastguard Worker dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1288*61046927SAndroid Build Coastguard Worker break;
1289*61046927SAndroid Build Coastguard Worker default:
1290*61046927SAndroid Build Coastguard Worker unreachable("Invalid opcode.");
1291*61046927SAndroid Build Coastguard Worker }
1292*61046927SAndroid Build Coastguard Worker }
1293*61046927SAndroid Build Coastguard Worker
1294*61046927SAndroid Build Coastguard Worker if (dest_size != 32) {
1295*61046927SAndroid Build Coastguard Worker /* When the accumulator is 32-bits, a NIR dot-product with saturate
1296*61046927SAndroid Build Coastguard Worker * is generated above. In all other cases a regular dot-product is
1297*61046927SAndroid Build Coastguard Worker * generated above, and separate addition with saturate is generated
1298*61046927SAndroid Build Coastguard Worker * here.
1299*61046927SAndroid Build Coastguard Worker *
1300*61046927SAndroid Build Coastguard Worker * The SPV_KHR_integer_dot_product spec says:
1301*61046927SAndroid Build Coastguard Worker *
1302*61046927SAndroid Build Coastguard Worker * If any of the multiplications or additions, with the exception
1303*61046927SAndroid Build Coastguard Worker * of the final accumulation, overflow or underflow, the result of
1304*61046927SAndroid Build Coastguard Worker * the instruction is undefined.
1305*61046927SAndroid Build Coastguard Worker *
1306*61046927SAndroid Build Coastguard Worker * Therefore it is safe to cast the dot-product result down to the
1307*61046927SAndroid Build Coastguard Worker * size of the accumulator before doing the addition. Since the
1308*61046927SAndroid Build Coastguard Worker * result of the dot-product cannot overflow 32-bits, this is also
1309*61046927SAndroid Build Coastguard Worker * safe to cast up.
1310*61046927SAndroid Build Coastguard Worker */
1311*61046927SAndroid Build Coastguard Worker if (num_inputs == 3) {
1312*61046927SAndroid Build Coastguard Worker dest = is_signed
1313*61046927SAndroid Build Coastguard Worker ? nir_iadd_sat(&b->nb, nir_i2iN(&b->nb, dest, dest_size), src[2])
1314*61046927SAndroid Build Coastguard Worker : nir_uadd_sat(&b->nb, nir_u2uN(&b->nb, dest, dest_size), src[2]);
1315*61046927SAndroid Build Coastguard Worker } else {
1316*61046927SAndroid Build Coastguard Worker dest = is_signed
1317*61046927SAndroid Build Coastguard Worker ? nir_i2iN(&b->nb, dest, dest_size)
1318*61046927SAndroid Build Coastguard Worker : nir_u2uN(&b->nb, dest, dest_size);
1319*61046927SAndroid Build Coastguard Worker }
1320*61046927SAndroid Build Coastguard Worker }
1321*61046927SAndroid Build Coastguard Worker }
1322*61046927SAndroid Build Coastguard Worker
1323*61046927SAndroid Build Coastguard Worker vtn_push_nir_ssa(b, w[2], dest);
1324*61046927SAndroid Build Coastguard Worker
1325*61046927SAndroid Build Coastguard Worker b->nb.exact = b->exact;
1326*61046927SAndroid Build Coastguard Worker }
1327*61046927SAndroid Build Coastguard Worker
1328*61046927SAndroid Build Coastguard Worker void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1329*61046927SAndroid Build Coastguard Worker vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1330*61046927SAndroid Build Coastguard Worker {
1331*61046927SAndroid Build Coastguard Worker vtn_assert(count == 4);
1332*61046927SAndroid Build Coastguard Worker /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1333*61046927SAndroid Build Coastguard Worker *
1334*61046927SAndroid Build Coastguard Worker * "If Result Type has the same number of components as Operand, they
1335*61046927SAndroid Build Coastguard Worker * must also have the same component width, and results are computed per
1336*61046927SAndroid Build Coastguard Worker * component.
1337*61046927SAndroid Build Coastguard Worker *
1338*61046927SAndroid Build Coastguard Worker * If Result Type has a different number of components than Operand, the
1339*61046927SAndroid Build Coastguard Worker * total number of bits in Result Type must equal the total number of
1340*61046927SAndroid Build Coastguard Worker * bits in Operand. Let L be the type, either Result Type or Operand’s
1341*61046927SAndroid Build Coastguard Worker * type, that has the larger number of components. Let S be the other
1342*61046927SAndroid Build Coastguard Worker * type, with the smaller number of components. The number of components
1343*61046927SAndroid Build Coastguard Worker * in L must be an integer multiple of the number of components in S.
1344*61046927SAndroid Build Coastguard Worker * The first component (that is, the only or lowest-numbered component)
1345*61046927SAndroid Build Coastguard Worker * of S maps to the first components of L, and so on, up to the last
1346*61046927SAndroid Build Coastguard Worker * component of S mapping to the last components of L. Within this
1347*61046927SAndroid Build Coastguard Worker * mapping, any single component of S (mapping to multiple components of
1348*61046927SAndroid Build Coastguard Worker * L) maps its lower-ordered bits to the lower-numbered components of L."
1349*61046927SAndroid Build Coastguard Worker */
1350*61046927SAndroid Build Coastguard Worker
1351*61046927SAndroid Build Coastguard Worker struct vtn_type *type = vtn_get_type(b, w[1]);
1352*61046927SAndroid Build Coastguard Worker if (type->base_type == vtn_base_type_cooperative_matrix) {
1353*61046927SAndroid Build Coastguard Worker vtn_handle_cooperative_instruction(b, SpvOpBitcast, w, count);
1354*61046927SAndroid Build Coastguard Worker return;
1355*61046927SAndroid Build Coastguard Worker }
1356*61046927SAndroid Build Coastguard Worker
1357*61046927SAndroid Build Coastguard Worker struct nir_def *src = vtn_get_nir_ssa(b, w[3]);
1358*61046927SAndroid Build Coastguard Worker
1359*61046927SAndroid Build Coastguard Worker vtn_fail_if(src->num_components * src->bit_size !=
1360*61046927SAndroid Build Coastguard Worker glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1361*61046927SAndroid Build Coastguard Worker "Source (%%%u) and destination (%%%u) of OpBitcast must have the same "
1362*61046927SAndroid Build Coastguard Worker "total number of bits", w[3], w[2]);
1363*61046927SAndroid Build Coastguard Worker nir_def *val =
1364*61046927SAndroid Build Coastguard Worker nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1365*61046927SAndroid Build Coastguard Worker vtn_push_nir_ssa(b, w[2], val);
1366*61046927SAndroid Build Coastguard Worker }
1367