xref: /aosp_15_r20/external/angle/src/compiler/translator/ConstantUnion.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2016 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // ConstantUnion: Constant folding helper class.
7 
8 #include "compiler/translator/ConstantUnion.h"
9 
10 #include "common/mathutil.h"
11 #include "compiler/translator/Diagnostics.h"
12 #include "compiler/translator/util.h"
13 
14 namespace sh
15 {
16 
17 namespace
18 {
19 
CheckedSum(float lhs,float rhs,TDiagnostics * diag,const TSourceLoc & line)20 float CheckedSum(float lhs, float rhs, TDiagnostics *diag, const TSourceLoc &line)
21 {
22     float result = lhs + rhs;
23     if (gl::isNaN(result) && !gl::isNaN(lhs) && !gl::isNaN(rhs))
24     {
25         diag->warning(line, "Constant folded undefined addition generated NaN", "+");
26     }
27     else if (gl::isInf(result) && !gl::isInf(lhs) && !gl::isInf(rhs))
28     {
29         diag->warning(line, "Constant folded addition overflowed to infinity", "+");
30     }
31     return result;
32 }
33 
CheckedDiff(float lhs,float rhs,TDiagnostics * diag,const TSourceLoc & line)34 float CheckedDiff(float lhs, float rhs, TDiagnostics *diag, const TSourceLoc &line)
35 {
36     float result = lhs - rhs;
37     if (gl::isNaN(result) && !gl::isNaN(lhs) && !gl::isNaN(rhs))
38     {
39         diag->warning(line, "Constant folded undefined subtraction generated NaN", "-");
40     }
41     else if (gl::isInf(result) && !gl::isInf(lhs) && !gl::isInf(rhs))
42     {
43         diag->warning(line, "Constant folded subtraction overflowed to infinity", "-");
44     }
45     return result;
46 }
47 
CheckedMul(float lhs,float rhs,TDiagnostics * diag,const TSourceLoc & line)48 float CheckedMul(float lhs, float rhs, TDiagnostics *diag, const TSourceLoc &line)
49 {
50     float result = lhs * rhs;
51     if (gl::isNaN(result) && !gl::isNaN(lhs) && !gl::isNaN(rhs))
52     {
53         diag->warning(line, "Constant folded undefined multiplication generated NaN", "*");
54     }
55     else if (gl::isInf(result) && !gl::isInf(lhs) && !gl::isInf(rhs))
56     {
57         diag->warning(line, "Constant folded multiplication overflowed to infinity", "*");
58     }
59     return result;
60 }
61 
IsValidShiftOffset(const TConstantUnion & rhs)62 bool IsValidShiftOffset(const TConstantUnion &rhs)
63 {
64     return (rhs.getType() == EbtInt && (rhs.getIConst() >= 0 && rhs.getIConst() <= 31)) ||
65            (rhs.getType() == EbtUInt && rhs.getUConst() <= 31u);
66 }
67 
68 }  // anonymous namespace
69 
TConstantUnion()70 TConstantUnion::TConstantUnion() : iConst(0), type(EbtVoid) {}
71 
TConstantUnion(int i)72 TConstantUnion::TConstantUnion(int i) : iConst(i), type(EbtInt) {}
73 
TConstantUnion(unsigned int u)74 TConstantUnion::TConstantUnion(unsigned int u) : uConst(u), type(EbtUInt) {}
75 
TConstantUnion(float f)76 TConstantUnion::TConstantUnion(float f) : fConst(f), type(EbtFloat) {}
77 
TConstantUnion(bool b)78 TConstantUnion::TConstantUnion(bool b) : bConst(b), type(EbtBool) {}
79 
getIConst() const80 int TConstantUnion::getIConst() const
81 {
82     ASSERT(type == EbtInt);
83     return iConst;
84 }
85 
getUConst() const86 unsigned int TConstantUnion::getUConst() const
87 {
88     ASSERT(type == EbtUInt);
89     return uConst;
90 }
91 
getFConst() const92 float TConstantUnion::getFConst() const
93 {
94     switch (type)
95     {
96         case EbtInt:
97             return static_cast<float>(iConst);
98         case EbtUInt:
99             return static_cast<float>(uConst);
100         default:
101             ASSERT(type == EbtFloat);
102             return fConst;
103     }
104 }
105 
getBConst() const106 bool TConstantUnion::getBConst() const
107 {
108     ASSERT(type == EbtBool);
109     return bConst;
110 }
111 
isZero() const112 bool TConstantUnion::isZero() const
113 {
114     switch (type)
115     {
116         case EbtInt:
117             return getIConst() == 0;
118         case EbtUInt:
119             return getUConst() == 0;
120         case EbtFloat:
121             return getFConst() == 0.0f;
122         case EbtBool:
123             return getBConst() == false;
124         default:
125             return false;
126     }
127 }
128 
getYuvCscStandardEXTConst() const129 TYuvCscStandardEXT TConstantUnion::getYuvCscStandardEXTConst() const
130 {
131     ASSERT(type == EbtYuvCscStandardEXT);
132     return yuvCscStandardEXTConst;
133 }
134 
cast(TBasicType newType,const TConstantUnion & constant)135 bool TConstantUnion::cast(TBasicType newType, const TConstantUnion &constant)
136 {
137     switch (newType)
138     {
139         case EbtFloat:
140             switch (constant.type)
141             {
142                 case EbtInt:
143                     setFConst(static_cast<float>(constant.getIConst()));
144                     break;
145                 case EbtUInt:
146                     setFConst(static_cast<float>(constant.getUConst()));
147                     break;
148                 case EbtBool:
149                     setFConst(static_cast<float>(constant.getBConst()));
150                     break;
151                 case EbtFloat:
152                     setFConst(static_cast<float>(constant.getFConst()));
153                     break;
154                 default:
155                     return false;
156             }
157             break;
158         case EbtInt:
159             switch (constant.type)
160             {
161                 case EbtInt:
162                     setIConst(static_cast<int>(constant.getIConst()));
163                     break;
164                 case EbtUInt:
165                     setIConst(static_cast<int>(constant.getUConst()));
166                     break;
167                 case EbtBool:
168                     setIConst(static_cast<int>(constant.getBConst()));
169                     break;
170                 case EbtFloat:
171                     setIConst(static_cast<int>(constant.getFConst()));
172                     break;
173                 default:
174                     return false;
175             }
176             break;
177         case EbtUInt:
178             switch (constant.type)
179             {
180                 case EbtInt:
181                     setUConst(static_cast<unsigned int>(constant.getIConst()));
182                     break;
183                 case EbtUInt:
184                     setUConst(static_cast<unsigned int>(constant.getUConst()));
185                     break;
186                 case EbtBool:
187                     setUConst(static_cast<unsigned int>(constant.getBConst()));
188                     break;
189                 case EbtFloat:
190                     if (constant.getFConst() < 0.0f)
191                     {
192                         // Avoid undefined behavior in C++ by first casting to signed int.
193                         setUConst(
194                             static_cast<unsigned int>(static_cast<int>(constant.getFConst())));
195                     }
196                     else
197                     {
198                         setUConst(static_cast<unsigned int>(constant.getFConst()));
199                     }
200                     break;
201                 default:
202                     return false;
203             }
204             break;
205         case EbtBool:
206             switch (constant.type)
207             {
208                 case EbtInt:
209                     setBConst(constant.getIConst() != 0);
210                     break;
211                 case EbtUInt:
212                     setBConst(constant.getUConst() != 0);
213                     break;
214                 case EbtBool:
215                     setBConst(constant.getBConst());
216                     break;
217                 case EbtFloat:
218                     setBConst(constant.getFConst() != 0.0f);
219                     break;
220                 default:
221                     return false;
222             }
223             break;
224         case EbtStruct:  // Struct fields don't get cast
225             switch (constant.type)
226             {
227                 case EbtInt:
228                     setIConst(constant.getIConst());
229                     break;
230                 case EbtUInt:
231                     setUConst(constant.getUConst());
232                     break;
233                 case EbtBool:
234                     setBConst(constant.getBConst());
235                     break;
236                 case EbtFloat:
237                     setFConst(constant.getFConst());
238                     break;
239                 default:
240                     return false;
241             }
242             break;
243         case EbtYuvCscStandardEXT:
244             switch (constant.type)
245             {
246                 case EbtYuvCscStandardEXT:
247                     setYuvCscStandardEXTConst(constant.getYuvCscStandardEXTConst());
248                     break;
249                 default:
250                     return false;
251             }
252             break;
253         default:
254             return false;
255     }
256 
257     return true;
258 }
259 
operator ==(const int i) const260 bool TConstantUnion::operator==(const int i) const
261 {
262     switch (type)
263     {
264         case EbtFloat:
265             return static_cast<float>(i) == fConst;
266         default:
267             return i == iConst;
268     }
269 }
270 
operator ==(const unsigned int u) const271 bool TConstantUnion::operator==(const unsigned int u) const
272 {
273     switch (type)
274     {
275         case EbtFloat:
276             return static_cast<float>(u) == fConst;
277         default:
278             return u == uConst;
279     }
280 }
281 
operator ==(const float f) const282 bool TConstantUnion::operator==(const float f) const
283 {
284     switch (type)
285     {
286         case EbtInt:
287             return f == static_cast<float>(iConst);
288         case EbtUInt:
289             return f == static_cast<float>(uConst);
290         default:
291             return f == fConst;
292     }
293 }
294 
operator ==(const bool b) const295 bool TConstantUnion::operator==(const bool b) const
296 {
297     return b == bConst;
298 }
299 
operator ==(const TYuvCscStandardEXT s) const300 bool TConstantUnion::operator==(const TYuvCscStandardEXT s) const
301 {
302     return s == yuvCscStandardEXTConst;
303 }
304 
operator ==(const TConstantUnion & constant) const305 bool TConstantUnion::operator==(const TConstantUnion &constant) const
306 {
307     switch (type)
308     {
309         case EbtInt:
310             return constant.iConst == iConst;
311         case EbtUInt:
312             return constant.uConst == uConst;
313         case EbtFloat:
314             return constant.fConst == fConst;
315         case EbtBool:
316             return constant.bConst == bConst;
317         case EbtYuvCscStandardEXT:
318             return constant.yuvCscStandardEXTConst == yuvCscStandardEXTConst;
319         default:
320             return false;
321     }
322 }
323 
operator !=(const int i) const324 bool TConstantUnion::operator!=(const int i) const
325 {
326     return !operator==(i);
327 }
328 
operator !=(const unsigned int u) const329 bool TConstantUnion::operator!=(const unsigned int u) const
330 {
331     return !operator==(u);
332 }
333 
operator !=(const float f) const334 bool TConstantUnion::operator!=(const float f) const
335 {
336     return !operator==(f);
337 }
338 
operator !=(const bool b) const339 bool TConstantUnion::operator!=(const bool b) const
340 {
341     return !operator==(b);
342 }
343 
operator !=(const TYuvCscStandardEXT s) const344 bool TConstantUnion::operator!=(const TYuvCscStandardEXT s) const
345 {
346     return !operator==(s);
347 }
348 
operator !=(const TConstantUnion & constant) const349 bool TConstantUnion::operator!=(const TConstantUnion &constant) const
350 {
351     return !operator==(constant);
352 }
353 
operator >(const TConstantUnion & constant) const354 bool TConstantUnion::operator>(const TConstantUnion &constant) const
355 {
356 
357     switch (type)
358     {
359         case EbtInt:
360             return iConst > constant.iConst;
361         case EbtUInt:
362             return uConst > constant.uConst;
363         case EbtFloat:
364             return fConst > constant.fConst;
365         default:
366             return false;  // Invalid operation, handled at semantic analysis
367     }
368 }
369 
operator <(const TConstantUnion & constant) const370 bool TConstantUnion::operator<(const TConstantUnion &constant) const
371 {
372     switch (type)
373     {
374         case EbtInt:
375             return iConst < constant.iConst;
376         case EbtUInt:
377             return uConst < constant.uConst;
378         case EbtFloat:
379             return fConst < constant.fConst;
380         default:
381             return false;  // Invalid operation, handled at semantic analysis
382     }
383 }
384 
385 // static
add(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)386 TConstantUnion TConstantUnion::add(const TConstantUnion &lhs,
387                                    const TConstantUnion &rhs,
388                                    TDiagnostics *diag,
389                                    const TSourceLoc &line)
390 {
391     TConstantUnion returnValue;
392 
393     switch (lhs.type)
394     {
395         case EbtInt:
396             returnValue.setIConst(gl::WrappingSum<int>(lhs.iConst, rhs.iConst));
397             break;
398         case EbtUInt:
399             returnValue.setUConst(gl::WrappingSum<unsigned int>(lhs.uConst, rhs.uConst));
400             break;
401         case EbtFloat:
402             returnValue.setFConst(CheckedSum(lhs.fConst, rhs.fConst, diag, line));
403             break;
404         default:
405             UNREACHABLE();
406     }
407 
408     return returnValue;
409 }
410 
411 // static
sub(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)412 TConstantUnion TConstantUnion::sub(const TConstantUnion &lhs,
413                                    const TConstantUnion &rhs,
414                                    TDiagnostics *diag,
415                                    const TSourceLoc &line)
416 {
417     TConstantUnion returnValue;
418 
419     switch (lhs.type)
420     {
421         case EbtInt:
422             returnValue.setIConst(gl::WrappingDiff<int>(lhs.iConst, rhs.iConst));
423             break;
424         case EbtUInt:
425             returnValue.setUConst(gl::WrappingDiff<unsigned int>(lhs.uConst, rhs.uConst));
426             break;
427         case EbtFloat:
428             returnValue.setFConst(CheckedDiff(lhs.fConst, rhs.fConst, diag, line));
429             break;
430         default:
431             UNREACHABLE();
432     }
433 
434     return returnValue;
435 }
436 
437 // static
mul(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)438 TConstantUnion TConstantUnion::mul(const TConstantUnion &lhs,
439                                    const TConstantUnion &rhs,
440                                    TDiagnostics *diag,
441                                    const TSourceLoc &line)
442 {
443     TConstantUnion returnValue;
444 
445     switch (lhs.type)
446     {
447         case EbtInt:
448             returnValue.setIConst(gl::WrappingMul(lhs.iConst, rhs.iConst));
449             break;
450         case EbtUInt:
451             // Unsigned integer math in C++ is defined to be done in modulo 2^n, so we rely
452             // on that to implement wrapping multiplication.
453             returnValue.setUConst(lhs.uConst * rhs.uConst);
454             break;
455         case EbtFloat:
456             returnValue.setFConst(CheckedMul(lhs.fConst, rhs.fConst, diag, line));
457             break;
458         default:
459             UNREACHABLE();
460     }
461 
462     return returnValue;
463 }
464 
operator %(const TConstantUnion & constant) const465 TConstantUnion TConstantUnion::operator%(const TConstantUnion &constant) const
466 {
467     TConstantUnion returnValue;
468     ASSERT(type == constant.type);
469     switch (type)
470     {
471         case EbtInt:
472             returnValue.setIConst(iConst % constant.iConst);
473             break;
474         case EbtUInt:
475             returnValue.setUConst(uConst % constant.uConst);
476             break;
477         default:
478             UNREACHABLE();
479     }
480 
481     return returnValue;
482 }
483 
484 // static
rshift(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)485 TConstantUnion TConstantUnion::rshift(const TConstantUnion &lhs,
486                                       const TConstantUnion &rhs,
487                                       TDiagnostics *diag,
488                                       const TSourceLoc &line)
489 {
490     TConstantUnion returnValue;
491     ASSERT(lhs.type == EbtInt || lhs.type == EbtUInt);
492     ASSERT(rhs.type == EbtInt || rhs.type == EbtUInt);
493     if (!IsValidShiftOffset(rhs))
494     {
495         diag->warning(line, "Undefined shift (operand out of range)", ">>");
496         switch (lhs.type)
497         {
498             case EbtInt:
499                 returnValue.setIConst(0);
500                 break;
501             case EbtUInt:
502                 returnValue.setUConst(0u);
503                 break;
504             default:
505                 UNREACHABLE();
506         }
507         return returnValue;
508     }
509 
510     switch (lhs.type)
511     {
512         case EbtInt:
513         {
514             unsigned int shiftOffset = 0;
515             switch (rhs.type)
516             {
517                 case EbtInt:
518                     shiftOffset = static_cast<unsigned int>(rhs.iConst);
519                     break;
520                 case EbtUInt:
521                     shiftOffset = rhs.uConst;
522                     break;
523                 default:
524                     UNREACHABLE();
525             }
526             if (shiftOffset > 0)
527             {
528                 // ESSL 3.00.6 section 5.9: "If E1 is a signed integer, the right-shift will extend
529                 // the sign bit." In C++ shifting negative integers is undefined, so we implement
530                 // extending the sign bit manually.
531                 int lhsSafe = lhs.iConst;
532                 if (lhsSafe == std::numeric_limits<int>::min())
533                 {
534                     // The min integer needs special treatment because only bit it has set is the
535                     // sign bit, which we clear later to implement safe right shift of negative
536                     // numbers.
537                     lhsSafe = -0x40000000;
538                     --shiftOffset;
539                 }
540                 if (shiftOffset > 0)
541                 {
542                     bool extendSignBit = false;
543                     if (lhsSafe < 0)
544                     {
545                         extendSignBit = true;
546                         // Clear the sign bit so that bitshift right is defined in C++.
547                         lhsSafe &= 0x7fffffff;
548                         ASSERT(lhsSafe > 0);
549                     }
550                     returnValue.setIConst(lhsSafe >> shiftOffset);
551 
552                     // Manually fill in the extended sign bit if necessary.
553                     if (extendSignBit)
554                     {
555                         int extendedSignBit = static_cast<int>(0xffffffffu << (31 - shiftOffset));
556                         returnValue.setIConst(returnValue.getIConst() | extendedSignBit);
557                     }
558                 }
559                 else
560                 {
561                     returnValue.setIConst(lhsSafe);
562                 }
563             }
564             else
565             {
566                 returnValue.setIConst(lhs.iConst);
567             }
568             break;
569         }
570         case EbtUInt:
571             switch (rhs.type)
572             {
573                 case EbtInt:
574                     returnValue.setUConst(lhs.uConst >> rhs.iConst);
575                     break;
576                 case EbtUInt:
577                     returnValue.setUConst(lhs.uConst >> rhs.uConst);
578                     break;
579                 default:
580                     UNREACHABLE();
581             }
582             break;
583 
584         default:
585             UNREACHABLE();
586     }
587     return returnValue;
588 }
589 
590 // static
lshift(const TConstantUnion & lhs,const TConstantUnion & rhs,TDiagnostics * diag,const TSourceLoc & line)591 TConstantUnion TConstantUnion::lshift(const TConstantUnion &lhs,
592                                       const TConstantUnion &rhs,
593                                       TDiagnostics *diag,
594                                       const TSourceLoc &line)
595 {
596     TConstantUnion returnValue;
597     ASSERT(lhs.type == EbtInt || lhs.type == EbtUInt);
598     ASSERT(rhs.type == EbtInt || rhs.type == EbtUInt);
599     if (!IsValidShiftOffset(rhs))
600     {
601         diag->warning(line, "Undefined shift (operand out of range)", "<<");
602         switch (lhs.type)
603         {
604             case EbtInt:
605                 returnValue.setIConst(0);
606                 break;
607             case EbtUInt:
608                 returnValue.setUConst(0u);
609                 break;
610             default:
611                 UNREACHABLE();
612         }
613         return returnValue;
614     }
615 
616     switch (lhs.type)
617     {
618         case EbtInt:
619             switch (rhs.type)
620             {
621                 // Cast to unsigned integer before shifting, since ESSL 3.00.6 section 5.9 says that
622                 // lhs is "interpreted as a bit pattern". This also avoids the possibility of signed
623                 // integer overflow or undefined shift of a negative integer.
624                 case EbtInt:
625                     returnValue.setIConst(
626                         static_cast<int>(static_cast<uint32_t>(lhs.iConst) << rhs.iConst));
627                     break;
628                 case EbtUInt:
629                     returnValue.setIConst(
630                         static_cast<int>(static_cast<uint32_t>(lhs.iConst) << rhs.uConst));
631                     break;
632                 default:
633                     UNREACHABLE();
634             }
635             break;
636 
637         case EbtUInt:
638             switch (rhs.type)
639             {
640                 case EbtInt:
641                     returnValue.setUConst(lhs.uConst << rhs.iConst);
642                     break;
643                 case EbtUInt:
644                     returnValue.setUConst(lhs.uConst << rhs.uConst);
645                     break;
646                 default:
647                     UNREACHABLE();
648             }
649             break;
650 
651         default:
652             UNREACHABLE();
653     }
654     return returnValue;
655 }
656 
operator &(const TConstantUnion & constant) const657 TConstantUnion TConstantUnion::operator&(const TConstantUnion &constant) const
658 {
659     TConstantUnion returnValue;
660     ASSERT(constant.type == EbtInt || constant.type == EbtUInt);
661     switch (type)
662     {
663         case EbtInt:
664             returnValue.setIConst(iConst & constant.iConst);
665             break;
666         case EbtUInt:
667             returnValue.setUConst(uConst & constant.uConst);
668             break;
669         default:
670             UNREACHABLE();
671     }
672 
673     return returnValue;
674 }
675 
operator |(const TConstantUnion & constant) const676 TConstantUnion TConstantUnion::operator|(const TConstantUnion &constant) const
677 {
678     TConstantUnion returnValue;
679     ASSERT(type == constant.type);
680     switch (type)
681     {
682         case EbtInt:
683             returnValue.setIConst(iConst | constant.iConst);
684             break;
685         case EbtUInt:
686             returnValue.setUConst(uConst | constant.uConst);
687             break;
688         default:
689             UNREACHABLE();
690     }
691 
692     return returnValue;
693 }
694 
operator ^(const TConstantUnion & constant) const695 TConstantUnion TConstantUnion::operator^(const TConstantUnion &constant) const
696 {
697     TConstantUnion returnValue;
698     ASSERT(type == constant.type);
699     switch (type)
700     {
701         case EbtInt:
702             returnValue.setIConst(iConst ^ constant.iConst);
703             break;
704         case EbtUInt:
705             returnValue.setUConst(uConst ^ constant.uConst);
706             break;
707         default:
708             UNREACHABLE();
709     }
710 
711     return returnValue;
712 }
713 
operator &&(const TConstantUnion & constant) const714 TConstantUnion TConstantUnion::operator&&(const TConstantUnion &constant) const
715 {
716     TConstantUnion returnValue;
717     ASSERT(type == constant.type);
718     switch (type)
719     {
720         case EbtBool:
721             returnValue.setBConst(bConst && constant.bConst);
722             break;
723         default:
724             UNREACHABLE();
725     }
726 
727     return returnValue;
728 }
729 
operator ||(const TConstantUnion & constant) const730 TConstantUnion TConstantUnion::operator||(const TConstantUnion &constant) const
731 {
732     TConstantUnion returnValue;
733     ASSERT(type == constant.type);
734     switch (type)
735     {
736         case EbtBool:
737             returnValue.setBConst(bConst || constant.bConst);
738             break;
739         default:
740             UNREACHABLE();
741     }
742 
743     return returnValue;
744 }
745 
746 }  // namespace sh
747