1 /*
2 * Copyright © 2020 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #ifndef NIR_CONVERSION_BUILDER_H
25 #define NIR_CONVERSION_BUILDER_H
26
27 #include "util/u_math.h"
28 #include "nir_builder.h"
29 #include "nir_builtin_builder.h"
30
31 #ifdef __cplusplus
32 extern "C" {
33 #endif
34
35 static inline nir_def *
nir_round_float_to_int(nir_builder * b,nir_def * src,nir_rounding_mode round)36 nir_round_float_to_int(nir_builder *b, nir_def *src,
37 nir_rounding_mode round)
38 {
39 switch (round) {
40 case nir_rounding_mode_ru:
41 return nir_fceil(b, src);
42
43 case nir_rounding_mode_rd:
44 return nir_ffloor(b, src);
45
46 case nir_rounding_mode_rtne:
47 return nir_fround_even(b, src);
48
49 case nir_rounding_mode_undef:
50 case nir_rounding_mode_rtz:
51 break;
52 }
53 unreachable("unexpected rounding mode");
54 }
55
56 static inline nir_def *
nir_round_float_to_float(nir_builder * b,nir_def * src,unsigned dest_bit_size,nir_rounding_mode round)57 nir_round_float_to_float(nir_builder *b, nir_def *src,
58 unsigned dest_bit_size,
59 nir_rounding_mode round)
60 {
61 unsigned src_bit_size = src->bit_size;
62 if (dest_bit_size > src_bit_size)
63 return src; /* No rounding is needed for an up-convert */
64
65 nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size,
66 nir_type_float | dest_bit_size,
67 nir_rounding_mode_undef);
68 nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size,
69 nir_type_float | src_bit_size,
70 nir_rounding_mode_undef);
71
72 switch (round) {
73 case nir_rounding_mode_ru: {
74 /* If lower-precision conversion results in a lower value, push it
75 * up one ULP. */
76 nir_def *lower_prec =
77 nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
78 nir_def *roundtrip =
79 nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
80 nir_def *cmp = nir_flt(b, roundtrip, src);
81 nir_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size);
82 return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec);
83 }
84 case nir_rounding_mode_rd: {
85 /* If lower-precision conversion results in a higher value, push it
86 * down one ULP. */
87 nir_def *lower_prec =
88 nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
89 nir_def *roundtrip =
90 nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
91 nir_def *cmp = nir_flt(b, src, roundtrip);
92 nir_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size);
93 return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec);
94 }
95 case nir_rounding_mode_rtz:
96 return nir_bcsel(b, nir_flt_imm(b, src, 1),
97 nir_round_float_to_float(b, src, dest_bit_size,
98 nir_rounding_mode_ru),
99 nir_round_float_to_float(b, src, dest_bit_size,
100 nir_rounding_mode_rd));
101 case nir_rounding_mode_rtne:
102 case nir_rounding_mode_undef:
103 break;
104 }
105 unreachable("unexpected rounding mode");
106 }
107
108 static inline nir_def *
nir_round_int_to_float(nir_builder * b,nir_def * src,nir_alu_type src_type,unsigned dest_bit_size,nir_rounding_mode round)109 nir_round_int_to_float(nir_builder *b, nir_def *src,
110 nir_alu_type src_type,
111 unsigned dest_bit_size,
112 nir_rounding_mode round)
113 {
114 /* We only care whether or not its signed */
115 src_type = nir_alu_type_get_base_type(src_type);
116
117 unsigned mantissa_bits;
118 switch (dest_bit_size) {
119 case 16:
120 mantissa_bits = 10;
121 break;
122 case 32:
123 mantissa_bits = 23;
124 break;
125 case 64:
126 mantissa_bits = 52;
127 break;
128 default:
129 unreachable("Unsupported bit size");
130 }
131
132 if (src->bit_size < mantissa_bits)
133 return src;
134
135 if (src_type == nir_type_int) {
136 nir_def *sign =
137 nir_i2b(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1)));
138 nir_def *abs = nir_iabs(b, src);
139 nir_def *positive_rounded =
140 nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round);
141 nir_def *max_positive =
142 nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size);
143 switch (round) {
144 case nir_rounding_mode_rtz:
145 return nir_bcsel(b, sign, nir_ineg(b, positive_rounded),
146 positive_rounded);
147 break;
148 case nir_rounding_mode_ru:
149 return nir_bcsel(b, sign,
150 nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)),
151 nir_umin(b, positive_rounded, max_positive));
152 break;
153 case nir_rounding_mode_rd:
154 return nir_bcsel(b, sign,
155 nir_ineg(b,
156 nir_umin(b, max_positive,
157 nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))),
158 positive_rounded);
159 case nir_rounding_mode_rtne:
160 case nir_rounding_mode_undef:
161 break;
162 }
163 unreachable("unexpected rounding mode");
164 } else {
165 nir_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits);
166 nir_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size);
167 nir_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size);
168 nir_def *one = nir_imm_intN_t(b, 1, src->bit_size);
169 nir_def *adjust = nir_ishl(b, one, bits_to_lose);
170 nir_def *mask = nir_inot(b, nir_isub(b, adjust, one));
171 nir_def *truncated = nir_iand(b, src, mask);
172 switch (round) {
173 case nir_rounding_mode_rtz:
174 case nir_rounding_mode_rd:
175 return truncated;
176 break;
177 case nir_rounding_mode_ru:
178 return nir_bcsel(b, nir_ieq(b, src, truncated),
179 src, nir_uadd_sat(b, truncated, adjust));
180 case nir_rounding_mode_rtne:
181 case nir_rounding_mode_undef:
182 break;
183 }
184 unreachable("unexpected rounding mode");
185 }
186 }
187
188 /** Returns true if the representable range of a contains the representable
189 * range of b.
190 */
191 static inline bool
nir_alu_type_range_contains_type_range(nir_alu_type a,nir_alu_type b)192 nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)
193 {
194 /* Split types from bit sizes */
195 nir_alu_type a_base_type = nir_alu_type_get_base_type(a);
196 nir_alu_type b_base_type = nir_alu_type_get_base_type(b);
197 unsigned a_bit_size = nir_alu_type_get_type_size(a);
198 unsigned b_bit_size = nir_alu_type_get_type_size(b);
199
200 /* This requires sized types */
201 assert(a_bit_size > 0 && b_bit_size > 0);
202
203 if (a_base_type == b_base_type && a_bit_size >= b_bit_size)
204 return true;
205
206 if (a_base_type == nir_type_int && b_base_type == nir_type_uint &&
207 a_bit_size > b_bit_size)
208 return true;
209
210 /* 16-bit floats fit in 32-bit integers */
211 if (a_base_type == nir_type_int && a_bit_size >= 32 &&
212 b == nir_type_float16)
213 return true;
214
215 /* All signed or unsigned ints can fit in float or above. A uint8 can fit
216 * in a float16.
217 */
218 if (a_base_type == nir_type_float && b_base_type != nir_type_float &&
219 (a_bit_size >= 32 || b_bit_size == 8))
220 return true;
221
222 return false;
223 }
224
225 /**
226 * Retrieves limits used for clamping a value of the src type into
227 * the widest representable range of the dst type via cmp + bcsel
228 */
229 static inline void
nir_get_clamp_limits(nir_builder * b,nir_alu_type src_type,nir_alu_type dest_type,nir_def ** low,nir_def ** high)230 nir_get_clamp_limits(nir_builder *b,
231 nir_alu_type src_type,
232 nir_alu_type dest_type,
233 nir_def **low, nir_def **high)
234 {
235 /* Split types from bit sizes */
236 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
237 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
238 unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
239 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
240 assert(dest_bit_size != 0 && src_bit_size != 0);
241
242 *low = NULL;
243 *high = NULL;
244
245 /* limits of the destination type, expressed in the source type */
246 switch (dest_base_type) {
247 case nir_type_int: {
248 int64_t ilow, ihigh;
249 if (dest_bit_size == 64) {
250 ilow = INT64_MIN;
251 ihigh = INT64_MAX;
252 } else {
253 ilow = -(1ll << (dest_bit_size - 1));
254 ihigh = (1ll << (dest_bit_size - 1)) - 1;
255 }
256
257 if (src_base_type == nir_type_int) {
258 *low = nir_imm_intN_t(b, ilow, src_bit_size);
259 *high = nir_imm_intN_t(b, ihigh, src_bit_size);
260 } else if (src_base_type == nir_type_uint) {
261 assert(src_bit_size >= dest_bit_size);
262 *high = nir_imm_intN_t(b, ihigh, src_bit_size);
263 } else {
264 *low = nir_imm_floatN_t(b, ilow, src_bit_size);
265 *high = nir_imm_floatN_t(b, ihigh, src_bit_size);
266 }
267 break;
268 }
269 case nir_type_uint: {
270 uint64_t uhigh = dest_bit_size == 64 ? ~0ull : (1ull << dest_bit_size) - 1;
271 if (src_base_type != nir_type_float) {
272 *low = nir_imm_intN_t(b, 0, src_bit_size);
273 if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size)
274 *high = nir_imm_intN_t(b, uhigh, src_bit_size);
275 } else {
276 *low = nir_imm_floatN_t(b, 0.0f, src_bit_size);
277 *high = nir_imm_floatN_t(b, uhigh, src_bit_size);
278 }
279 break;
280 }
281 case nir_type_float: {
282 double flow, fhigh;
283 switch (dest_bit_size) {
284 case 16:
285 flow = -65504.0f;
286 fhigh = 65504.0f;
287 break;
288 case 32:
289 flow = -FLT_MAX;
290 fhigh = FLT_MAX;
291 break;
292 case 64:
293 flow = -DBL_MAX;
294 fhigh = DBL_MAX;
295 break;
296 default:
297 unreachable("Unhandled bit size");
298 }
299
300 switch (src_base_type) {
301 case nir_type_int: {
302 int64_t src_ilow, src_ihigh;
303 if (src_bit_size == 64) {
304 src_ilow = INT64_MIN;
305 src_ihigh = INT64_MAX;
306 } else {
307 src_ilow = -(1ll << (src_bit_size - 1));
308 src_ihigh = (1ll << (src_bit_size - 1)) - 1;
309 }
310 if (src_ilow < flow)
311 *low = nir_imm_intN_t(b, flow, src_bit_size);
312 if (src_ihigh > fhigh)
313 *high = nir_imm_intN_t(b, fhigh, src_bit_size);
314 break;
315 }
316 case nir_type_uint: {
317 uint64_t src_uhigh = src_bit_size == 64 ? ~0ull : (1ull << src_bit_size) - 1;
318 if (src_uhigh > fhigh)
319 *high = nir_imm_intN_t(b, fhigh, src_bit_size);
320 break;
321 }
322 case nir_type_float:
323 *low = nir_imm_floatN_t(b, flow, src_bit_size);
324 *high = nir_imm_floatN_t(b, fhigh, src_bit_size);
325 break;
326 default:
327 unreachable("Clamping from unknown type");
328 }
329 break;
330 }
331 default:
332 unreachable("clamping to unknown type");
333 break;
334 }
335 }
336
337 /**
338 * Clamp the value into the widest representatble range of the
339 * destination type with cmp + bcsel.
340 *
341 * val/val_type: The variables used for bcsel
342 * src/src_type: The variables used for comparison
343 * dest_type: The type which determines the range used for comparison
344 */
345 static inline nir_def *
nir_clamp_to_type_range(nir_builder * b,nir_def * val,nir_alu_type val_type,nir_def * src,nir_alu_type src_type,nir_alu_type dest_type)346 nir_clamp_to_type_range(nir_builder *b,
347 nir_def *val, nir_alu_type val_type,
348 nir_def *src, nir_alu_type src_type,
349 nir_alu_type dest_type)
350 {
351 assert(nir_alu_type_get_type_size(src_type) == 0 ||
352 nir_alu_type_get_type_size(src_type) == src->bit_size);
353 src_type |= src->bit_size;
354 if (nir_alu_type_range_contains_type_range(dest_type, src_type))
355 return val;
356
357 /* limits of the destination type, expressed in the source type */
358 nir_def *low = NULL, *high = NULL;
359 nir_get_clamp_limits(b, src_type, dest_type, &low, &high);
360
361 nir_def *low_cond = NULL, *high_cond = NULL;
362 switch (nir_alu_type_get_base_type(src_type)) {
363 case nir_type_int:
364 low_cond = low ? nir_ilt(b, src, low) : NULL;
365 high_cond = high ? nir_ilt(b, high, src) : NULL;
366 break;
367 case nir_type_uint:
368 low_cond = low ? nir_ult(b, src, low) : NULL;
369 high_cond = high ? nir_ult(b, high, src) : NULL;
370 break;
371 case nir_type_float:
372 low_cond = low ? nir_fge(b, low, src) : NULL;
373 high_cond = high ? nir_fge(b, src, high) : NULL;
374 break;
375 default:
376 unreachable("clamping from unknown type");
377 }
378
379 nir_def *val_low = low, *val_high = high;
380 if (val_type != src_type) {
381 nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high);
382 }
383
384 nir_def *res = val;
385 if (low_cond && val_low)
386 res = nir_bcsel(b, low_cond, val_low, res);
387 if (high_cond && val_high)
388 res = nir_bcsel(b, high_cond, val_high, res);
389
390 return res;
391 }
392
393 static inline nir_rounding_mode
nir_simplify_conversion_rounding(nir_alu_type src_type,nir_alu_type dest_type,nir_rounding_mode rounding)394 nir_simplify_conversion_rounding(nir_alu_type src_type,
395 nir_alu_type dest_type,
396 nir_rounding_mode rounding)
397 {
398 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
399 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
400 unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
401 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
402 assert(src_bit_size > 0 && dest_bit_size > 0);
403
404 if (rounding == nir_rounding_mode_undef)
405 return rounding;
406
407 /* Pure integer conversion doesn't have any rounding */
408 if (src_base_type != nir_type_float &&
409 dest_base_type != nir_type_float)
410 return nir_rounding_mode_undef;
411
412 /* Float down-casts don't round */
413 if (src_base_type == nir_type_float &&
414 dest_base_type == nir_type_float &&
415 dest_bit_size >= src_bit_size)
416 return nir_rounding_mode_undef;
417
418 /* Regular float to int conversions are RTZ */
419 if (src_base_type == nir_type_float &&
420 dest_base_type != nir_type_float &&
421 rounding == nir_rounding_mode_rtz)
422 return nir_rounding_mode_undef;
423
424 /* The CL spec requires regular conversions to float to be RTNE */
425 if (dest_base_type == nir_type_float &&
426 rounding == nir_rounding_mode_rtne)
427 return nir_rounding_mode_undef;
428
429 /* Couldn't simplify */
430 return rounding;
431 }
432
433 static inline nir_def *
nir_convert_with_rounding(nir_builder * b,nir_def * src,nir_alu_type src_type,nir_alu_type dest_type,nir_rounding_mode round,bool clamp)434 nir_convert_with_rounding(nir_builder *b,
435 nir_def *src, nir_alu_type src_type,
436 nir_alu_type dest_type,
437 nir_rounding_mode round,
438 bool clamp)
439 {
440 /* Some stuff wants sized types */
441 assert(nir_alu_type_get_type_size(src_type) == 0 ||
442 nir_alu_type_get_type_size(src_type) == src->bit_size);
443 src_type |= src->bit_size;
444
445 /* Split types from bit sizes */
446 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
447 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
448 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
449
450 /* Try to simplify the conversion if we can */
451 clamp = clamp &&
452 !nir_alu_type_range_contains_type_range(dest_type, src_type);
453 round = nir_simplify_conversion_rounding(src_type, dest_type, round);
454
455 /* For float -> int/uint conversions, we might not be able to represent
456 * the destination range in the source float accurately. For these cases,
457 * do the comparison in float range, but the bcsel in the destination range.
458 */
459 bool clamp_after_conversion = clamp &&
460 src_base_type == nir_type_float &&
461 dest_base_type != nir_type_float;
462
463 /*
464 * If we don't care about rounding and clamping, we can just use NIR's
465 * built-in ops. There is also a special case for SPIR-V in shaders, where
466 * f32/f64 -> f16 conversions can have one of two rounding modes applied,
467 * which NIR has built-in opcodes for.
468 *
469 * For the rest, we have our own implementation of rounding and clamping.
470 */
471 bool trivial_convert;
472 if (!clamp && round == nir_rounding_mode_undef) {
473 trivial_convert = true;
474 } else if (!clamp && src_type == nir_type_float32 &&
475 dest_type == nir_type_float16 &&
476 (round == nir_rounding_mode_rtne ||
477 round == nir_rounding_mode_rtz)) {
478 trivial_convert = true;
479 } else {
480 trivial_convert = false;
481 }
482
483 if (trivial_convert)
484 return nir_type_convert(b, src, src_type, dest_type, round);
485
486 nir_def *dest = src;
487
488 /* clamp the result into range */
489 if (clamp && !clamp_after_conversion)
490 dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type);
491
492 /* round with selected rounding mode */
493 if (!trivial_convert && round != nir_rounding_mode_undef) {
494 if (src_base_type == nir_type_float) {
495 if (dest_base_type == nir_type_float) {
496 dest = nir_round_float_to_float(b, dest, dest_bit_size, round);
497 } else {
498 dest = nir_round_float_to_int(b, dest, round);
499 }
500 } else {
501 dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round);
502 }
503
504 round = nir_rounding_mode_undef;
505 }
506
507 /* now we can convert the value */
508 nir_op op = nir_type_conversion_op(src_type, dest_type, round);
509 dest = nir_build_alu(b, op, dest, NULL, NULL, NULL);
510
511 if (clamp_after_conversion)
512 dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type);
513
514 return dest;
515 }
516
517 #ifdef __cplusplus
518 }
519 #endif
520
521 #endif /* NIR_CONVERSION_BUILDER_H */
522