1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2018 The Khronos Group Inc.
6  * Copyright (c) 2015 Samsung Electronics Co., Ltd.
7  * Copyright (c) 2016 The Android Open Source Project
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  *//*!
22  * \file
23  * \brief Precision and range tests for builtins and types.
24  *
25  *//*--------------------------------------------------------------------*/
26 
27 #include "vktShaderBuiltinPrecisionTests.hpp"
28 #include "vktShaderExecutor.hpp"
29 #include "amber/vktAmberTestCase.hpp"
30 
31 #include "deMath.h"
32 #include "deMemory.h"
33 #include "deFloat16.h"
34 #include "deDefs.hpp"
35 #include "deRandom.hpp"
36 #include "deSTLUtil.hpp"
37 #include "deStringUtil.hpp"
38 #include "deUniquePtr.hpp"
39 #include "deSharedPtr.hpp"
40 #include "deArrayUtil.hpp"
41 
42 #include "tcuCommandLine.hpp"
43 #include "tcuFloatFormat.hpp"
44 #include "tcuInterval.hpp"
45 #include "tcuTestLog.hpp"
46 #include "tcuVector.hpp"
47 #include "tcuMatrix.hpp"
48 #include "tcuResultCollector.hpp"
49 #include "tcuMaybe.hpp"
50 
51 #include "gluContextInfo.hpp"
52 #include "gluVarType.hpp"
53 #include "gluRenderContext.hpp"
54 #include "glwDefs.hpp"
55 
56 #include <cmath>
57 #include <string>
58 #include <sstream>
59 #include <iostream>
60 #include <map>
61 #include <utility>
62 #include <limits>
63 
64 // Uncomment this to get evaluation trace dumps to std::cerr
65 // #define GLS_ENABLE_TRACE
66 
67 // set this to true to dump even passing results
68 #define GLS_LOG_ALL_RESULTS false
69 
70 #define FLOAT16_1_0 0x3C00 //1.0 float16bit
71 #define FLOAT16_2_0 0x4000 //2.0 float16bit
72 #define FLOAT16_3_0 0x4200 //3.0 float16bit
73 #define FLOAT16_0_5 0x3800 //0.5 float16bit
74 #define FLOAT16_0_0 0x0000 //0.0 float16bit
75 
76 using tcu::Vector;
77 typedef Vector<deFloat16, 1> Vec1_16Bit;
78 typedef Vector<deFloat16, 2> Vec2_16Bit;
79 typedef Vector<deFloat16, 3> Vec3_16Bit;
80 typedef Vector<deFloat16, 4> Vec4_16Bit;
81 
82 typedef Vector<double, 1> Vec1_64Bit;
83 typedef Vector<double, 2> Vec2_64Bit;
84 typedef Vector<double, 3> Vec3_64Bit;
85 typedef Vector<double, 4> Vec4_64Bit;
86 
87 enum
88 {
89     // Computing reference intervals can take a non-trivial amount of time, especially on
90     // platforms where toggling floating-point rounding mode is slow (emulated arm on x86).
91     // As a workaround watchdog is kept happy by touching it periodically during reference
92     // interval computation.
93     TOUCH_WATCHDOG_VALUE_FREQUENCY = 512
94 };
95 
96 namespace vkt
97 {
98 namespace shaderexecutor
99 {
100 
101 using std::map;
102 using std::ostream;
103 using std::ostringstream;
104 using std::pair;
105 using std::set;
106 using std::string;
107 using std::vector;
108 
109 using de::MovePtr;
110 using de::Random;
111 using de::SharedPtr;
112 using de::UniquePtr;
113 using glu::DataType;
114 using glu::Precision;
115 using glu::ShaderType;
116 using glu::VarType;
117 using tcu::FloatFormat;
118 using tcu::Interval;
119 using tcu::Matrix;
120 using tcu::MessageBuilder;
121 using tcu::TestLog;
122 using tcu::Vector;
123 
124 enum PrecisionTestFeatureBits
125 {
126     PRECISION_TEST_FEATURES_NONE                                    = 0u,
127     PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS                     = (1u << 1),
128     PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS = (1u << 2),
129     PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT                     = (1u << 3),
130     PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT                      = (1u << 4),
131     PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT                      = (1u << 5),
132     PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT                      = (1u << 6),
133 };
134 typedef uint32_t PrecisionTestFeatures;
135 
areFeaturesSupported(const Context & context,uint32_t toCheck)136 void areFeaturesSupported(const Context &context, uint32_t toCheck)
137 {
138     if (toCheck == PRECISION_TEST_FEATURES_NONE)
139         return;
140 
141     const vk::VkPhysicalDevice16BitStorageFeatures &extensionFeatures = context.get16BitStorageFeatures();
142 
143     if ((toCheck & PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS) != 0 &&
144         extensionFeatures.storageBuffer16BitAccess == VK_FALSE)
145         TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
146 
147     if ((toCheck & PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS) != 0 &&
148         extensionFeatures.uniformAndStorageBuffer16BitAccess == VK_FALSE)
149         TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
150 
151     if ((toCheck & PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT) != 0 &&
152         extensionFeatures.storagePushConstant16 == VK_FALSE)
153         TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
154 
155     if ((toCheck & PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT) != 0 &&
156         extensionFeatures.storageInputOutput16 == VK_FALSE)
157         TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
158 
159     if ((toCheck & PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT) != 0 &&
160         context.getShaderFloat16Int8Features().shaderFloat16 == VK_FALSE)
161         TCU_THROW(NotSupportedError, "Requested 16-bit floats (halfs) are not supported in shader code");
162 
163     if ((toCheck & PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT) != 0 &&
164         context.getDeviceFeatures().shaderFloat64 == VK_FALSE)
165         TCU_THROW(NotSupportedError, "Requested 64-bit floats are not supported in shader code");
166 }
167 
168 /*--------------------------------------------------------------------*//*!
169  * \brief Generic singleton creator.
170  *
171  * instance<T>() returns a reference to a unique default-constructed instance
172  * of T. This is mainly used for our GLSL function implementations: each
173  * function is implemented by an object, and each of the objects has a
174  * distinct class. It would be extremely toilsome to maintain a separate
175  * context object that contained individual instances of the function classes,
176  * so we have to resort to global singleton instances.
177  *
178  *//*--------------------------------------------------------------------*/
179 template <typename T>
instance(void)180 const T &instance(void)
181 {
182     static const T s_instance = T();
183     return s_instance;
184 }
185 
186 /*--------------------------------------------------------------------*//*!
187  * \brief Empty placeholder type for unused template parameters.
188  *
189  * In the precision tests we are dealing with functions of different arities.
190  * To minimize code duplication, we only define templates with the maximum
191  * number of arguments, currently four. If a function's arity is less than the
192  * maximum, Void us used as the type for unused arguments.
193  *
194  * Although Voids are not used at run-time, they still must be compilable, so
195  * they must support all operations that other types do.
196  *
197  *//*--------------------------------------------------------------------*/
198 struct Void
199 {
200     typedef Void Element;
201     enum
202     {
203         SIZE = 0,
204     };
205 
206     template <typename T>
Voidvkt::shaderexecutor::Void207     explicit Void(const T &)
208     {
209     }
Voidvkt::shaderexecutor::Void210     Void(void)
211     {
212     }
operator doublevkt::shaderexecutor::Void213     operator double(void) const
214     {
215         return TCU_NAN;
216     }
217 
218     // These are used to make Voids usable as containers in container-generic code.
operator []vkt::shaderexecutor::Void219     Void &operator[](int)
220     {
221         return *this;
222     }
operator []vkt::shaderexecutor::Void223     const Void &operator[](int) const
224     {
225         return *this;
226     }
227 };
228 
operator <<(ostream & os,Void)229 ostream &operator<<(ostream &os, Void)
230 {
231     return os << "()";
232 }
233 
234 //! Returns true for all other types except Void
235 template <typename T>
isTypeValid(void)236 bool isTypeValid(void)
237 {
238     return true;
239 }
240 template <>
isTypeValid(void)241 bool isTypeValid<Void>(void)
242 {
243     return false;
244 }
245 
246 template <typename T>
isInteger(void)247 bool isInteger(void)
248 {
249     return false;
250 }
251 template <>
isInteger(void)252 bool isInteger<int>(void)
253 {
254     return true;
255 }
256 template <>
isInteger(void)257 bool isInteger<tcu::IVec2>(void)
258 {
259     return true;
260 }
261 template <>
isInteger(void)262 bool isInteger<tcu::IVec3>(void)
263 {
264     return true;
265 }
266 template <>
isInteger(void)267 bool isInteger<tcu::IVec4>(void)
268 {
269     return true;
270 }
271 
272 //! Utility function for getting the name of a data type.
273 //! This is used in vector and matrix constructors.
274 template <typename T>
dataTypeNameOf(void)275 const char *dataTypeNameOf(void)
276 {
277     return glu::getDataTypeName(glu::dataTypeOf<T>());
278 }
279 
280 template <>
dataTypeNameOf(void)281 const char *dataTypeNameOf<Void>(void)
282 {
283     DE_FATAL("Impossible");
284     return DE_NULL;
285 }
286 
287 template <typename T>
getVarTypeOf(Precision prec=glu::PRECISION_LAST)288 VarType getVarTypeOf(Precision prec = glu::PRECISION_LAST)
289 {
290     return glu::varTypeOf<T>(prec);
291 }
292 
293 //! A hack to get Void support for VarType.
294 template <>
getVarTypeOf(Precision)295 VarType getVarTypeOf<Void>(Precision)
296 {
297     DE_FATAL("Impossible");
298     return VarType();
299 }
300 
301 /*--------------------------------------------------------------------*//*!
302  * \brief Type traits for generalized interval types.
303  *
304  * We are trying to compute sets of acceptable values not only for
305  * float-valued expressions but also for compound values: vectors and
306  * matrices. We approximate a set of vectors as a vector of intervals and
307  * likewise for matrices.
308  *
309  * We now need generalized operations for each type and its interval
310  * approximation. These are given in the type Traits<T>.
311  *
312  * The type Traits<T>::IVal is the approximation of T: it is `Interval` for
313  * scalar types, and a vector or matrix of intervals for container types.
314  *
315  * To allow template inference to take place, there are function wrappers for
316  * the actual operations in Traits<T>. Hence we can just use:
317  *
318  * makeIVal(someFloat)
319  *
320  * instead of:
321  *
322  * Traits<float>::doMakeIVal(value)
323  *
324  *//*--------------------------------------------------------------------*/
325 
326 template <typename T>
327 struct Traits;
328 
329 //! Create container from elementwise singleton values.
330 template <typename T>
makeIVal(const T & value)331 typename Traits<T>::IVal makeIVal(const T &value)
332 {
333     return Traits<T>::doMakeIVal(value);
334 }
335 
336 //! Elementwise union of intervals.
337 template <typename T>
unionIVal(const typename Traits<T>::IVal & a,const typename Traits<T>::IVal & b)338 typename Traits<T>::IVal unionIVal(const typename Traits<T>::IVal &a, const typename Traits<T>::IVal &b)
339 {
340     return Traits<T>::doUnion(a, b);
341 }
342 
343 //! Returns true iff every element of `ival` contains the corresponding element of `value`.
344 template <typename T, typename U = Void>
contains(const typename Traits<T>::IVal & ival,const T & value,bool is16Bit=false,const tcu::Maybe<U> & modularDivisor=tcu::Nothing)345 bool contains(const typename Traits<T>::IVal &ival, const T &value, bool is16Bit = false,
346               const tcu::Maybe<U> &modularDivisor = tcu::Nothing)
347 {
348     return Traits<T>::doContains(ival, value, is16Bit, modularDivisor);
349 }
350 
351 //! Print out an interval with the precision of `fmt`.
352 template <typename T>
printIVal(const FloatFormat & fmt,const typename Traits<T>::IVal & ival,ostream & os)353 void printIVal(const FloatFormat &fmt, const typename Traits<T>::IVal &ival, ostream &os)
354 {
355     Traits<T>::doPrintIVal(fmt, ival, os);
356 }
357 
358 template <typename T>
intervalToString(const FloatFormat & fmt,const typename Traits<T>::IVal & ival)359 string intervalToString(const FloatFormat &fmt, const typename Traits<T>::IVal &ival)
360 {
361     ostringstream oss;
362     printIVal<T>(fmt, ival, oss);
363     return oss.str();
364 }
365 
366 //! Print out a value with the precision of `fmt`.
367 template <typename T>
printValue16(const FloatFormat & fmt,const T & value,ostream & os)368 void printValue16(const FloatFormat &fmt, const T &value, ostream &os)
369 {
370     Traits<T>::doPrintValue16(fmt, value, os);
371 }
372 
373 template <typename T>
value16ToString(const FloatFormat & fmt,const T & val)374 string value16ToString(const FloatFormat &fmt, const T &val)
375 {
376     ostringstream oss;
377     printValue16(fmt, val, oss);
378     return oss.str();
379 }
380 
getComparisonOperation(const int ndx)381 const std::string getComparisonOperation(const int ndx)
382 {
383     const int operationCount = 10;
384     DE_ASSERT(de::inBounds(ndx, 0, operationCount));
385     const std::string operations[operationCount] = {
386         "OpFOrdEqual\t\t\t",        "OpFOrdGreaterThan\t",  "OpFOrdLessThan\t\t",    "OpFOrdGreaterThanEqual",
387         "OpFOrdLessThanEqual\t",    "OpFUnordEqual\t\t",    "OpFUnordGreaterThan\t", "OpFUnordLessThan\t",
388         "OpFUnordGreaterThanEqual", "OpFUnordLessThanEqual"};
389     return operations[ndx];
390 }
391 
392 template <typename T>
comparisonMessage(const T & val)393 string comparisonMessage(const T &val)
394 {
395     DE_UNREF(val);
396     return "";
397 }
398 
399 template <>
comparisonMessage(const int & val)400 string comparisonMessage(const int &val)
401 {
402     ostringstream oss;
403 
404     int flags = val;
405     for (int ndx = 0; ndx < 10; ++ndx)
406     {
407         oss << getComparisonOperation(ndx) << "\t:\t" << ((flags & 1) == 1 ? "TRUE" : "FALSE") << "\n";
408         flags = flags >> 1;
409     }
410     return oss.str();
411 }
412 
413 template <>
comparisonMessage(const tcu::IVec2 & val)414 string comparisonMessage(const tcu::IVec2 &val)
415 {
416     ostringstream oss;
417     tcu::IVec2 flags = val;
418     for (int ndx = 0; ndx < 10; ++ndx)
419     {
420         oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
421             << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
422         flags.x() = flags.x() >> 1;
423         flags.y() = flags.y() >> 1;
424     }
425     return oss.str();
426 }
427 
428 template <>
comparisonMessage(const tcu::IVec3 & val)429 string comparisonMessage(const tcu::IVec3 &val)
430 {
431     ostringstream oss;
432     tcu::IVec3 flags = val;
433     for (int ndx = 0; ndx < 10; ++ndx)
434     {
435         oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
436             << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t" << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
437         flags.x() = flags.x() >> 1;
438         flags.y() = flags.y() >> 1;
439         flags.z() = flags.z() >> 1;
440     }
441     return oss.str();
442 }
443 
444 template <>
comparisonMessage(const tcu::IVec4 & val)445 string comparisonMessage(const tcu::IVec4 &val)
446 {
447     ostringstream oss;
448     tcu::IVec4 flags = val;
449     for (int ndx = 0; ndx < 10; ++ndx)
450     {
451         oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
452             << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t" << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
453             << ((flags.w() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
454         flags.x() = flags.x() >> 1;
455         flags.y() = flags.y() >> 1;
456         flags.z() = flags.z() >> 1;
457         flags.w() = flags.z() >> 1;
458     }
459     return oss.str();
460 }
461 //! Print out a value with the precision of `fmt`.
462 template <typename T>
printValue32(const FloatFormat & fmt,const T & value,ostream & os)463 void printValue32(const FloatFormat &fmt, const T &value, ostream &os)
464 {
465     Traits<T>::doPrintValue32(fmt, value, os);
466 }
467 
468 template <typename T>
value32ToString(const FloatFormat & fmt,const T & val)469 string value32ToString(const FloatFormat &fmt, const T &val)
470 {
471     ostringstream oss;
472     printValue32(fmt, val, oss);
473     return oss.str();
474 }
475 
476 template <typename T>
printValue64(const FloatFormat & fmt,const T & value,ostream & os)477 void printValue64(const FloatFormat &fmt, const T &value, ostream &os)
478 {
479     Traits<T>::doPrintValue64(fmt, value, os);
480 }
481 
482 template <typename T>
value64ToString(const FloatFormat & fmt,const T & val)483 string value64ToString(const FloatFormat &fmt, const T &val)
484 {
485     ostringstream oss;
486     printValue64(fmt, val, oss);
487     return oss.str();
488 }
489 
490 //! Approximate `value` elementwise to the float precision defined in `fmt`.
491 //! The resulting interval might not be a singleton if rounding in both
492 //! directions is allowed.
493 template <typename T>
round(const FloatFormat & fmt,const T & value)494 typename Traits<T>::IVal round(const FloatFormat &fmt, const T &value)
495 {
496     return Traits<T>::doRound(fmt, value);
497 }
498 
499 template <typename T>
convert(const FloatFormat & fmt,const typename Traits<T>::IVal & value)500 typename Traits<T>::IVal convert(const FloatFormat &fmt, const typename Traits<T>::IVal &value)
501 {
502     return Traits<T>::doConvert(fmt, value);
503 }
504 
505 // Matching input and output types. We may be in a modulo case and modularDivisor may have an actual value.
506 template <typename T>
intervalContains(const Interval & interval,T value,const tcu::Maybe<T> & modularDivisor)507 bool intervalContains(const Interval &interval, T value, const tcu::Maybe<T> &modularDivisor)
508 {
509     bool contained = interval.contains(value);
510 
511     if (!contained && modularDivisor)
512     {
513         const T divisor = modularDivisor.get();
514 
515         // In a modulo operation, if the calculated answer contains the divisor, allow exactly 0.0 as a replacement. Alternatively,
516         // if the calculated answer contains 0.0, allow exactly the divisor as a replacement.
517         if (interval.contains(static_cast<double>(divisor)))
518             contained |= (value == 0.0);
519         if (interval.contains(0.0))
520             contained |= (value == divisor);
521     }
522     return contained;
523 }
524 
525 // When the input and output types do not match, we are not in a real modulo operation. Do not take the divisor into account. This
526 // version is provided for syntactical compatibility only.
527 template <typename T, typename U>
intervalContains(const Interval & interval,T value,const tcu::Maybe<U> & modularDivisor)528 bool intervalContains(const Interval &interval, T value, const tcu::Maybe<U> &modularDivisor)
529 {
530     DE_UNREF(modularDivisor); // For release builds.
531     DE_ASSERT(!modularDivisor);
532     return interval.contains(value);
533 }
534 
535 //! Common traits for scalar types.
536 template <typename T>
537 struct ScalarTraits
538 {
539     typedef Interval IVal;
540 
doMakeIValvkt::shaderexecutor::ScalarTraits541     static Interval doMakeIVal(const T &value)
542     {
543         // Thankfully all scalar types have a well-defined conversion to `double`,
544         // hence Interval can represent their ranges without problems.
545         return Interval(double(value));
546     }
547 
doUnionvkt::shaderexecutor::ScalarTraits548     static Interval doUnion(const Interval &a, const Interval &b)
549     {
550         return a | b;
551     }
552 
doContainsvkt::shaderexecutor::ScalarTraits553     static bool doContains(const Interval &a, T value)
554     {
555         return a.contains(double(value));
556     }
557 
doConvertvkt::shaderexecutor::ScalarTraits558     static Interval doConvert(const FloatFormat &fmt, const IVal &ival)
559     {
560         return fmt.convert(ival);
561     }
562 
doConvertvkt::shaderexecutor::ScalarTraits563     static Interval doConvert(const FloatFormat &fmt, const IVal &ival, bool is16Bit)
564     {
565         DE_UNREF(is16Bit);
566         return fmt.convert(ival);
567     }
568 
doRoundvkt::shaderexecutor::ScalarTraits569     static Interval doRound(const FloatFormat &fmt, T value)
570     {
571         return fmt.roundOut(double(value), false);
572     }
573 };
574 
575 template <>
576 struct ScalarTraits<uint16_t>
577 {
578     typedef Interval IVal;
579 
doMakeIValvkt::shaderexecutor::ScalarTraits580     static Interval doMakeIVal(const uint16_t &value)
581     {
582         // Thankfully all scalar types have a well-defined conversion to `double`,
583         // hence Interval can represent their ranges without problems.
584         return Interval(double(deFloat16To32(value)));
585     }
586 
doUnionvkt::shaderexecutor::ScalarTraits587     static Interval doUnion(const Interval &a, const Interval &b)
588     {
589         return a | b;
590     }
591 
doConvertvkt::shaderexecutor::ScalarTraits592     static Interval doConvert(const FloatFormat &fmt, const IVal &ival)
593     {
594         return fmt.convert(ival);
595     }
596 
doRoundvkt::shaderexecutor::ScalarTraits597     static Interval doRound(const FloatFormat &fmt, uint16_t value)
598     {
599         return fmt.roundOut(double(deFloat16To32(value)), false);
600     }
601 };
602 
603 template <>
604 struct Traits<float> : ScalarTraits<float>
605 {
doPrintIValvkt::shaderexecutor::Traits606     static void doPrintIVal(const FloatFormat &fmt, const Interval &ival, ostream &os)
607     {
608         os << fmt.intervalToHex(ival);
609     }
610 
doPrintValue16vkt::shaderexecutor::Traits611     static void doPrintValue16(const FloatFormat &fmt, const float &value, ostream &os)
612     {
613         const uint32_t iRep = reinterpret_cast<const uint32_t &>(value);
614         float res0          = deFloat16To32((deFloat16)(iRep & 0xFFFF));
615         float res1          = deFloat16To32((deFloat16)(iRep >> 16));
616         os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
617     }
618 
doPrintValue32vkt::shaderexecutor::Traits619     static void doPrintValue32(const FloatFormat &fmt, const float &value, ostream &os)
620     {
621         os << fmt.floatToHex(value);
622     }
623 
doPrintValue64vkt::shaderexecutor::Traits624     static void doPrintValue64(const FloatFormat &fmt, const float &value, ostream &os)
625     {
626         os << fmt.floatToHex(value);
627     }
628 
629     template <typename U>
doContainsvkt::shaderexecutor::Traits630     static bool doContains(const Interval &a, const float &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
631     {
632         if (is16Bit)
633         {
634             // Note: for deFloat16s packed in 32 bits, the original divisor is provided as a float to the shader in the input
635             // buffer, so U is also float here and we call the right interlvalContains() version.
636             const uint32_t iRep = reinterpret_cast<const uint32_t &>(value);
637             float res0          = deFloat16To32((deFloat16)(iRep & 0xFFFF));
638             float res1          = deFloat16To32((deFloat16)(iRep >> 16));
639             return intervalContains(a, res0, modularDivisor) && (res1 == -1.0);
640         }
641         return intervalContains(a, value, modularDivisor);
642     }
643 };
644 
645 template <>
646 struct Traits<double> : ScalarTraits<double>
647 {
doPrintIValvkt::shaderexecutor::Traits648     static void doPrintIVal(const FloatFormat &fmt, const Interval &ival, ostream &os)
649     {
650         os << fmt.intervalToHex(ival);
651     }
652 
doPrintValue16vkt::shaderexecutor::Traits653     static void doPrintValue16(const FloatFormat &fmt, const double &value, ostream &os)
654     {
655         const uint64_t iRep = reinterpret_cast<const uint64_t &>(value);
656         double byte0        = deFloat16To64((deFloat16)((iRep)&0xffff));
657         double byte1        = deFloat16To64((deFloat16)((iRep >> 16) & 0xffff));
658         double byte2        = deFloat16To64((deFloat16)((iRep >> 32) & 0xffff));
659         double byte3        = deFloat16To64((deFloat16)((iRep >> 48) & 0xffff));
660         os << fmt.floatToHex(byte0) << " " << fmt.floatToHex(byte1) << " " << fmt.floatToHex(byte2) << " "
661            << fmt.floatToHex(byte3);
662     }
663 
doPrintValue32vkt::shaderexecutor::Traits664     static void doPrintValue32(const FloatFormat &fmt, const double &value, ostream &os)
665     {
666         const uint64_t iRep = reinterpret_cast<const uint64_t &>(value);
667         double res0         = static_cast<double>((float)((iRep)&0xffffffff));
668         double res1         = static_cast<double>((float)((iRep >> 32) & 0xffffffff));
669         os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
670     }
671 
doPrintValue64vkt::shaderexecutor::Traits672     static void doPrintValue64(const FloatFormat &fmt, const double &value, ostream &os)
673     {
674         os << fmt.floatToHex(value);
675     }
676 
677     template <class U>
doContainsvkt::shaderexecutor::Traits678     static bool doContains(const Interval &a, const double &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
679     {
680         DE_UNREF(is16Bit);
681         DE_ASSERT(!is16Bit);
682         return intervalContains(a, value, modularDivisor);
683     }
684 };
685 
686 template <>
687 struct Traits<deFloat16> : ScalarTraits<deFloat16>
688 {
doPrintIValvkt::shaderexecutor::Traits689     static void doPrintIVal(const FloatFormat &fmt, const Interval &ival, ostream &os)
690     {
691         os << fmt.intervalToHex(ival);
692     }
693 
doPrintValue16vkt::shaderexecutor::Traits694     static void doPrintValue16(const FloatFormat &fmt, const deFloat16 &value, ostream &os)
695     {
696         const float res0 = deFloat16To32(value);
697         os << fmt.floatToHex(static_cast<double>(res0));
698     }
doPrintValue32vkt::shaderexecutor::Traits699     static void doPrintValue32(const FloatFormat &fmt, const deFloat16 &value, ostream &os)
700     {
701         const float res0 = deFloat16To32(value);
702         os << fmt.floatToHex(static_cast<double>(res0));
703     }
704 
doPrintValue64vkt::shaderexecutor::Traits705     static void doPrintValue64(const FloatFormat &fmt, const deFloat16 &value, ostream &os)
706     {
707         const double res0 = deFloat16To64(value);
708         os << fmt.floatToHex(res0);
709     }
710 
711     // When the value and divisor are both deFloat16, convert both to float to call the right intervalContains version.
doContainsvkt::shaderexecutor::Traits712     static bool doContains(const Interval &a, const deFloat16 &value, bool is16Bit,
713                            const tcu::Maybe<deFloat16> &modularDivisor)
714     {
715         DE_UNREF(is16Bit);
716         float res0 = deFloat16To32(value);
717         const tcu::Maybe<float> convertedDivisor =
718             (modularDivisor ? tcu::just(deFloat16To32(modularDivisor.get())) : tcu::Nothing);
719         return intervalContains(a, res0, convertedDivisor);
720     }
721 
722     // If the types don't match we should not be in a modulo operation, so no conversion should take place.
723     template <class U>
doContainsvkt::shaderexecutor::Traits724     static bool doContains(const Interval &a, const deFloat16 &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
725     {
726         DE_UNREF(is16Bit);
727         float res0 = deFloat16To32(value);
728         return intervalContains(a, res0, modularDivisor);
729     }
730 };
731 
732 template <>
733 struct Traits<bool> : ScalarTraits<bool>
734 {
doPrintValue16vkt::shaderexecutor::Traits735     static void doPrintValue16(const FloatFormat &, const float &value, ostream &os)
736     {
737         os << (value != 0.0f ? "true" : "false");
738     }
739 
doPrintValue32vkt::shaderexecutor::Traits740     static void doPrintValue32(const FloatFormat &, const float &value, ostream &os)
741     {
742         os << (value != 0.0f ? "true" : "false");
743     }
744 
doPrintValue64vkt::shaderexecutor::Traits745     static void doPrintValue64(const FloatFormat &, const float &value, ostream &os)
746     {
747         os << (value != 0.0f ? "true" : "false");
748     }
749 
doPrintIValvkt::shaderexecutor::Traits750     static void doPrintIVal(const FloatFormat &, const Interval &ival, ostream &os)
751     {
752         os << "{";
753         if (ival.contains(false))
754             os << "false";
755         if (ival.contains(false) && ival.contains(true))
756             os << ", ";
757         if (ival.contains(true))
758             os << "true";
759         os << "}";
760     }
761 };
762 
763 template <>
764 struct Traits<int> : ScalarTraits<int>
765 {
doPrintValue16vkt::shaderexecutor::Traits766     static void doPrintValue16(const FloatFormat &, const int &value, ostream &os)
767     {
768         int res0 = value & 0xFFFF;
769         int res1 = value >> 16;
770         os << res0 << " " << res1;
771     }
772 
doPrintValue32vkt::shaderexecutor::Traits773     static void doPrintValue32(const FloatFormat &, const int &value, ostream &os)
774     {
775         os << value;
776     }
777 
doPrintValue64vkt::shaderexecutor::Traits778     static void doPrintValue64(const FloatFormat &, const int &value, ostream &os)
779     {
780         os << value;
781     }
782 
doPrintIValvkt::shaderexecutor::Traits783     static void doPrintIVal(const FloatFormat &, const Interval &ival, ostream &os)
784     {
785         os << "[" << int(ival.lo()) << ", " << int(ival.hi()) << "]";
786     }
787 
788     template <typename U>
doContainsvkt::shaderexecutor::Traits789     static bool doContains(const Interval &a, const int &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
790     {
791         DE_UNREF(is16Bit);
792         return intervalContains(a, value, modularDivisor);
793     }
794 };
795 
796 //! Common traits for containers, i.e. vectors and matrices.
797 //! T is the container type itself, I is the same type with interval elements.
798 template <typename T, typename I>
799 struct ContainerTraits
800 {
801     typedef typename T::Element Element;
802     typedef I IVal;
803 
doMakeIValvkt::shaderexecutor::ContainerTraits804     static IVal doMakeIVal(const T &value)
805     {
806         IVal ret;
807 
808         for (int ndx = 0; ndx < T::SIZE; ++ndx)
809             ret[ndx] = makeIVal(value[ndx]);
810 
811         return ret;
812     }
813 
doUnionvkt::shaderexecutor::ContainerTraits814     static IVal doUnion(const IVal &a, const IVal &b)
815     {
816         IVal ret;
817 
818         for (int ndx = 0; ndx < T::SIZE; ++ndx)
819             ret[ndx] = unionIVal<Element>(a[ndx], b[ndx]);
820 
821         return ret;
822     }
823 
824     // When the input and output types match, we may be in a modulo operation. If the divisor is provided, use each of its
825     // components to determine if the obtained result is fine.
doContainsvkt::shaderexecutor::ContainerTraits826     static bool doContains(const IVal &ival, const T &value, bool is16Bit, const tcu::Maybe<T> &modularDivisor)
827     {
828         using DivisorElement = typename T::Element;
829 
830         for (int ndx = 0; ndx < T::SIZE; ++ndx)
831         {
832             const tcu::Maybe<DivisorElement> divisorElement =
833                 (modularDivisor ? tcu::just((*modularDivisor)[ndx]) : tcu::Nothing);
834             if (!contains(ival[ndx], value[ndx], is16Bit, divisorElement))
835                 return false;
836         }
837 
838         return true;
839     }
840 
841     // When the input and output types do not match we should not be in a modulo operation. This version is provided for syntactical
842     // compatibility.
843     template <typename U>
doContainsvkt::shaderexecutor::ContainerTraits844     static bool doContains(const IVal &ival, const T &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
845     {
846         for (int ndx = 0; ndx < T::SIZE; ++ndx)
847         {
848             if (!contains(ival[ndx], value[ndx], is16Bit, modularDivisor))
849                 return false;
850         }
851 
852         return true;
853     }
854 
doPrintIValvkt::shaderexecutor::ContainerTraits855     static void doPrintIVal(const FloatFormat &fmt, const IVal ival, ostream &os)
856     {
857         os << "(";
858 
859         for (int ndx = 0; ndx < T::SIZE; ++ndx)
860         {
861             if (ndx > 0)
862                 os << ", ";
863 
864             printIVal<Element>(fmt, ival[ndx], os);
865         }
866 
867         os << ")";
868     }
869 
doPrintValue16vkt::shaderexecutor::ContainerTraits870     static void doPrintValue16(const FloatFormat &fmt, const T &value, ostream &os)
871     {
872         os << dataTypeNameOf<T>() << "(";
873 
874         for (int ndx = 0; ndx < T::SIZE; ++ndx)
875         {
876             if (ndx > 0)
877                 os << ", ";
878 
879             printValue16<Element>(fmt, value[ndx], os);
880         }
881 
882         os << ")";
883     }
884 
doPrintValue32vkt::shaderexecutor::ContainerTraits885     static void doPrintValue32(const FloatFormat &fmt, const T &value, ostream &os)
886     {
887         os << dataTypeNameOf<T>() << "(";
888 
889         for (int ndx = 0; ndx < T::SIZE; ++ndx)
890         {
891             if (ndx > 0)
892                 os << ", ";
893 
894             printValue32<Element>(fmt, value[ndx], os);
895         }
896 
897         os << ")";
898     }
899 
doPrintValue64vkt::shaderexecutor::ContainerTraits900     static void doPrintValue64(const FloatFormat &fmt, const T &value, ostream &os)
901     {
902         os << dataTypeNameOf<T>() << "(";
903 
904         for (int ndx = 0; ndx < T::SIZE; ++ndx)
905         {
906             if (ndx > 0)
907                 os << ", ";
908 
909             printValue64<Element>(fmt, value[ndx], os);
910         }
911 
912         os << ")";
913     }
914 
doConvertvkt::shaderexecutor::ContainerTraits915     static IVal doConvert(const FloatFormat &fmt, const IVal &value)
916     {
917         IVal ret;
918 
919         for (int ndx = 0; ndx < T::SIZE; ++ndx)
920             ret[ndx] = convert<Element>(fmt, value[ndx]);
921 
922         return ret;
923     }
924 
doRoundvkt::shaderexecutor::ContainerTraits925     static IVal doRound(const FloatFormat &fmt, T value)
926     {
927         IVal ret;
928 
929         for (int ndx = 0; ndx < T::SIZE; ++ndx)
930             ret[ndx] = round(fmt, value[ndx]);
931 
932         return ret;
933     }
934 };
935 
936 template <typename T, int Size>
937 struct Traits<Vector<T, Size>> : ContainerTraits<Vector<T, Size>, Vector<typename Traits<T>::IVal, Size>>
938 {
939 };
940 
941 template <typename T, int Rows, int Cols>
942 struct Traits<Matrix<T, Rows, Cols>>
943     : ContainerTraits<Matrix<T, Rows, Cols>, Matrix<typename Traits<T>::IVal, Rows, Cols>>
944 {
945 };
946 
947 //! Void traits. These are just dummies, but technically valid: a Void is a
948 //! unit type with a single possible value.
949 template <>
950 struct Traits<Void>
951 {
952     typedef Void IVal;
953 
doMakeIValvkt::shaderexecutor::Traits954     static Void doMakeIVal(const Void &value)
955     {
956         return value;
957     }
doUnionvkt::shaderexecutor::Traits958     static Void doUnion(const Void &, const Void &)
959     {
960         return Void();
961     }
doContainsvkt::shaderexecutor::Traits962     static bool doContains(const Void &, Void)
963     {
964         return true;
965     }
966     template <typename U>
doContainsvkt::shaderexecutor::Traits967     static bool doContains(const Void &, const Void &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
968     {
969         DE_UNREF(value);
970         DE_UNREF(is16Bit);
971         DE_UNREF(modularDivisor);
972         return true;
973     }
doRoundvkt::shaderexecutor::Traits974     static Void doRound(const FloatFormat &, const Void &value)
975     {
976         return value;
977     }
doConvertvkt::shaderexecutor::Traits978     static Void doConvert(const FloatFormat &, const Void &value)
979     {
980         return value;
981     }
982 
doPrintValue16vkt::shaderexecutor::Traits983     static void doPrintValue16(const FloatFormat &, const Void &, ostream &os)
984     {
985         os << "()";
986     }
987 
doPrintValue32vkt::shaderexecutor::Traits988     static void doPrintValue32(const FloatFormat &, const Void &, ostream &os)
989     {
990         os << "()";
991     }
992 
doPrintValue64vkt::shaderexecutor::Traits993     static void doPrintValue64(const FloatFormat &, const Void &, ostream &os)
994     {
995         os << "()";
996     }
997 
doPrintIValvkt::shaderexecutor::Traits998     static void doPrintIVal(const FloatFormat &, const Void &, ostream &os)
999     {
1000         os << "()";
1001     }
1002 };
1003 
1004 //! This is needed for container-generic operations.
1005 //! We want a scalar type T to be its own "one-element vector".
1006 template <typename T, int Size>
1007 struct ContainerOf
1008 {
1009     typedef Vector<T, Size> Container;
1010 };
1011 
1012 template <typename T>
1013 struct ContainerOf<T, 1>
1014 {
1015     typedef T Container;
1016 };
1017 template <int Size>
1018 struct ContainerOf<Void, Size>
1019 {
1020     typedef Void Container;
1021 };
1022 
1023 // This is a kludge that is only needed to get the ExprP::operator[] syntactic sugar to work.
1024 template <typename T>
1025 struct ElementOf
1026 {
1027     typedef typename T::Element Element;
1028 };
1029 template <>
1030 struct ElementOf<float>
1031 {
1032     typedef void Element;
1033 };
1034 template <>
1035 struct ElementOf<double>
1036 {
1037     typedef void Element;
1038 };
1039 template <>
1040 struct ElementOf<bool>
1041 {
1042     typedef void Element;
1043 };
1044 template <>
1045 struct ElementOf<int>
1046 {
1047     typedef void Element;
1048 };
1049 
1050 template <typename T>
comparisonMessageInterval(const typename Traits<T>::IVal & val)1051 string comparisonMessageInterval(const typename Traits<T>::IVal &val)
1052 {
1053     DE_UNREF(val);
1054     return "";
1055 }
1056 
1057 template <>
comparisonMessageInterval(const Traits<int>::IVal & val)1058 string comparisonMessageInterval<int>(const Traits<int>::IVal &val)
1059 {
1060     return comparisonMessage(static_cast<int>(val.lo()));
1061 }
1062 
1063 template <>
comparisonMessageInterval(const Traits<float>::IVal & val)1064 string comparisonMessageInterval<float>(const Traits<float>::IVal &val)
1065 {
1066     return comparisonMessage(static_cast<int>(val.lo()));
1067 }
1068 
1069 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,2> & val)1070 string comparisonMessageInterval<tcu::Vector<int, 2>>(const tcu::Vector<tcu::Interval, 2> &val)
1071 {
1072     tcu::IVec2 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()));
1073     return comparisonMessage(result);
1074 }
1075 
1076 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,3> & val)1077 string comparisonMessageInterval<tcu::Vector<int, 3>>(const tcu::Vector<tcu::Interval, 3> &val)
1078 {
1079     tcu::IVec3 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()));
1080     return comparisonMessage(result);
1081 }
1082 
1083 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,4> & val)1084 string comparisonMessageInterval<tcu::Vector<int, 4>>(const tcu::Vector<tcu::Interval, 4> &val)
1085 {
1086     tcu::IVec4 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()),
1087                       static_cast<int>(val[3].lo()));
1088     return comparisonMessage(result);
1089 }
1090 
1091 /*--------------------------------------------------------------------*//*!
1092  *
1093  * \name Abstract syntax for expressions and statements.
1094  *
1095  * We represent GLSL programs as syntax objects: an Expr<T> represents an
1096  * expression whose GLSL type corresponds to the C++ type T, and a Statement
1097  * represents a statement.
1098  *
1099  * To ease memory management, we use shared pointers to refer to expressions
1100  * and statements. ExprP<T> is a shared pointer to an Expr<T>, and StatementP
1101  * is a shared pointer to a Statement.
1102  *
1103  * \{
1104  *
1105  *//*--------------------------------------------------------------------*/
1106 
1107 class ExprBase;
1108 class ExpandContext;
1109 class Statement;
1110 class StatementP;
1111 class FuncBase;
1112 template <typename T>
1113 class ExprP;
1114 template <typename T>
1115 class Variable;
1116 template <typename T>
1117 class VariableP;
1118 template <typename T>
1119 class DefaultSampling;
1120 
1121 typedef set<const FuncBase *> FuncSet;
1122 
1123 template <typename T>
1124 VariableP<T> variable(const string &name);
1125 StatementP compoundStatement(const vector<StatementP> &statements);
1126 
1127 /*--------------------------------------------------------------------*//*!
1128  * \brief A variable environment.
1129  *
1130  * An Environment object maintains the mapping between variables of the
1131  * abstract syntax tree and their values.
1132  *
1133  * \todo [2014-03-28 lauri] At least run-time type safety.
1134  *
1135  *//*--------------------------------------------------------------------*/
1136 class Environment
1137 {
1138 public:
1139     template <typename T>
bind(const Variable<T> & variable,const typename Traits<T>::IVal & value)1140     void bind(const Variable<T> &variable, const typename Traits<T>::IVal &value)
1141     {
1142         uint8_t *const data = new uint8_t[sizeof(value)];
1143 
1144         deMemcpy(data, &value, sizeof(value));
1145         de::insert(m_map, variable.getName(), SharedPtr<uint8_t>(data, de::ArrayDeleter<uint8_t>()));
1146     }
1147 
1148     template <typename T>
lookup(const Variable<T> & variable) const1149     typename Traits<T>::IVal &lookup(const Variable<T> &variable) const
1150     {
1151         uint8_t *const data = de::lookup(m_map, variable.getName()).get();
1152 
1153         return *reinterpret_cast<typename Traits<T>::IVal *>(data);
1154     }
1155 
1156 private:
1157     map<string, SharedPtr<uint8_t>> m_map;
1158 };
1159 
1160 /*--------------------------------------------------------------------*//*!
1161  * \brief Evaluation context.
1162  *
1163  * The evaluation context contains everything that separates one execution of
1164  * an expression from the next. Currently this means the desired floating
1165  * point precision and the current variable environment.
1166  *
1167  *//*--------------------------------------------------------------------*/
1168 struct EvalContext
1169 {
EvalContextvkt::shaderexecutor::EvalContext1170     EvalContext(const FloatFormat &format_, Precision floatPrecision_, Environment &env_, int callDepth_)
1171         : format(format_)
1172         , floatPrecision(floatPrecision_)
1173         , env(env_)
1174         , callDepth(callDepth_)
1175     {
1176     }
1177 
1178     FloatFormat format;
1179     Precision floatPrecision;
1180     Environment &env;
1181     int callDepth;
1182 };
1183 
1184 /*--------------------------------------------------------------------*//*!
1185  * \brief Simple incremental counter.
1186  *
1187  * This is used to make sure that different ExpandContexts will not produce
1188  * overlapping temporary names.
1189  *
1190  *//*--------------------------------------------------------------------*/
1191 class Counter
1192 {
1193 public:
Counter(int count=0)1194     Counter(int count = 0) : m_count(count)
1195     {
1196     }
operator ()(void)1197     int operator()(void)
1198     {
1199         return m_count++;
1200     }
1201 
1202 private:
1203     int m_count;
1204 };
1205 
1206 class ExpandContext
1207 {
1208 public:
ExpandContext(Counter & symCounter)1209     ExpandContext(Counter &symCounter) : m_symCounter(symCounter)
1210     {
1211     }
ExpandContext(const ExpandContext & parent)1212     ExpandContext(const ExpandContext &parent) : m_symCounter(parent.m_symCounter)
1213     {
1214     }
1215 
1216     template <typename T>
genSym(const string & baseName)1217     VariableP<T> genSym(const string &baseName)
1218     {
1219         return variable<T>(baseName + de::toString(m_symCounter()));
1220     }
1221 
addStatement(const StatementP & stmt)1222     void addStatement(const StatementP &stmt)
1223     {
1224         m_statements.push_back(stmt);
1225     }
1226 
getStatements(void) const1227     vector<StatementP> getStatements(void) const
1228     {
1229         return m_statements;
1230     }
1231 
1232 private:
1233     Counter &m_symCounter;
1234     vector<StatementP> m_statements;
1235 };
1236 
1237 /*--------------------------------------------------------------------*//*!
1238  * \brief A statement or declaration.
1239  *
1240  * Statements have no values. Instead, they are executed for their side
1241  * effects only: the execute() method should modify at least one variable in
1242  * the environment.
1243  *
1244  * As a bit of a kludge, a Statement object can also represent a declaration:
1245  * when it is evaluated, it can add a variable binding to the environment
1246  * instead of modifying a current one.
1247  *
1248  *//*--------------------------------------------------------------------*/
1249 class Statement
1250 {
1251 public:
~Statement(void)1252     virtual ~Statement(void)
1253     {
1254     }
1255     //! Execute the statement, modifying the environment of `ctx`
execute(EvalContext & ctx) const1256     void execute(EvalContext &ctx) const
1257     {
1258         this->doExecute(ctx);
1259     }
print(ostream & os) const1260     void print(ostream &os) const
1261     {
1262         this->doPrint(os);
1263     }
1264     //! Add the functions used in this statement to `dst`.
getUsedFuncs(FuncSet & dst) const1265     void getUsedFuncs(FuncSet &dst) const
1266     {
1267         this->doGetUsedFuncs(dst);
1268     }
failed(EvalContext & ctx) const1269     void failed(EvalContext &ctx) const
1270     {
1271         this->doFail(ctx);
1272     }
1273 
1274 protected:
1275     virtual void doPrint(ostream &os) const         = 0;
1276     virtual void doExecute(EvalContext &ctx) const  = 0;
1277     virtual void doGetUsedFuncs(FuncSet &dst) const = 0;
doFail(EvalContext & ctx) const1278     virtual void doFail(EvalContext &ctx) const
1279     {
1280         DE_UNREF(ctx);
1281     }
1282 };
1283 
operator <<(ostream & os,const Statement & stmt)1284 ostream &operator<<(ostream &os, const Statement &stmt)
1285 {
1286     stmt.print(os);
1287     return os;
1288 }
1289 
1290 /*--------------------------------------------------------------------*//*!
1291  * \brief Smart pointer for statements (and declarations)
1292  *
1293  *//*--------------------------------------------------------------------*/
1294 class StatementP : public SharedPtr<const Statement>
1295 {
1296 public:
1297     typedef SharedPtr<const Statement> Super;
1298 
StatementP(void)1299     StatementP(void)
1300     {
1301     }
StatementP(const Statement * ptr)1302     explicit StatementP(const Statement *ptr) : Super(ptr)
1303     {
1304     }
StatementP(const Super & ptr)1305     StatementP(const Super &ptr) : Super(ptr)
1306     {
1307     }
1308 };
1309 
1310 /*--------------------------------------------------------------------*//*!
1311  * \brief
1312  *
1313  * A statement that modifies a variable or a declaration that binds a variable.
1314  *
1315  *//*--------------------------------------------------------------------*/
1316 template <typename T>
1317 class VariableStatement : public Statement
1318 {
1319 public:
VariableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1320     VariableStatement(const VariableP<T> &variable, const ExprP<T> &value, bool isDeclaration)
1321         : m_variable(variable)
1322         , m_value(value)
1323         , m_isDeclaration(isDeclaration)
1324     {
1325     }
1326 
1327 protected:
doPrint(ostream & os) const1328     void doPrint(ostream &os) const
1329     {
1330         if (m_isDeclaration)
1331             os << glu::declare(getVarTypeOf<T>(), m_variable->getName());
1332         else
1333             os << m_variable->getName();
1334 
1335         os << " = ";
1336         os << *m_value << ";\n";
1337     }
1338 
doExecute(EvalContext & ctx) const1339     void doExecute(EvalContext &ctx) const
1340     {
1341         if (m_isDeclaration)
1342             ctx.env.bind(*m_variable, m_value->evaluate(ctx));
1343         else
1344             ctx.env.lookup(*m_variable) = m_value->evaluate(ctx);
1345     }
1346 
doGetUsedFuncs(FuncSet & dst) const1347     void doGetUsedFuncs(FuncSet &dst) const
1348     {
1349         m_value->getUsedFuncs(dst);
1350     }
1351 
doFail(EvalContext & ctx) const1352     virtual void doFail(EvalContext &ctx) const
1353     {
1354         if (m_isDeclaration)
1355             ctx.env.bind(*m_variable, m_value->fails(ctx));
1356         else
1357             ctx.env.lookup(*m_variable) = m_value->fails(ctx);
1358     }
1359 
1360     VariableP<T> m_variable;
1361     ExprP<T> m_value;
1362     bool m_isDeclaration;
1363 };
1364 
1365 template <typename T>
variableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1366 StatementP variableStatement(const VariableP<T> &variable, const ExprP<T> &value, bool isDeclaration)
1367 {
1368     return StatementP(new VariableStatement<T>(variable, value, isDeclaration));
1369 }
1370 
1371 template <typename T>
variableDeclaration(const VariableP<T> & variable,const ExprP<T> & definiens)1372 StatementP variableDeclaration(const VariableP<T> &variable, const ExprP<T> &definiens)
1373 {
1374     return variableStatement(variable, definiens, true);
1375 }
1376 
1377 template <typename T>
variableAssignment(const VariableP<T> & variable,const ExprP<T> & value)1378 StatementP variableAssignment(const VariableP<T> &variable, const ExprP<T> &value)
1379 {
1380     return variableStatement(variable, value, false);
1381 }
1382 
1383 /*--------------------------------------------------------------------*//*!
1384  * \brief A compound statement, i.e. a block.
1385  *
1386  * A compound statement is executed by executing its constituent statements in
1387  * sequence.
1388  *
1389  *//*--------------------------------------------------------------------*/
1390 class CompoundStatement : public Statement
1391 {
1392 public:
CompoundStatement(const vector<StatementP> & statements)1393     CompoundStatement(const vector<StatementP> &statements) : m_statements(statements)
1394     {
1395     }
1396 
1397 protected:
doPrint(ostream & os) const1398     void doPrint(ostream &os) const
1399     {
1400         os << "{\n";
1401 
1402         for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1403             os << *m_statements[ndx];
1404 
1405         os << "}\n";
1406     }
1407 
doExecute(EvalContext & ctx) const1408     void doExecute(EvalContext &ctx) const
1409     {
1410         for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1411             m_statements[ndx]->execute(ctx);
1412     }
1413 
doGetUsedFuncs(FuncSet & dst) const1414     void doGetUsedFuncs(FuncSet &dst) const
1415     {
1416         for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1417             m_statements[ndx]->getUsedFuncs(dst);
1418     }
1419 
1420     vector<StatementP> m_statements;
1421 };
1422 
compoundStatement(const vector<StatementP> & statements)1423 StatementP compoundStatement(const vector<StatementP> &statements)
1424 {
1425     return StatementP(new CompoundStatement(statements));
1426 }
1427 
1428 //! Common base class for all expressions regardless of their type.
1429 class ExprBase
1430 {
1431 public:
~ExprBase(void)1432     virtual ~ExprBase(void)
1433     {
1434     }
printExpr(ostream & os) const1435     void printExpr(ostream &os) const
1436     {
1437         this->doPrintExpr(os);
1438     }
1439 
1440     //! Output the functions that this expression refers to
getUsedFuncs(FuncSet & dst) const1441     void getUsedFuncs(FuncSet &dst) const
1442     {
1443         this->doGetUsedFuncs(dst);
1444     }
1445 
1446 protected:
doPrintExpr(ostream &) const1447     virtual void doPrintExpr(ostream &) const
1448     {
1449     }
doGetUsedFuncs(FuncSet &) const1450     virtual void doGetUsedFuncs(FuncSet &) const
1451     {
1452     }
1453 };
1454 
1455 //! Type-specific operations for an expression representing type T.
1456 template <typename T>
1457 class Expr : public ExprBase
1458 {
1459 public:
1460     typedef T Val;
1461     typedef typename Traits<T>::IVal IVal;
1462 
1463     IVal evaluate(const EvalContext &ctx) const;
fails(const EvalContext & ctx) const1464     IVal fails(const EvalContext &ctx) const
1465     {
1466         return this->doFails(ctx);
1467     }
1468 
1469 protected:
1470     virtual IVal doEvaluate(const EvalContext &ctx) const = 0;
doFails(const EvalContext & ctx) const1471     virtual IVal doFails(const EvalContext &ctx) const
1472     {
1473         return doEvaluate(ctx);
1474     }
1475 };
1476 
1477 //! Evaluate an expression with the given context, optionally tracing the calls to stderr.
1478 template <typename T>
evaluate(const EvalContext & ctx) const1479 typename Traits<T>::IVal Expr<T>::evaluate(const EvalContext &ctx) const
1480 {
1481 #ifdef GLS_ENABLE_TRACE
1482     static const FloatFormat highpFmt(-126, 127, 23, true, tcu::MAYBE, tcu::YES, tcu::MAYBE);
1483     EvalContext newCtx(ctx.format, ctx.floatPrecision, ctx.env, ctx.callDepth + 1);
1484     const IVal ret = this->doEvaluate(newCtx);
1485 
1486     if (isTypeValid<T>())
1487     {
1488         std::cerr << string(ctx.callDepth, ' ');
1489         this->printExpr(std::cerr);
1490         std::cerr << " -> " << intervalToString<T>(highpFmt, ret) << std::endl;
1491     }
1492     return ret;
1493 #else
1494     return this->doEvaluate(ctx);
1495 #endif
1496 }
1497 
1498 template <typename T>
1499 class ExprPBase : public SharedPtr<const Expr<T>>
1500 {
1501 public:
1502 };
1503 
operator <<(ostream & os,const ExprBase & expr)1504 ostream &operator<<(ostream &os, const ExprBase &expr)
1505 {
1506     expr.printExpr(os);
1507     return os;
1508 }
1509 
1510 /*--------------------------------------------------------------------*//*!
1511  * \brief Shared pointer to an expression of a container type.
1512  *
1513  * Container types (i.e. vectors and matrices) support the subscription
1514  * operator. This class provides a bit of syntactic sugar to allow us to use
1515  * the C++ subscription operator to create a subscription expression.
1516  *//*--------------------------------------------------------------------*/
1517 template <typename T>
1518 class ContainerExprPBase : public ExprPBase<T>
1519 {
1520 public:
1521     ExprP<typename T::Element> operator[](int i) const;
1522 };
1523 
1524 template <typename T>
1525 class ExprP : public ExprPBase<T>
1526 {
1527 };
1528 
1529 // We treat Voids as containers since the unused parameters in generalized
1530 // vector functions are represented as Voids.
1531 template <>
1532 class ExprP<Void> : public ContainerExprPBase<Void>
1533 {
1534 };
1535 
1536 template <typename T, int Size>
1537 class ExprP<Vector<T, Size>> : public ContainerExprPBase<Vector<T, Size>>
1538 {
1539 };
1540 
1541 template <typename T, int Rows, int Cols>
1542 class ExprP<Matrix<T, Rows, Cols>> : public ContainerExprPBase<Matrix<T, Rows, Cols>>
1543 {
1544 };
1545 
1546 template <typename T>
exprP(void)1547 ExprP<T> exprP(void)
1548 {
1549     return ExprP<T>();
1550 }
1551 
1552 template <typename T>
exprP(const SharedPtr<const Expr<T>> & ptr)1553 ExprP<T> exprP(const SharedPtr<const Expr<T>> &ptr)
1554 {
1555     ExprP<T> ret;
1556     static_cast<SharedPtr<const Expr<T>> &>(ret) = ptr;
1557     return ret;
1558 }
1559 
1560 template <typename T>
exprP(const Expr<T> * ptr)1561 ExprP<T> exprP(const Expr<T> *ptr)
1562 {
1563     return exprP(SharedPtr<const Expr<T>>(ptr));
1564 }
1565 
1566 /*--------------------------------------------------------------------*//*!
1567  * \brief A shared pointer to a variable expression.
1568  *
1569  * This is just a narrowing of ExprP for the operations that require a variable
1570  * instead of an arbitrary expression.
1571  *
1572  *//*--------------------------------------------------------------------*/
1573 template <typename T>
1574 class VariableP : public SharedPtr<const Variable<T>>
1575 {
1576 public:
1577     typedef SharedPtr<const Variable<T>> Super;
VariableP(const Variable<T> * ptr)1578     explicit VariableP(const Variable<T> *ptr) : Super(ptr)
1579     {
1580     }
VariableP(void)1581     VariableP(void)
1582     {
1583     }
VariableP(const Super & ptr)1584     VariableP(const Super &ptr) : Super(ptr)
1585     {
1586     }
1587 
operator ExprP<T>(void) const1588     operator ExprP<T>(void) const
1589     {
1590         SharedPtr<const Expr<T>> ptr = *this;
1591         return exprP(ptr);
1592     }
1593 };
1594 
1595 /*--------------------------------------------------------------------*//*!
1596  * \name Syntactic sugar operators for expressions.
1597  *
1598  * @{
1599  *
1600  * These operators allow the use of C++ syntax to construct GLSL expressions
1601  * containing operators: e.g. "a+b" creates an addition expression with
1602  * operands a and b, and so on.
1603  *
1604  *//*--------------------------------------------------------------------*/
1605 ExprP<float> operator+(const ExprP<float> &arg0, const ExprP<float> &arg1);
1606 ExprP<deFloat16> operator+(const ExprP<deFloat16> &arg0, const ExprP<deFloat16> &arg1);
1607 ExprP<double> operator+(const ExprP<double> &arg0, const ExprP<double> &arg1);
1608 template <typename T>
1609 ExprP<T> operator-(const ExprP<T> &arg0);
1610 template <typename T>
1611 ExprP<T> operator-(const ExprP<T> &arg0, const ExprP<T> &arg1);
1612 template <int Left, int Mid, int Right, typename T>
1613 ExprP<Matrix<T, Left, Right>> operator*(const ExprP<Matrix<T, Left, Mid>> &left,
1614                                         const ExprP<Matrix<T, Mid, Right>> &right);
1615 ExprP<float> operator*(const ExprP<float> &arg0, const ExprP<float> &arg1);
1616 ExprP<deFloat16> operator*(const ExprP<deFloat16> &arg0, const ExprP<deFloat16> &arg1);
1617 ExprP<double> operator*(const ExprP<double> &arg0, const ExprP<double> &arg1);
1618 template <typename T>
1619 ExprP<T> operator/(const ExprP<T> &arg0, const ExprP<T> &arg1);
1620 template <typename T, int Size>
1621 ExprP<Vector<T, Size>> operator-(const ExprP<Vector<T, Size>> &arg0);
1622 template <typename T, int Size>
1623 ExprP<Vector<T, Size>> operator-(const ExprP<Vector<T, Size>> &arg0, const ExprP<Vector<T, Size>> &arg1);
1624 template <int Size, typename T>
1625 ExprP<Vector<T, Size>> operator*(const ExprP<Vector<T, Size>> &arg0, const ExprP<T> &arg1);
1626 template <typename T, int Size>
1627 ExprP<Vector<T, Size>> operator*(const ExprP<Vector<T, Size>> &arg0, const ExprP<Vector<T, Size>> &arg1);
1628 template <int Rows, int Cols, typename T>
1629 ExprP<Vector<T, Rows>> operator*(const ExprP<Vector<T, Cols>> &left, const ExprP<Matrix<T, Rows, Cols>> &right);
1630 template <int Rows, int Cols, typename T>
1631 ExprP<Vector<T, Cols>> operator*(const ExprP<Matrix<T, Rows, Cols>> &left, const ExprP<Vector<T, Rows>> &right);
1632 template <int Rows, int Cols, typename T>
1633 ExprP<Matrix<T, Rows, Cols>> operator*(const ExprP<Matrix<T, Rows, Cols>> &left, const ExprP<T> &right);
1634 template <int Rows, int Cols>
1635 ExprP<Matrix<float, Rows, Cols>> operator+(const ExprP<Matrix<float, Rows, Cols>> &left,
1636                                            const ExprP<Matrix<float, Rows, Cols>> &right);
1637 template <int Rows, int Cols>
1638 ExprP<Matrix<deFloat16, Rows, Cols>> operator+(const ExprP<Matrix<deFloat16, Rows, Cols>> &left,
1639                                                const ExprP<Matrix<deFloat16, Rows, Cols>> &right);
1640 template <int Rows, int Cols>
1641 ExprP<Matrix<double, Rows, Cols>> operator+(const ExprP<Matrix<double, Rows, Cols>> &left,
1642                                             const ExprP<Matrix<double, Rows, Cols>> &right);
1643 template <typename T, int Rows, int Cols>
1644 ExprP<Matrix<T, Rows, Cols>> operator-(const ExprP<Matrix<T, Rows, Cols>> &mat);
1645 
1646 //! @}
1647 
1648 /*--------------------------------------------------------------------*//*!
1649  * \brief Variable expression.
1650  *
1651  * A variable is evaluated by looking up its range of possible values from an
1652  * environment.
1653  *//*--------------------------------------------------------------------*/
1654 template <typename T>
1655 class Variable : public Expr<T>
1656 {
1657 public:
1658     typedef typename Expr<T>::IVal IVal;
1659 
Variable(const string & name)1660     Variable(const string &name) : m_name(name)
1661     {
1662     }
getName(void) const1663     string getName(void) const
1664     {
1665         return m_name;
1666     }
1667 
1668 protected:
doPrintExpr(ostream & os) const1669     void doPrintExpr(ostream &os) const
1670     {
1671         os << m_name;
1672     }
doEvaluate(const EvalContext & ctx) const1673     IVal doEvaluate(const EvalContext &ctx) const
1674     {
1675         return ctx.env.lookup<T>(*this);
1676     }
1677 
1678 private:
1679     string m_name;
1680 };
1681 
1682 template <typename T>
variable(const string & name)1683 VariableP<T> variable(const string &name)
1684 {
1685     return VariableP<T>(new Variable<T>(name));
1686 }
1687 
1688 template <typename T>
bindExpression(const string & name,ExpandContext & ctx,const ExprP<T> & expr)1689 VariableP<T> bindExpression(const string &name, ExpandContext &ctx, const ExprP<T> &expr)
1690 {
1691     VariableP<T> var = ctx.genSym<T>(name);
1692     ctx.addStatement(variableDeclaration(var, expr));
1693     return var;
1694 }
1695 
1696 /*--------------------------------------------------------------------*//*!
1697  * \brief Constant expression.
1698  *
1699  * A constant is evaluated by rounding it to a set of possible values allowed
1700  * by the current floating point precision.
1701  * TODO: For whatever reason this doesn't happen, the constant is converted to
1702  *       type T and the interval contains only (T)value. See FloatConstant, below.
1703  *//*--------------------------------------------------------------------*/
1704 template <typename T>
1705 class Constant : public Expr<T>
1706 {
1707 public:
1708     typedef typename Expr<T>::IVal IVal;
1709 
Constant(const T & value)1710     Constant(const T &value) : m_value(value)
1711     {
1712     }
1713 
1714 protected:
doPrintExpr(ostream & os) const1715     void doPrintExpr(ostream &os) const
1716     {
1717         os << m_value;
1718     }
doEvaluate(const EvalContext &) const1719     IVal doEvaluate(const EvalContext &) const
1720     {
1721         return makeIVal(m_value);
1722     }
1723 
1724 private:
1725     T m_value;
1726 };
1727 
1728 template <typename T>
constant(const T & value)1729 ExprP<T> constant(const T &value)
1730 {
1731     return exprP(new Constant<T>(value));
1732 }
1733 
1734 template <typename T>
1735 class FloatConstant : public Expr<T>
1736 {
1737 public:
1738     typedef typename Expr<T>::IVal IVal;
1739 
FloatConstant(double value)1740     FloatConstant(double value) : m_value(value)
1741     {
1742     }
1743 
1744 protected:
doPrintExpr(ostream & os) const1745     void doPrintExpr(ostream &os) const
1746     {
1747         os << m_value;
1748     }
1749     // TODO: This should probably roundOut to T, not ctx.format, but the templates don't work like that.
doEvaluate(const EvalContext & ctx) const1750     IVal doEvaluate(const EvalContext &ctx) const
1751     {
1752         return ctx.format.roundOut(makeIVal(m_value), true);
1753     }
1754 
1755 private:
1756     double m_value;
1757 };
1758 
f16Constant(double value)1759 ExprP<deFloat16> f16Constant(double value)
1760 {
1761     return exprP(new FloatConstant<deFloat16>(value));
1762 }
f32Constant(double value)1763 ExprP<float> f32Constant(double value)
1764 {
1765     return exprP(new FloatConstant<float>(value));
1766 }
1767 
1768 //! Return a reference to a singleton void constant.
voidP(void)1769 const ExprP<Void> &voidP(void)
1770 {
1771     static const ExprP<Void> singleton = constant(Void());
1772 
1773     return singleton;
1774 }
1775 
1776 /*--------------------------------------------------------------------*//*!
1777  * \brief Four-element tuple.
1778  *
1779  * This is used for various things where we need one thing for each possible
1780  * function parameter. Currently the maximum supported number of parameters is
1781  * four.
1782  *//*--------------------------------------------------------------------*/
1783 template <typename T0 = Void, typename T1 = Void, typename T2 = Void, typename T3 = Void>
1784 struct Tuple4
1785 {
Tuple4vkt::shaderexecutor::Tuple41786     explicit Tuple4(const T0 e0 = T0(), const T1 e1 = T1(), const T2 e2 = T2(), const T3 e3 = T3())
1787         : a(e0)
1788         , b(e1)
1789         , c(e2)
1790         , d(e3)
1791     {
1792     }
1793 
1794     T0 a;
1795     T1 b;
1796     T2 c;
1797     T3 d;
1798 };
1799 
1800 /*--------------------------------------------------------------------*//*!
1801  * \brief Function signature.
1802  *
1803  * This is a purely compile-time structure used to bundle all types in a
1804  * function signature together. This makes passing the signature around in
1805  * templates easier, since we only need to take and pass a single Sig instead
1806  * of a bunch of parameter types and a return type.
1807  *
1808  *//*--------------------------------------------------------------------*/
1809 template <typename R, typename P0 = Void, typename P1 = Void, typename P2 = Void, typename P3 = Void>
1810 struct Signature
1811 {
1812     typedef R Ret;
1813     typedef P0 Arg0;
1814     typedef P1 Arg1;
1815     typedef P2 Arg2;
1816     typedef P3 Arg3;
1817     typedef typename Traits<Ret>::IVal IRet;
1818     typedef typename Traits<Arg0>::IVal IArg0;
1819     typedef typename Traits<Arg1>::IVal IArg1;
1820     typedef typename Traits<Arg2>::IVal IArg2;
1821     typedef typename Traits<Arg3>::IVal IArg3;
1822 
1823     typedef Tuple4<const Arg0 &, const Arg1 &, const Arg2 &, const Arg3 &> Args;
1824     typedef Tuple4<const IArg0 &, const IArg1 &, const IArg2 &, const IArg3 &> IArgs;
1825     typedef Tuple4<ExprP<Arg0>, ExprP<Arg1>, ExprP<Arg2>, ExprP<Arg3>> ArgExprs;
1826 };
1827 
1828 typedef vector<const ExprBase *> BaseArgExprs;
1829 
1830 /*--------------------------------------------------------------------*//*!
1831  * \brief Type-independent operations for function objects.
1832  *
1833  *//*--------------------------------------------------------------------*/
1834 class FuncBase
1835 {
1836 public:
~FuncBase(void)1837     virtual ~FuncBase(void)
1838     {
1839     }
1840     virtual string getName(void) const = 0;
1841     //! Name of extension that this function requires, or empty.
getRequiredExtension(void) const1842     virtual string getRequiredExtension(void) const
1843     {
1844         return "";
1845     }
getInputRange(const bool is16bit) const1846     virtual Interval getInputRange(const bool is16bit) const
1847     {
1848         DE_UNREF(is16bit);
1849         return Interval(true, -TCU_INFINITY, TCU_INFINITY);
1850     }
1851     virtual void print(ostream &, const BaseArgExprs &) const = 0;
1852     //! Index of output parameter, or -1 if none of the parameters is output.
getOutParamIndex(void) const1853     virtual int getOutParamIndex(void) const
1854     {
1855         return -1;
1856     }
1857 
getSpirvCase(void) const1858     virtual SpirVCaseT getSpirvCase(void) const
1859     {
1860         return SPIRV_CASETYPE_NONE;
1861     }
1862 
printDefinition(ostream & os) const1863     void printDefinition(ostream &os) const
1864     {
1865         doPrintDefinition(os);
1866     }
1867 
getUsedFuncs(FuncSet & dst) const1868     void getUsedFuncs(FuncSet &dst) const
1869     {
1870         this->doGetUsedFuncs(dst);
1871     }
1872 
1873 protected:
1874     virtual void doPrintDefinition(ostream &os) const = 0;
1875     virtual void doGetUsedFuncs(FuncSet &dst) const   = 0;
1876 };
1877 
1878 typedef Tuple4<string, string, string, string> ParamNames;
1879 
1880 /*--------------------------------------------------------------------*//*!
1881  * \brief Function objects.
1882  *
1883  * Each Func object represents a GLSL function. It can be applied to interval
1884  * arguments, and it returns the an interval that is a conservative
1885  * approximation of the image of the GLSL function over the argument
1886  * intervals. That is, it is given a set of possible arguments and it returns
1887  * the set of possible values.
1888  *
1889  *//*--------------------------------------------------------------------*/
1890 template <typename Sig_>
1891 class Func : public FuncBase
1892 {
1893 public:
1894     typedef Sig_ Sig;
1895     typedef typename Sig::Ret Ret;
1896     typedef typename Sig::Arg0 Arg0;
1897     typedef typename Sig::Arg1 Arg1;
1898     typedef typename Sig::Arg2 Arg2;
1899     typedef typename Sig::Arg3 Arg3;
1900     typedef typename Sig::IRet IRet;
1901     typedef typename Sig::IArg0 IArg0;
1902     typedef typename Sig::IArg1 IArg1;
1903     typedef typename Sig::IArg2 IArg2;
1904     typedef typename Sig::IArg3 IArg3;
1905     typedef typename Sig::Args Args;
1906     typedef typename Sig::IArgs IArgs;
1907     typedef typename Sig::ArgExprs ArgExprs;
1908 
print(ostream & os,const BaseArgExprs & args) const1909     void print(ostream &os, const BaseArgExprs &args) const
1910     {
1911         this->doPrint(os, args);
1912     }
1913 
apply(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1914     IRet apply(const EvalContext &ctx, const IArg0 &arg0 = IArg0(), const IArg1 &arg1 = IArg1(),
1915                const IArg2 &arg2 = IArg2(), const IArg3 &arg3 = IArg3()) const
1916     {
1917         return this->applyArgs(ctx, IArgs(arg0, arg1, arg2, arg3));
1918     }
1919 
fail(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1920     IRet fail(const EvalContext &ctx, const IArg0 &arg0 = IArg0(), const IArg1 &arg1 = IArg1(),
1921               const IArg2 &arg2 = IArg2(), const IArg3 &arg3 = IArg3()) const
1922     {
1923         return this->doFail(ctx, IArgs(arg0, arg1, arg2, arg3));
1924     }
applyArgs(const EvalContext & ctx,const IArgs & args) const1925     IRet applyArgs(const EvalContext &ctx, const IArgs &args) const
1926     {
1927         return this->doApply(ctx, args);
1928     }
1929     ExprP<Ret> operator()(const ExprP<Arg0> &arg0 = voidP(), const ExprP<Arg1> &arg1 = voidP(),
1930                           const ExprP<Arg2> &arg2 = voidP(), const ExprP<Arg3> &arg3 = voidP()) const;
1931 
getParamNames(void) const1932     const ParamNames &getParamNames(void) const
1933     {
1934         return this->doGetParamNames();
1935     }
1936 
1937 protected:
1938     virtual IRet doApply(const EvalContext &, const IArgs &) const = 0;
doFail(const EvalContext & ctx,const IArgs & args) const1939     virtual IRet doFail(const EvalContext &ctx, const IArgs &args) const
1940     {
1941         return this->doApply(ctx, args);
1942     }
doPrint(ostream & os,const BaseArgExprs & args) const1943     virtual void doPrint(ostream &os, const BaseArgExprs &args) const
1944     {
1945         os << getName() << "(";
1946 
1947         if (isTypeValid<Arg0>())
1948             os << *args[0];
1949 
1950         if (isTypeValid<Arg1>())
1951             os << ", " << *args[1];
1952 
1953         if (isTypeValid<Arg2>())
1954             os << ", " << *args[2];
1955 
1956         if (isTypeValid<Arg3>())
1957             os << ", " << *args[3];
1958 
1959         os << ")";
1960     }
1961 
doGetParamNames(void) const1962     virtual const ParamNames &doGetParamNames(void) const
1963     {
1964         static ParamNames names("a", "b", "c", "d");
1965         return names;
1966     }
1967 };
1968 
1969 template <typename Sig>
1970 class Apply : public Expr<typename Sig::Ret>
1971 {
1972 public:
1973     typedef typename Sig::Ret Ret;
1974     typedef typename Sig::Arg0 Arg0;
1975     typedef typename Sig::Arg1 Arg1;
1976     typedef typename Sig::Arg2 Arg2;
1977     typedef typename Sig::Arg3 Arg3;
1978     typedef typename Expr<Ret>::Val Val;
1979     typedef typename Expr<Ret>::IVal IVal;
1980     typedef Func<Sig> ApplyFunc;
1981     typedef typename ApplyFunc::ArgExprs ArgExprs;
1982 
Apply(const ApplyFunc & func,const ExprP<Arg0> & arg0=voidP (),const ExprP<Arg1> & arg1=voidP (),const ExprP<Arg2> & arg2=voidP (),const ExprP<Arg3> & arg3=voidP ())1983     Apply(const ApplyFunc &func, const ExprP<Arg0> &arg0 = voidP(), const ExprP<Arg1> &arg1 = voidP(),
1984           const ExprP<Arg2> &arg2 = voidP(), const ExprP<Arg3> &arg3 = voidP())
1985         : m_func(func)
1986         , m_args(arg0, arg1, arg2, arg3)
1987     {
1988     }
1989 
Apply(const ApplyFunc & func,const ArgExprs & args)1990     Apply(const ApplyFunc &func, const ArgExprs &args) : m_func(func), m_args(args)
1991     {
1992     }
1993 
1994 protected:
doPrintExpr(ostream & os) const1995     void doPrintExpr(ostream &os) const
1996     {
1997         BaseArgExprs args;
1998         args.push_back(m_args.a.get());
1999         args.push_back(m_args.b.get());
2000         args.push_back(m_args.c.get());
2001         args.push_back(m_args.d.get());
2002         m_func.print(os, args);
2003     }
2004 
doEvaluate(const EvalContext & ctx) const2005     IVal doEvaluate(const EvalContext &ctx) const
2006     {
2007         return m_func.apply(ctx, m_args.a->evaluate(ctx), m_args.b->evaluate(ctx), m_args.c->evaluate(ctx),
2008                             m_args.d->evaluate(ctx));
2009     }
2010 
doGetUsedFuncs(FuncSet & dst) const2011     void doGetUsedFuncs(FuncSet &dst) const
2012     {
2013         m_func.getUsedFuncs(dst);
2014         m_args.a->getUsedFuncs(dst);
2015         m_args.b->getUsedFuncs(dst);
2016         m_args.c->getUsedFuncs(dst);
2017         m_args.d->getUsedFuncs(dst);
2018     }
2019 
2020     const ApplyFunc &m_func;
2021     ArgExprs m_args;
2022 };
2023 
2024 template <typename T>
2025 class Alternatives : public Func<Signature<T, T, T>>
2026 {
2027 public:
2028     typedef typename Alternatives::Sig Sig;
2029 
2030 protected:
2031     typedef typename Alternatives::IRet IRet;
2032     typedef typename Alternatives::IArgs IArgs;
2033 
getName(void) const2034     virtual string getName(void) const
2035     {
2036         return "alternatives";
2037     }
doPrintDefinition(std::ostream &) const2038     virtual void doPrintDefinition(std::ostream &) const
2039     {
2040     }
doGetUsedFuncs(FuncSet &) const2041     void doGetUsedFuncs(FuncSet &) const
2042     {
2043     }
2044 
doApply(const EvalContext &,const IArgs & args) const2045     virtual IRet doApply(const EvalContext &, const IArgs &args) const
2046     {
2047         return unionIVal<T>(args.a, args.b);
2048     }
2049 
doPrint(ostream & os,const BaseArgExprs & args) const2050     virtual void doPrint(ostream &os, const BaseArgExprs &args) const
2051     {
2052         os << "{" << *args[0] << " | " << *args[1] << "}";
2053     }
2054 };
2055 
2056 template <typename Sig>
createApply(const Func<Sig> & func,const typename Func<Sig>::ArgExprs & args)2057 ExprP<typename Sig::Ret> createApply(const Func<Sig> &func, const typename Func<Sig>::ArgExprs &args)
2058 {
2059     return exprP(new Apply<Sig>(func, args));
2060 }
2061 
2062 template <typename Sig>
createApply(const Func<Sig> & func,const ExprP<typename Sig::Arg0> & arg0=voidP (),const ExprP<typename Sig::Arg1> & arg1=voidP (),const ExprP<typename Sig::Arg2> & arg2=voidP (),const ExprP<typename Sig::Arg3> & arg3=voidP ())2063 ExprP<typename Sig::Ret> createApply(const Func<Sig> &func, const ExprP<typename Sig::Arg0> &arg0 = voidP(),
2064                                      const ExprP<typename Sig::Arg1> &arg1 = voidP(),
2065                                      const ExprP<typename Sig::Arg2> &arg2 = voidP(),
2066                                      const ExprP<typename Sig::Arg3> &arg3 = voidP())
2067 {
2068     return exprP(new Apply<Sig>(func, arg0, arg1, arg2, arg3));
2069 }
2070 
2071 template <typename Sig>
operator ()(const ExprP<typename Sig::Arg0> & arg0,const ExprP<typename Sig::Arg1> & arg1,const ExprP<typename Sig::Arg2> & arg2,const ExprP<typename Sig::Arg3> & arg3) const2072 ExprP<typename Sig::Ret> Func<Sig>::operator()(const ExprP<typename Sig::Arg0> &arg0,
2073                                                const ExprP<typename Sig::Arg1> &arg1,
2074                                                const ExprP<typename Sig::Arg2> &arg2,
2075                                                const ExprP<typename Sig::Arg3> &arg3) const
2076 {
2077     return createApply(*this, arg0, arg1, arg2, arg3);
2078 }
2079 
2080 template <typename F>
app(const ExprP<typename F::Arg0> & arg0=voidP (),const ExprP<typename F::Arg1> & arg1=voidP (),const ExprP<typename F::Arg2> & arg2=voidP (),const ExprP<typename F::Arg3> & arg3=voidP ())2081 ExprP<typename F::Ret> app(const ExprP<typename F::Arg0> &arg0 = voidP(), const ExprP<typename F::Arg1> &arg1 = voidP(),
2082                            const ExprP<typename F::Arg2> &arg2 = voidP(), const ExprP<typename F::Arg3> &arg3 = voidP())
2083 {
2084     return createApply(instance<F>(), arg0, arg1, arg2, arg3);
2085 }
2086 
2087 template <typename F>
call(const EvalContext & ctx,const typename F::IArg0 & arg0=Void (),const typename F::IArg1 & arg1=Void (),const typename F::IArg2 & arg2=Void (),const typename F::IArg3 & arg3=Void ())2088 typename F::IRet call(const EvalContext &ctx, const typename F::IArg0 &arg0 = Void(),
2089                       const typename F::IArg1 &arg1 = Void(), const typename F::IArg2 &arg2 = Void(),
2090                       const typename F::IArg3 &arg3 = Void())
2091 {
2092     return instance<F>().apply(ctx, arg0, arg1, arg2, arg3);
2093 }
2094 
2095 template <typename T>
alternatives(const ExprP<T> & arg0,const ExprP<T> & arg1)2096 ExprP<T> alternatives(const ExprP<T> &arg0, const ExprP<T> &arg1)
2097 {
2098     return createApply<typename Alternatives<T>::Sig>(instance<Alternatives<T>>(), arg0, arg1);
2099 }
2100 
2101 template <typename Sig>
2102 class ApplyVar : public Apply<Sig>
2103 {
2104 public:
2105     typedef typename Sig::Ret Ret;
2106     typedef typename Sig::Arg0 Arg0;
2107     typedef typename Sig::Arg1 Arg1;
2108     typedef typename Sig::Arg2 Arg2;
2109     typedef typename Sig::Arg3 Arg3;
2110     typedef typename Expr<Ret>::Val Val;
2111     typedef typename Expr<Ret>::IVal IVal;
2112     typedef Func<Sig> ApplyFunc;
2113     typedef typename ApplyFunc::ArgExprs ArgExprs;
2114 
ApplyVar(const ApplyFunc & func,const VariableP<Arg0> & arg0,const VariableP<Arg1> & arg1,const VariableP<Arg2> & arg2,const VariableP<Arg3> & arg3)2115     ApplyVar(const ApplyFunc &func, const VariableP<Arg0> &arg0, const VariableP<Arg1> &arg1,
2116              const VariableP<Arg2> &arg2, const VariableP<Arg3> &arg3)
2117         : Apply<Sig>(func, arg0, arg1, arg2, arg3)
2118     {
2119     }
2120 
2121 protected:
doEvaluate(const EvalContext & ctx) const2122     IVal doEvaluate(const EvalContext &ctx) const
2123     {
2124         const Variable<Arg0> &var0 = static_cast<const Variable<Arg0> &>(*this->m_args.a);
2125         const Variable<Arg1> &var1 = static_cast<const Variable<Arg1> &>(*this->m_args.b);
2126         const Variable<Arg2> &var2 = static_cast<const Variable<Arg2> &>(*this->m_args.c);
2127         const Variable<Arg3> &var3 = static_cast<const Variable<Arg3> &>(*this->m_args.d);
2128         return this->m_func.apply(ctx, ctx.env.lookup(var0), ctx.env.lookup(var1), ctx.env.lookup(var2),
2129                                   ctx.env.lookup(var3));
2130     }
2131 
doFails(const EvalContext & ctx) const2132     IVal doFails(const EvalContext &ctx) const
2133     {
2134         const Variable<Arg0> &var0 = static_cast<const Variable<Arg0> &>(*this->m_args.a);
2135         const Variable<Arg1> &var1 = static_cast<const Variable<Arg1> &>(*this->m_args.b);
2136         const Variable<Arg2> &var2 = static_cast<const Variable<Arg2> &>(*this->m_args.c);
2137         const Variable<Arg3> &var3 = static_cast<const Variable<Arg3> &>(*this->m_args.d);
2138         return this->m_func.fail(ctx, ctx.env.lookup(var0), ctx.env.lookup(var1), ctx.env.lookup(var2),
2139                                  ctx.env.lookup(var3));
2140     }
2141 };
2142 
2143 template <typename Sig>
applyVar(const Func<Sig> & func,const VariableP<typename Sig::Arg0> & arg0,const VariableP<typename Sig::Arg1> & arg1,const VariableP<typename Sig::Arg2> & arg2,const VariableP<typename Sig::Arg3> & arg3)2144 ExprP<typename Sig::Ret> applyVar(const Func<Sig> &func, const VariableP<typename Sig::Arg0> &arg0,
2145                                   const VariableP<typename Sig::Arg1> &arg1, const VariableP<typename Sig::Arg2> &arg2,
2146                                   const VariableP<typename Sig::Arg3> &arg3)
2147 {
2148     return exprP(new ApplyVar<Sig>(func, arg0, arg1, arg2, arg3));
2149 }
2150 
2151 template <typename Sig_>
2152 class DerivedFunc : public Func<Sig_>
2153 {
2154 public:
2155     typedef typename DerivedFunc::ArgExprs ArgExprs;
2156     typedef typename DerivedFunc::IRet IRet;
2157     typedef typename DerivedFunc::IArgs IArgs;
2158     typedef typename DerivedFunc::Ret Ret;
2159     typedef typename DerivedFunc::Arg0 Arg0;
2160     typedef typename DerivedFunc::Arg1 Arg1;
2161     typedef typename DerivedFunc::Arg2 Arg2;
2162     typedef typename DerivedFunc::Arg3 Arg3;
2163     typedef typename DerivedFunc::IArg0 IArg0;
2164     typedef typename DerivedFunc::IArg1 IArg1;
2165     typedef typename DerivedFunc::IArg2 IArg2;
2166     typedef typename DerivedFunc::IArg3 IArg3;
2167 
2168 protected:
doPrintDefinition(ostream & os) const2169     void doPrintDefinition(ostream &os) const
2170     {
2171         const ParamNames &paramNames = this->getParamNames();
2172 
2173         initialize();
2174 
2175         os << dataTypeNameOf<Ret>() << " " << this->getName() << "(";
2176         if (isTypeValid<Arg0>())
2177             os << dataTypeNameOf<Arg0>() << " " << paramNames.a;
2178         if (isTypeValid<Arg1>())
2179             os << ", " << dataTypeNameOf<Arg1>() << " " << paramNames.b;
2180         if (isTypeValid<Arg2>())
2181             os << ", " << dataTypeNameOf<Arg2>() << " " << paramNames.c;
2182         if (isTypeValid<Arg3>())
2183             os << ", " << dataTypeNameOf<Arg3>() << " " << paramNames.d;
2184         os << ")\n{\n";
2185 
2186         for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2187             os << *m_body[ndx];
2188         os << "return " << *m_ret << ";\n";
2189         os << "}\n";
2190     }
2191 
doApply(const EvalContext & ctx,const IArgs & args) const2192     IRet doApply(const EvalContext &ctx, const IArgs &args) const
2193     {
2194         Environment funEnv;
2195         IArgs &mutArgs = const_cast<IArgs &>(args);
2196         IRet ret;
2197 
2198         initialize();
2199 
2200         funEnv.bind(*m_var0, args.a);
2201         funEnv.bind(*m_var1, args.b);
2202         funEnv.bind(*m_var2, args.c);
2203         funEnv.bind(*m_var3, args.d);
2204 
2205         {
2206             EvalContext funCtx(ctx.format, ctx.floatPrecision, funEnv, ctx.callDepth);
2207 
2208             for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2209                 m_body[ndx]->execute(funCtx);
2210 
2211             ret = m_ret->evaluate(funCtx);
2212         }
2213 
2214         // \todo [lauri] Store references instead of values in environment
2215         const_cast<IArg0 &>(mutArgs.a) = funEnv.lookup(*m_var0);
2216         const_cast<IArg1 &>(mutArgs.b) = funEnv.lookup(*m_var1);
2217         const_cast<IArg2 &>(mutArgs.c) = funEnv.lookup(*m_var2);
2218         const_cast<IArg3 &>(mutArgs.d) = funEnv.lookup(*m_var3);
2219 
2220         return ret;
2221     }
2222 
doGetUsedFuncs(FuncSet & dst) const2223     void doGetUsedFuncs(FuncSet &dst) const
2224     {
2225         initialize();
2226         if (dst.insert(this).second)
2227         {
2228             for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2229                 m_body[ndx]->getUsedFuncs(dst);
2230             m_ret->getUsedFuncs(dst);
2231         }
2232     }
2233 
2234     virtual ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args_) const = 0;
2235 
2236     // These are transparently initialized when first needed. They cannot be
2237     // initialized in the constructor because they depend on the doExpand
2238     // method of the subclass.
2239 
2240     mutable VariableP<Arg0> m_var0;
2241     mutable VariableP<Arg1> m_var1;
2242     mutable VariableP<Arg2> m_var2;
2243     mutable VariableP<Arg3> m_var3;
2244     mutable vector<StatementP> m_body;
2245     mutable ExprP<Ret> m_ret;
2246 
2247 private:
initialize(void) const2248     void initialize(void) const
2249     {
2250         if (!m_ret)
2251         {
2252             const ParamNames &paramNames = this->getParamNames();
2253             Counter symCounter;
2254             ExpandContext ctx(symCounter);
2255             ArgExprs args;
2256 
2257             args.a = m_var0 = variable<Arg0>(paramNames.a);
2258             args.b = m_var1 = variable<Arg1>(paramNames.b);
2259             args.c = m_var2 = variable<Arg2>(paramNames.c);
2260             args.d = m_var3 = variable<Arg3>(paramNames.d);
2261 
2262             m_ret  = this->doExpand(ctx, args);
2263             m_body = ctx.getStatements();
2264         }
2265     }
2266 };
2267 
2268 template <typename Sig>
2269 class PrimitiveFunc : public Func<Sig>
2270 {
2271 public:
2272     typedef typename PrimitiveFunc::Ret Ret;
2273     typedef typename PrimitiveFunc::ArgExprs ArgExprs;
2274 
2275 protected:
doPrintDefinition(ostream &) const2276     void doPrintDefinition(ostream &) const
2277     {
2278     }
doGetUsedFuncs(FuncSet &) const2279     void doGetUsedFuncs(FuncSet &) const
2280     {
2281     }
2282 };
2283 
2284 template <typename T>
2285 class Cond : public PrimitiveFunc<Signature<T, bool, T, T>>
2286 {
2287 public:
2288     typedef typename Cond::IArgs IArgs;
2289     typedef typename Cond::IRet IRet;
2290 
getName(void) const2291     string getName(void) const
2292     {
2293         return "_cond";
2294     }
2295 
2296 protected:
doPrint(ostream & os,const BaseArgExprs & args) const2297     void doPrint(ostream &os, const BaseArgExprs &args) const
2298     {
2299         os << "(" << *args[0] << " ? " << *args[1] << " : " << *args[2] << ")";
2300     }
2301 
doApply(const EvalContext &,const IArgs & iargs) const2302     IRet doApply(const EvalContext &, const IArgs &iargs) const
2303     {
2304         IRet ret;
2305 
2306         if (iargs.a.contains(true))
2307             ret = unionIVal<T>(ret, iargs.b);
2308 
2309         if (iargs.a.contains(false))
2310             ret = unionIVal<T>(ret, iargs.c);
2311 
2312         return ret;
2313     }
2314 };
2315 
2316 template <typename T>
2317 class CompareOperator : public PrimitiveFunc<Signature<bool, T, T>>
2318 {
2319 public:
2320     typedef typename CompareOperator::IArgs IArgs;
2321     typedef typename CompareOperator::IArg0 IArg0;
2322     typedef typename CompareOperator::IArg1 IArg1;
2323     typedef typename CompareOperator::IRet IRet;
2324 
2325 protected:
doPrint(ostream & os,const BaseArgExprs & args) const2326     void doPrint(ostream &os, const BaseArgExprs &args) const
2327     {
2328         os << "(" << *args[0] << getSymbol() << *args[1] << ")";
2329     }
2330 
doApply(const EvalContext &,const IArgs & iargs) const2331     Interval doApply(const EvalContext &, const IArgs &iargs) const
2332     {
2333         const IArg0 &arg0 = iargs.a;
2334         const IArg1 &arg1 = iargs.b;
2335         IRet ret;
2336 
2337         if (canSucceed(arg0, arg1))
2338             ret |= true;
2339         if (canFail(arg0, arg1))
2340             ret |= false;
2341 
2342         return ret;
2343     }
2344 
2345     virtual string getSymbol(void) const                        = 0;
2346     virtual bool canSucceed(const IArg0 &, const IArg1 &) const = 0;
2347     virtual bool canFail(const IArg0 &, const IArg1 &) const    = 0;
2348 };
2349 
2350 template <typename T>
2351 class LessThan : public CompareOperator<T>
2352 {
2353 public:
getName(void) const2354     string getName(void) const
2355     {
2356         return "lessThan";
2357     }
2358 
2359 protected:
getSymbol(void) const2360     string getSymbol(void) const
2361     {
2362         return "<";
2363     }
2364 
canSucceed(const Interval & a,const Interval & b) const2365     bool canSucceed(const Interval &a, const Interval &b) const
2366     {
2367         return (a.lo() < b.hi());
2368     }
2369 
canFail(const Interval & a,const Interval & b) const2370     bool canFail(const Interval &a, const Interval &b) const
2371     {
2372         return !(a.hi() < b.lo());
2373     }
2374 };
2375 
2376 template <typename T>
operator <(const ExprP<T> & a,const ExprP<T> & b)2377 ExprP<bool> operator<(const ExprP<T> &a, const ExprP<T> &b)
2378 {
2379     return app<LessThan<T>>(a, b);
2380 }
2381 
2382 template <typename T>
cond(const ExprP<bool> & test,const ExprP<T> & consequent,const ExprP<T> & alternative)2383 ExprP<T> cond(const ExprP<bool> &test, const ExprP<T> &consequent, const ExprP<T> &alternative)
2384 {
2385     return app<Cond<T>>(test, consequent, alternative);
2386 }
2387 
2388 /*--------------------------------------------------------------------*//*!
2389  *
2390  * @}
2391  *
2392  *//*--------------------------------------------------------------------*/
2393 //Proper parameters for template T
2394 //    Signature<float, float>        32bit tests
2395 //    Signature<float, deFloat16>    16bit tests
2396 //    Signature<double, double>    64bit tests
2397 template <class T>
2398 class FloatFunc1 : public PrimitiveFunc<T>
2399 {
2400 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0>::IArgs & iargs) const2401     Interval doApply(const EvalContext &ctx,
2402                      const typename Signature<typename T::Ret, typename T::Arg0>::IArgs &iargs) const
2403     {
2404         return this->applyMonotone(ctx, iargs.a);
2405     }
2406 
applyMonotone(const EvalContext & ctx,const Interval & iarg0) const2407     Interval applyMonotone(const EvalContext &ctx, const Interval &iarg0) const
2408     {
2409         Interval ret;
2410 
2411         TCU_INTERVAL_APPLY_MONOTONE1(ret, arg0, iarg0, val,
2412                                      TCU_SET_INTERVAL(val, point, point = this->applyPoint(ctx, arg0)));
2413 
2414         ret |= innerExtrema(ctx, iarg0);
2415         ret &= (this->getCodomain(ctx) | TCU_NAN);
2416 
2417         return ctx.format.convert(ret);
2418     }
2419 
innerExtrema(const EvalContext &,const Interval &) const2420     virtual Interval innerExtrema(const EvalContext &, const Interval &) const
2421     {
2422         return Interval(); // empty interval, i.e. no extrema
2423     }
2424 
applyPoint(const EvalContext & ctx,double arg0) const2425     virtual Interval applyPoint(const EvalContext &ctx, double arg0) const
2426     {
2427         const double exact = this->applyExact(arg0);
2428         const double prec  = this->precision(ctx, exact, arg0);
2429 
2430         return exact + Interval(-prec, prec);
2431     }
2432 
applyExact(double) const2433     virtual double applyExact(double) const
2434     {
2435         TCU_THROW(InternalError, "Cannot apply");
2436     }
2437 
getCodomain(const EvalContext &) const2438     virtual Interval getCodomain(const EvalContext &) const
2439     {
2440         return Interval::unbounded(true);
2441     }
2442 
2443     virtual double precision(const EvalContext &ctx, double, double) const = 0;
2444 };
2445 
2446 /*Proper parameters for template T
2447     Signature<double, double>    64bit tests
2448     Signature<float, float>        32bit tests
2449     Signature<float, deFloat16>    16bit tests*/
2450 template <class T>
2451 class CFloatFunc1 : public FloatFunc1<T>
2452 {
2453 public:
CFloatFunc1(const string & name,tcu::DoubleFunc1 & func)2454     CFloatFunc1(const string &name, tcu::DoubleFunc1 &func) : m_name(name), m_func(func)
2455     {
2456     }
2457 
getName(void) const2458     string getName(void) const
2459     {
2460         return m_name;
2461     }
2462 
2463 protected:
applyExact(double x) const2464     double applyExact(double x) const
2465     {
2466         return m_func(x);
2467     }
2468 
2469     const string m_name;
2470     tcu::DoubleFunc1 &m_func;
2471 };
2472 
2473 //<Signature<float, deFloat16, deFloat16> >
2474 //<Signature<float, float, float> >
2475 //<Signature<double, double, double> >
2476 template <class T>
2477 class FloatFunc2 : public PrimitiveFunc<T>
2478 {
2479 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2480     Interval doApply(const EvalContext &ctx,
2481                      const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs &iargs) const
2482     {
2483         return this->applyMonotone(ctx, iargs.a, iargs.b);
2484     }
2485 
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi) const2486     Interval applyMonotone(const EvalContext &ctx, const Interval &xi, const Interval &yi) const
2487     {
2488         Interval reti;
2489 
2490         TCU_INTERVAL_APPLY_MONOTONE2(reti, x, xi, y, yi, ret,
2491                                      TCU_SET_INTERVAL(ret, point, point = this->applyPoint(ctx, x, y)));
2492         reti |= innerExtrema(ctx, xi, yi);
2493         reti &= (this->getCodomain(ctx) | TCU_NAN);
2494 
2495         return ctx.format.convert(reti);
2496     }
2497 
innerExtrema(const EvalContext &,const Interval &,const Interval &) const2498     virtual Interval innerExtrema(const EvalContext &, const Interval &, const Interval &) const
2499     {
2500         return Interval(); // empty interval, i.e. no extrema
2501     }
2502 
applyPoint(const EvalContext & ctx,double x,double y) const2503     virtual Interval applyPoint(const EvalContext &ctx, double x, double y) const
2504     {
2505         const double exact = this->applyExact(x, y);
2506         const double prec  = this->precision(ctx, exact, x, y);
2507 
2508         return exact + Interval(-prec, prec);
2509     }
2510 
applyExact(double,double) const2511     virtual double applyExact(double, double) const
2512     {
2513         TCU_THROW(InternalError, "Cannot apply");
2514     }
2515 
getCodomain(const EvalContext &) const2516     virtual Interval getCodomain(const EvalContext &) const
2517     {
2518         return Interval::unbounded(true);
2519     }
2520 
2521     virtual double precision(const EvalContext &ctx, double ret, double x, double y) const = 0;
2522 };
2523 
2524 template <class T>
2525 class CFloatFunc2 : public FloatFunc2<T>
2526 {
2527 public:
CFloatFunc2(const string & name,tcu::DoubleFunc2 & func)2528     CFloatFunc2(const string &name, tcu::DoubleFunc2 &func) : m_name(name), m_func(func)
2529     {
2530     }
2531 
getName(void) const2532     string getName(void) const
2533     {
2534         return m_name;
2535     }
2536 
2537 protected:
applyExact(double x,double y) const2538     double applyExact(double x, double y) const
2539     {
2540         return m_func(x, y);
2541     }
2542 
2543     const string m_name;
2544     tcu::DoubleFunc2 &m_func;
2545 };
2546 
2547 template <class T>
2548 class InfixOperator : public FloatFunc2<T>
2549 {
2550 protected:
2551     virtual string getSymbol(void) const = 0;
2552 
doPrint(ostream & os,const BaseArgExprs & args) const2553     void doPrint(ostream &os, const BaseArgExprs &args) const
2554     {
2555         os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2556     }
2557 
applyPoint(const EvalContext & ctx,double x,double y) const2558     Interval applyPoint(const EvalContext &ctx, double x, double y) const
2559     {
2560         const double exact = this->applyExact(x, y);
2561 
2562         // Allow either representable number on both sides of the exact value,
2563         // but require exactly representable values to be preserved.
2564         return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2565     }
2566 
precision(const EvalContext &,double,double,double) const2567     double precision(const EvalContext &, double, double, double) const
2568     {
2569         return 0.0;
2570     }
2571 };
2572 
2573 class InfixOperator16Bit : public FloatFunc2<Signature<float, deFloat16, deFloat16>>
2574 {
2575 protected:
2576     virtual string getSymbol(void) const = 0;
2577 
doPrint(ostream & os,const BaseArgExprs & args) const2578     void doPrint(ostream &os, const BaseArgExprs &args) const
2579     {
2580         os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2581     }
2582 
applyPoint(const EvalContext & ctx,double x,double y) const2583     Interval applyPoint(const EvalContext &ctx, double x, double y) const
2584     {
2585         const double exact = this->applyExact(x, y);
2586 
2587         // Allow either representable number on both sides of the exact value,
2588         // but require exactly representable values to be preserved.
2589         return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2590     }
2591 
precision(const EvalContext &,double,double,double) const2592     double precision(const EvalContext &, double, double, double) const
2593     {
2594         return 0.0;
2595     }
2596 };
2597 
2598 template <class T>
2599 class FloatFunc3 : public PrimitiveFunc<T>
2600 {
2601 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1,typename T::Arg2>::IArgs & iargs) const2602     Interval doApply(const EvalContext &ctx,
2603                      const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1,
2604                                               typename T::Arg2>::IArgs &iargs) const
2605     {
2606         return this->applyMonotone(ctx, iargs.a, iargs.b, iargs.c);
2607     }
2608 
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi,const Interval & zi) const2609     Interval applyMonotone(const EvalContext &ctx, const Interval &xi, const Interval &yi, const Interval &zi) const
2610     {
2611         Interval reti;
2612         TCU_INTERVAL_APPLY_MONOTONE3(reti, x, xi, y, yi, z, zi, ret,
2613                                      TCU_SET_INTERVAL(ret, point, point = this->applyPoint(ctx, x, y, z)));
2614         return ctx.format.convert(reti);
2615     }
2616 
applyPoint(const EvalContext & ctx,double x,double y,double z) const2617     virtual Interval applyPoint(const EvalContext &ctx, double x, double y, double z) const
2618     {
2619         const double exact = this->applyExact(x, y, z);
2620         const double prec  = this->precision(ctx, exact, x, y, z);
2621         return exact + Interval(-prec, prec);
2622     }
2623 
applyExact(double,double,double) const2624     virtual double applyExact(double, double, double) const
2625     {
2626         TCU_THROW(InternalError, "Cannot apply");
2627     }
2628 
2629     virtual double precision(const EvalContext &ctx, double result, double x, double y, double z) const = 0;
2630 };
2631 
2632 // We define syntactic sugar functions for expression constructors. Since
2633 // these have the same names as ordinary mathematical operations (sin, log
2634 // etc.), it's better to give them a dedicated namespace.
2635 namespace Functions
2636 {
2637 
2638 using namespace tcu;
2639 
2640 template <class T>
2641 class Comparison : public InfixOperator<T>
2642 {
2643 public:
getName(void) const2644     string getName(void) const
2645     {
2646         return "comparison";
2647     }
getSymbol(void) const2648     string getSymbol(void) const
2649     {
2650         return "";
2651     }
2652 
getSpirvCase() const2653     SpirVCaseT getSpirvCase() const
2654     {
2655         return SPIRV_CASETYPE_COMPARE;
2656     }
2657 
doApply(const EvalContext & ctx,const typename Comparison<T>::IArgs & iargs) const2658     Interval doApply(const EvalContext &ctx, const typename Comparison<T>::IArgs &iargs) const
2659     {
2660         DE_UNREF(ctx);
2661         if (iargs.a.hasNaN() || iargs.b.hasNaN())
2662         {
2663             return TCU_NAN; // one of the floats is NaN: block analysis
2664         }
2665 
2666         int operationFlag = 1;
2667         int result        = 0;
2668         const double a    = iargs.a.midpoint();
2669         const double b    = iargs.b.midpoint();
2670 
2671         for (int i = 0; i < 2; ++i)
2672         {
2673             if (a == b)
2674                 result += operationFlag;
2675             operationFlag = operationFlag << 1;
2676 
2677             if (a > b)
2678                 result += operationFlag;
2679             operationFlag = operationFlag << 1;
2680 
2681             if (a < b)
2682                 result += operationFlag;
2683             operationFlag = operationFlag << 1;
2684 
2685             if (a >= b)
2686                 result += operationFlag;
2687             operationFlag = operationFlag << 1;
2688 
2689             if (a <= b)
2690                 result += operationFlag;
2691             operationFlag = operationFlag << 1;
2692         }
2693         return result;
2694     }
2695 };
2696 
2697 template <class T>
2698 class Add : public InfixOperator<T>
2699 {
2700 public:
getName(void) const2701     string getName(void) const
2702     {
2703         return "add";
2704     }
getSymbol(void) const2705     string getSymbol(void) const
2706     {
2707         return "+";
2708     }
2709 
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2710     Interval doApply(const EvalContext &ctx,
2711                      const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs &iargs) const
2712     {
2713         // Fast-path for common case
2714         if (iargs.a.isOrdinary(ctx.format.getMaxValue()) && iargs.b.isOrdinary(ctx.format.getMaxValue()))
2715         {
2716             Interval ret;
2717             TCU_SET_INTERVAL_BOUNDS(ret, sum, sum = iargs.a.lo() + iargs.b.lo(), sum = iargs.a.hi() + iargs.b.hi());
2718             return ctx.format.convert(ctx.format.roundOut(ret, true));
2719         }
2720         return this->applyMonotone(ctx, iargs.a, iargs.b);
2721     }
2722 
2723 protected:
applyExact(double x,double y) const2724     double applyExact(double x, double y) const
2725     {
2726         return x + y;
2727     }
2728 };
2729 
2730 template <class T>
2731 class Mul : public InfixOperator<T>
2732 {
2733 public:
getName(void) const2734     string getName(void) const
2735     {
2736         return "mul";
2737     }
getSymbol(void) const2738     string getSymbol(void) const
2739     {
2740         return "*";
2741     }
2742 
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2743     Interval doApply(const EvalContext &ctx,
2744                      const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs &iargs) const
2745     {
2746         Interval a = iargs.a;
2747         Interval b = iargs.b;
2748 
2749         // Fast-path for common case
2750         if (a.isOrdinary(ctx.format.getMaxValue()) && b.isOrdinary(ctx.format.getMaxValue()))
2751         {
2752             Interval ret;
2753             if (a.hi() < 0)
2754             {
2755                 a = -a;
2756                 b = -b;
2757             }
2758             if (a.lo() >= 0 && b.lo() >= 0)
2759             {
2760                 TCU_SET_INTERVAL_BOUNDS(ret, prod, prod = a.lo() * b.lo(), prod = a.hi() * b.hi());
2761                 return ctx.format.convert(ctx.format.roundOut(ret, true));
2762             }
2763             if (a.lo() >= 0 && b.hi() <= 0)
2764             {
2765                 TCU_SET_INTERVAL_BOUNDS(ret, prod, prod = a.hi() * b.lo(), prod = a.lo() * b.hi());
2766                 return ctx.format.convert(ctx.format.roundOut(ret, true));
2767             }
2768         }
2769         return this->applyMonotone(ctx, iargs.a, iargs.b);
2770     }
2771 
2772 protected:
applyExact(double x,double y) const2773     double applyExact(double x, double y) const
2774     {
2775         return x * y;
2776     }
2777 
innerExtrema(const EvalContext &,const Interval & xi,const Interval & yi) const2778     Interval innerExtrema(const EvalContext &, const Interval &xi, const Interval &yi) const
2779     {
2780         if (((xi.contains(-TCU_INFINITY) || xi.contains(TCU_INFINITY)) && yi.contains(0.0)) ||
2781             ((yi.contains(-TCU_INFINITY) || yi.contains(TCU_INFINITY)) && xi.contains(0.0)))
2782             return Interval(TCU_NAN);
2783 
2784         return Interval();
2785     }
2786 };
2787 
2788 template <class T>
2789 class Sub : public InfixOperator<T>
2790 {
2791 public:
getName(void) const2792     string getName(void) const
2793     {
2794         return "sub";
2795     }
getSymbol(void) const2796     string getSymbol(void) const
2797     {
2798         return "-";
2799     }
2800 
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2801     Interval doApply(const EvalContext &ctx,
2802                      const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs &iargs) const
2803     {
2804         // Fast-path for common case
2805         if (iargs.a.isOrdinary(ctx.format.getMaxValue()) && iargs.b.isOrdinary(ctx.format.getMaxValue()))
2806         {
2807             Interval ret;
2808 
2809             TCU_SET_INTERVAL_BOUNDS(ret, diff, diff = iargs.a.lo() - iargs.b.hi(), diff = iargs.a.hi() - iargs.b.lo());
2810             return ctx.format.convert(ctx.format.roundOut(ret, true));
2811         }
2812         else
2813         {
2814             return this->applyMonotone(ctx, iargs.a, iargs.b);
2815         }
2816     }
2817 
2818 protected:
applyExact(double x,double y) const2819     double applyExact(double x, double y) const
2820     {
2821         return x - y;
2822     }
2823 };
2824 
2825 template <class T>
2826 class Negate : public FloatFunc1<T>
2827 {
2828 public:
getName(void) const2829     string getName(void) const
2830     {
2831         return "_negate";
2832     }
doPrint(ostream & os,const BaseArgExprs & args) const2833     void doPrint(ostream &os, const BaseArgExprs &args) const
2834     {
2835         os << "-" << *args[0];
2836     }
2837 
2838 protected:
precision(const EvalContext &,double,double) const2839     double precision(const EvalContext &, double, double) const
2840     {
2841         return 0.0;
2842     }
applyExact(double x) const2843     double applyExact(double x) const
2844     {
2845         return -x;
2846     }
2847 };
2848 
2849 template <class T>
2850 class Div : public InfixOperator<T>
2851 {
2852 public:
getName(void) const2853     string getName(void) const
2854     {
2855         return "div";
2856     }
2857 
2858 protected:
getSymbol(void) const2859     string getSymbol(void) const
2860     {
2861         return "/";
2862     }
2863 
innerExtrema(const EvalContext &,const Interval & nom,const Interval & den) const2864     Interval innerExtrema(const EvalContext &, const Interval &nom, const Interval &den) const
2865     {
2866         Interval ret;
2867 
2868         if (den.contains(0.0))
2869         {
2870             if (nom.contains(0.0))
2871                 ret |= TCU_NAN;
2872 
2873             if (nom.lo() < 0.0 || nom.hi() > 0.0)
2874                 ret |= Interval::unbounded();
2875         }
2876 
2877         return ret;
2878     }
2879 
applyExact(double x,double y) const2880     double applyExact(double x, double y) const
2881     {
2882         return x / y;
2883     }
2884 
applyPoint(const EvalContext & ctx,double x,double y) const2885     Interval applyPoint(const EvalContext &ctx, double x, double y) const
2886     {
2887         Interval ret = FloatFunc2<T>::applyPoint(ctx, x, y);
2888 
2889         if (!deIsInf(x) && !deIsInf(y) && y != 0.0)
2890         {
2891             const Interval dst = ctx.format.convert(ret);
2892             if (dst.contains(-TCU_INFINITY))
2893                 ret |= -ctx.format.getMaxValue();
2894             if (dst.contains(+TCU_INFINITY))
2895                 ret |= +ctx.format.getMaxValue();
2896         }
2897 
2898         return ret;
2899     }
2900 
precision(const EvalContext & ctx,double ret,double,double den) const2901     double precision(const EvalContext &ctx, double ret, double, double den) const
2902     {
2903         const FloatFormat &fmt = ctx.format;
2904 
2905         // \todo [2014-03-05 lauri] Check that the limits in GLSL 3.10 are actually correct.
2906         // For now, we assume that division's precision is 2.5 ULP when the value is within
2907         // [2^MINEXP, 2^MAXEXP-1]
2908 
2909         if (den == 0.0)
2910             return 0.0; // Result must be exactly inf
2911         else if (de::inBounds(deAbs(den), deLdExp(1.0, fmt.getMinExp()), deLdExp(1.0, fmt.getMaxExp() - 1)))
2912             return fmt.ulp(ret, 2.5);
2913         else
2914             return TCU_INFINITY; // Can be any number, but must be a number.
2915     }
2916 };
2917 
2918 template <class T>
2919 class InverseSqrt : public FloatFunc1<T>
2920 {
2921 public:
getName(void) const2922     string getName(void) const
2923     {
2924         return "inversesqrt";
2925     }
2926 
2927 protected:
applyExact(double x) const2928     double applyExact(double x) const
2929     {
2930         return 1.0 / deSqrt(x);
2931     }
2932 
precision(const EvalContext & ctx,double ret,double x) const2933     double precision(const EvalContext &ctx, double ret, double x) const
2934     {
2935         return x <= 0 ? TCU_NAN : ctx.format.ulp(ret, 2.0);
2936     }
2937 
getCodomain(const EvalContext &) const2938     Interval getCodomain(const EvalContext &) const
2939     {
2940         return Interval(0.0, TCU_INFINITY);
2941     }
2942 };
2943 
2944 template <class T>
2945 class ExpFunc : public CFloatFunc1<T>
2946 {
2947 public:
ExpFunc(const string & name,DoubleFunc1 & func)2948     ExpFunc(const string &name, DoubleFunc1 &func) : CFloatFunc1<T>(name, func)
2949     {
2950     }
2951 
2952 protected:
2953     double precision(const EvalContext &ctx, double ret, double x) const;
getCodomain(const EvalContext &) const2954     Interval getCodomain(const EvalContext &) const
2955     {
2956         return Interval(0.0, TCU_INFINITY);
2957     }
2958 };
2959 
2960 template <>
precision(const EvalContext & ctx,double ret,double x) const2961 double ExpFunc<Signature<float, float>>::precision(const EvalContext &ctx, double ret, double x) const
2962 {
2963     switch (ctx.floatPrecision)
2964     {
2965     case glu::PRECISION_HIGHP:
2966         return ctx.format.ulp(ret, 3.0 + 2.0 * deAbs(x));
2967     case glu::PRECISION_MEDIUMP:
2968     case glu::PRECISION_LAST:
2969         return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2970     default:
2971         DE_FATAL("Impossible");
2972     }
2973 
2974     return 0.0;
2975 }
2976 
2977 template <>
precision(const EvalContext & ctx,double ret,double x) const2978 double ExpFunc<Signature<deFloat16, deFloat16>>::precision(const EvalContext &ctx, double ret, double x) const
2979 {
2980     return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2981 }
2982 
2983 template <>
precision(const EvalContext & ctx,double ret,double x) const2984 double ExpFunc<Signature<double, double>>::precision(const EvalContext &ctx, double ret, double x) const
2985 {
2986     return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2987 }
2988 
2989 template <class T>
2990 class Exp2 : public ExpFunc<T>
2991 {
2992 public:
Exp2(void)2993     Exp2(void) : ExpFunc<T>("exp2", deExp2)
2994     {
2995     }
2996 };
2997 template <class T>
2998 class Exp : public ExpFunc<T>
2999 {
3000 public:
Exp(void)3001     Exp(void) : ExpFunc<T>("exp", deExp)
3002     {
3003     }
3004 };
3005 
3006 template <typename T>
exp2(const ExprP<T> & x)3007 ExprP<T> exp2(const ExprP<T> &x)
3008 {
3009     return app<Exp2<Signature<T, T>>>(x);
3010 }
3011 template <typename T>
exp(const ExprP<T> & x)3012 ExprP<T> exp(const ExprP<T> &x)
3013 {
3014     return app<Exp<Signature<T, T>>>(x);
3015 }
3016 
3017 template <class T>
3018 class LogFunc : public CFloatFunc1<T>
3019 {
3020 public:
LogFunc(const string & name,DoubleFunc1 & func)3021     LogFunc(const string &name, DoubleFunc1 &func) : CFloatFunc1<T>(name, func)
3022     {
3023     }
3024 
3025 protected:
3026     double precision(const EvalContext &ctx, double ret, double x) const;
3027 };
3028 
3029 template <>
precision(const EvalContext & ctx,double ret,double x) const3030 double LogFunc<Signature<float, float>>::precision(const EvalContext &ctx, double ret, double x) const
3031 {
3032     if (x <= 0)
3033         return TCU_NAN;
3034 
3035     switch (ctx.floatPrecision)
3036     {
3037     case glu::PRECISION_HIGHP:
3038         return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
3039     case glu::PRECISION_MEDIUMP:
3040     case glu::PRECISION_LAST:
3041         return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
3042     default:
3043         DE_FATAL("Impossible");
3044     }
3045 
3046     return 0;
3047 }
3048 
3049 template <>
precision(const EvalContext & ctx,double ret,double x) const3050 double LogFunc<Signature<deFloat16, deFloat16>>::precision(const EvalContext &ctx, double ret, double x) const
3051 {
3052     if (x <= 0)
3053         return TCU_NAN;
3054     return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
3055 }
3056 
3057 // Spec: "The precision of double-precision instructions is at least that of single precision."
3058 // Lets pick float high precision as a reference.
3059 template <>
precision(const EvalContext & ctx,double ret,double x) const3060 double LogFunc<Signature<double, double>>::precision(const EvalContext &ctx, double ret, double x) const
3061 {
3062     if (x <= 0)
3063         return TCU_NAN;
3064     return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
3065 }
3066 
3067 template <class T>
3068 class Log2 : public LogFunc<T>
3069 {
3070 public:
Log2(void)3071     Log2(void) : LogFunc<T>("log2", deLog2)
3072     {
3073     }
3074 };
3075 template <class T>
3076 class Log : public LogFunc<T>
3077 {
3078 public:
Log(void)3079     Log(void) : LogFunc<T>("log", deLog)
3080     {
3081     }
3082 };
3083 
log2(const ExprP<float> & x)3084 ExprP<float> log2(const ExprP<float> &x)
3085 {
3086     return app<Log2<Signature<float, float>>>(x);
3087 }
log(const ExprP<float> & x)3088 ExprP<float> log(const ExprP<float> &x)
3089 {
3090     return app<Log<Signature<float, float>>>(x);
3091 }
3092 
log2(const ExprP<deFloat16> & x)3093 ExprP<deFloat16> log2(const ExprP<deFloat16> &x)
3094 {
3095     return app<Log2<Signature<deFloat16, deFloat16>>>(x);
3096 }
log(const ExprP<deFloat16> & x)3097 ExprP<deFloat16> log(const ExprP<deFloat16> &x)
3098 {
3099     return app<Log<Signature<deFloat16, deFloat16>>>(x);
3100 }
3101 
log2(const ExprP<double> & x)3102 ExprP<double> log2(const ExprP<double> &x)
3103 {
3104     return app<Log2<Signature<double, double>>>(x);
3105 }
log(const ExprP<double> & x)3106 ExprP<double> log(const ExprP<double> &x)
3107 {
3108     return app<Log<Signature<double, double>>>(x);
3109 }
3110 
3111 #define DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0) \
3112     ExprP<TRET> NAME(const ExprP<T0> &arg0)        \
3113     {                                              \
3114         return app<CLASS>(arg0);                   \
3115     }
3116 
3117 #define DEFINE_DERIVED1(CLASS, TRET, NAME, T0, ARG0, EXPANSION)                   \
3118     class CLASS : public DerivedFunc<Signature<TRET, T0>> /* NOLINT(CLASS) */     \
3119     {                                                                             \
3120     public:                                                                       \
3121         string getName(void) const                                                \
3122         {                                                                         \
3123             return #NAME;                                                         \
3124         }                                                                         \
3125                                                                                   \
3126     protected:                                                                    \
3127         ExprP<TRET> doExpand(ExpandContext &, const CLASS::ArgExprs &args_) const \
3128         {                                                                         \
3129             const ExprP<T0> &ARG0 = args_.a;                                      \
3130             return EXPANSION;                                                     \
3131         }                                                                         \
3132     };                                                                            \
3133     DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
3134 
3135 #define DEFINE_DERIVED_DOUBLE1(CLASS, NAME, ARG0, EXPANSION) \
3136     DEFINE_DERIVED1(CLASS, double, NAME, double, ARG0, EXPANSION)
3137 
3138 #define DEFINE_DERIVED_FLOAT1(CLASS, NAME, ARG0, EXPANSION) DEFINE_DERIVED1(CLASS, float, NAME, float, ARG0, EXPANSION)
3139 
3140 #define DEFINE_DERIVED1_INPUTRANGE(CLASS, TRET, NAME, T0, ARG0, EXPANSION, INTERVAL) \
3141     class CLASS : public DerivedFunc<Signature<TRET, T0>> /* NOLINT(CLASS) */        \
3142     {                                                                                \
3143     public:                                                                          \
3144         string getName(void) const                                                   \
3145         {                                                                            \
3146             return #NAME;                                                            \
3147         }                                                                            \
3148                                                                                      \
3149     protected:                                                                       \
3150         ExprP<TRET> doExpand(ExpandContext &, const CLASS::ArgExprs &args_) const    \
3151         {                                                                            \
3152             const ExprP<T0> &ARG0 = args_.a;                                         \
3153             return EXPANSION;                                                        \
3154         }                                                                            \
3155         Interval getInputRange(const bool /*is16bit*/) const                         \
3156         {                                                                            \
3157             return INTERVAL;                                                         \
3158         }                                                                            \
3159     };                                                                               \
3160     DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
3161 
3162 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
3163     DEFINE_DERIVED1_INPUTRANGE(CLASS, float, NAME, float, ARG0, EXPANSION, INTERVAL)
3164 
3165 #define DEFINE_DERIVED_DOUBLE1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
3166     DEFINE_DERIVED1_INPUTRANGE(CLASS, double, NAME, double, ARG0, EXPANSION, INTERVAL)
3167 
3168 #define DEFINE_DERIVED_FLOAT1_16BIT(CLASS, NAME, ARG0, EXPANSION) \
3169     DEFINE_DERIVED1(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION)
3170 
3171 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
3172     DEFINE_DERIVED1_INPUTRANGE(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION, INTERVAL)
3173 
3174 #define DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1)             \
3175     ExprP<TRET> NAME(const ExprP<T0> &arg0, const ExprP<T1> &arg1) \
3176     {                                                              \
3177         return app<CLASS>(arg0, arg1);                             \
3178     }
3179 
3180 #define DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRVCASE) \
3181     class CLASS : public DerivedFunc<Signature<TRET, T0, T1>> /* NOLINT(CLASS) */          \
3182     {                                                                                      \
3183     public:                                                                                \
3184         string getName(void) const                                                         \
3185         {                                                                                  \
3186             return #NAME;                                                                  \
3187         }                                                                                  \
3188                                                                                            \
3189         SpirVCaseT getSpirvCase(void) const                                                \
3190         {                                                                                  \
3191             return SPIRVCASE;                                                              \
3192         }                                                                                  \
3193                                                                                            \
3194     protected:                                                                             \
3195         ExprP<TRET> doExpand(ExpandContext &, const ArgExprs &args_) const                 \
3196         {                                                                                  \
3197             const ExprP<T0> &Arg0 = args_.a;                                               \
3198             const ExprP<T1> &Arg1 = args_.b;                                               \
3199             return EXPANSION;                                                              \
3200         }                                                                                  \
3201     };                                                                                     \
3202     DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1)
3203 
3204 #define DEFINE_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION) \
3205     DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRV_CASETYPE_NONE)
3206 
3207 #define DEFINE_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
3208     DEFINE_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION)
3209 
3210 #define DEFINE_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
3211     DEFINE_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION)
3212 
3213 #define DEFINE_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION) \
3214     DEFINE_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION)
3215 
3216 #define DEFINE_CASED_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
3217     DEFINE_CASED_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION, SPIRVCASE)
3218 
3219 #define DEFINE_CASED_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
3220     DEFINE_CASED_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION, SPIRVCASE)
3221 
3222 #define DEFINE_CASED_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
3223     DEFINE_CASED_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION, SPIRVCASE)
3224 
3225 #define DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2)                                \
3226     ExprP<TRET> NAME(const ExprP<T0> &arg0, const ExprP<T1> &arg1, const ExprP<T2> &arg2) \
3227     {                                                                                     \
3228         return app<CLASS>(arg0, arg1, arg2);                                              \
3229     }
3230 
3231 #define DEFINE_DERIVED3(CLASS, TRET, NAME, T0, ARG0, T1, ARG1, T2, ARG2, EXPANSION)   \
3232     class CLASS : public DerivedFunc<Signature<TRET, T0, T1, T2>> /* NOLINT(CLASS) */ \
3233     {                                                                                 \
3234     public:                                                                           \
3235         string getName(void) const                                                    \
3236         {                                                                             \
3237             return #NAME;                                                             \
3238         }                                                                             \
3239                                                                                       \
3240     protected:                                                                        \
3241         ExprP<TRET> doExpand(ExpandContext &, const ArgExprs &args_) const            \
3242         {                                                                             \
3243             const ExprP<T0> &ARG0 = args_.a;                                          \
3244             const ExprP<T1> &ARG1 = args_.b;                                          \
3245             const ExprP<T2> &ARG2 = args_.c;                                          \
3246             return EXPANSION;                                                         \
3247         }                                                                             \
3248     };                                                                                \
3249     DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2)
3250 
3251 #define DEFINE_DERIVED_DOUBLE3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3252     DEFINE_DERIVED3(CLASS, double, NAME, double, ARG0, double, ARG1, double, ARG2, EXPANSION)
3253 
3254 #define DEFINE_DERIVED_FLOAT3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3255     DEFINE_DERIVED3(CLASS, float, NAME, float, ARG0, float, ARG1, float, ARG2, EXPANSION)
3256 
3257 #define DEFINE_DERIVED_FLOAT3_16BIT(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3258     DEFINE_DERIVED3(CLASS, deFloat16, NAME, deFloat16, ARG0, deFloat16, ARG1, deFloat16, ARG2, EXPANSION)
3259 
3260 #define DEFINE_CONSTRUCTOR4(CLASS, TRET, NAME, T0, T1, T2, T3)                                                   \
3261     ExprP<TRET> NAME(const ExprP<T0> &arg0, const ExprP<T1> &arg1, const ExprP<T2> &arg2, const ExprP<T3> &arg3) \
3262     {                                                                                                            \
3263         return app<CLASS>(arg0, arg1, arg2, arg3);                                                               \
3264     }
3265 
3266 typedef InverseSqrt<Signature<deFloat16, deFloat16>> InverseSqrt16Bit;
3267 typedef InverseSqrt<Signature<float, float>> InverseSqrt32Bit;
3268 typedef InverseSqrt<Signature<double, double>> InverseSqrt64Bit;
3269 
3270 DEFINE_DERIVED_FLOAT1(Sqrt32Bit, sqrt, x, constant(1.0f) / app<InverseSqrt32Bit>(x))
3271 DEFINE_DERIVED_FLOAT1_16BIT(Sqrt16Bit, sqrt, x, constant((deFloat16)FLOAT16_1_0) / app<InverseSqrt16Bit>(x))
3272 DEFINE_DERIVED_DOUBLE1(Sqrt64Bit, sqrt, x, constant(1.0) / app<InverseSqrt64Bit>(x))
3273 DEFINE_DERIVED_FLOAT2(Pow, pow, x, y, exp2<float>(y *log2(x)))
3274 DEFINE_DERIVED_FLOAT2_16BIT(Pow16, pow, x, y, exp2<deFloat16>(y *log2(x)))
3275 DEFINE_DERIVED_FLOAT1(Radians, radians, d, f32Constant(DE_PI_DOUBLE / 180.0f) * d)
3276 DEFINE_DERIVED_FLOAT1_16BIT(Radians16, radians, d, f16Constant(DE_PI_DOUBLE / 180.0f) * d)
3277 DEFINE_DERIVED_FLOAT1(Degrees, degrees, r, f32Constant(180.0 / DE_PI_DOUBLE) * r)
3278 DEFINE_DERIVED_FLOAT1_16BIT(Degrees16, degrees, r, f16Constant(180.0 / DE_PI_DOUBLE) * r)
3279 
3280 /*Proper parameters for template T
3281     Signature<float, float>        32bit tests
3282     Signature<float, deFloat16>    16bit tests*/
3283 template <class T>
3284 class TrigFunc : public CFloatFunc1<T>
3285 {
3286 public:
TrigFunc(const string & name,DoubleFunc1 & func,const Interval & loEx,const Interval & hiEx)3287     TrigFunc(const string &name, DoubleFunc1 &func, const Interval &loEx, const Interval &hiEx)
3288         : CFloatFunc1<T>(name, func)
3289         , m_loExtremum(loEx)
3290         , m_hiExtremum(hiEx)
3291     {
3292     }
3293 
3294 protected:
innerExtrema(const EvalContext &,const Interval & angle) const3295     Interval innerExtrema(const EvalContext &, const Interval &angle) const
3296     {
3297         const double lo   = angle.lo();
3298         const double hi   = angle.hi();
3299         const int loSlope = doGetSlope(lo);
3300         const int hiSlope = doGetSlope(hi);
3301 
3302         // Detect the high and low values the function can take between the
3303         // interval endpoints.
3304         if (angle.length() >= 2.0 * DE_PI_DOUBLE)
3305         {
3306             // The interval is longer than a full cycle, so it must get all possible values.
3307             return m_hiExtremum | m_loExtremum;
3308         }
3309         else if (loSlope == 1 && hiSlope == -1)
3310         {
3311             // The slope can change from positive to negative only at the maximum value.
3312             return m_hiExtremum;
3313         }
3314         else if (loSlope == -1 && hiSlope == 1)
3315         {
3316             // The slope can change from negative to positive only at the maximum value.
3317             return m_loExtremum;
3318         }
3319         else if (loSlope == hiSlope &&
3320                  deIntSign(CFloatFunc1<T>::applyExact(hi) - CFloatFunc1<T>::applyExact(lo)) * loSlope == -1)
3321         {
3322             // The slope has changed twice between the endpoints, so both extrema are included.
3323             return m_hiExtremum | m_loExtremum;
3324         }
3325 
3326         return Interval();
3327     }
3328 
getCodomain(const EvalContext &) const3329     Interval getCodomain(const EvalContext &) const
3330     {
3331         // Ensure that result is always within [-1, 1], or NaN (for +-inf)
3332         return Interval(-1.0, 1.0) | TCU_NAN;
3333     }
3334 
3335     double precision(const EvalContext &ctx, double ret, double arg) const;
3336 
3337     Interval getInputRange(const bool is16bit) const;
3338     virtual int doGetSlope(double angle) const = 0;
3339 
3340     Interval m_loExtremum;
3341     Interval m_hiExtremum;
3342 };
3343 
3344 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3345 template <>
getInputRange(const bool is16bit) const3346 Interval TrigFunc<Signature<float, float>>::getInputRange(const bool is16bit) const
3347 {
3348     DE_UNREF(is16bit);
3349     return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3350 }
3351 
3352 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3353 template <>
getInputRange(const bool is16bit) const3354 Interval TrigFunc<Signature<deFloat16, deFloat16>>::getInputRange(const bool is16bit) const
3355 {
3356     DE_UNREF(is16bit);
3357     return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3358 }
3359 
3360 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3361 template <>
getInputRange(const bool is16bit) const3362 Interval TrigFunc<Signature<double, double>>::getInputRange(const bool is16bit) const
3363 {
3364     DE_UNREF(is16bit);
3365     return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3366 }
3367 
3368 template <>
precision(const EvalContext & ctx,double ret,double arg) const3369 double TrigFunc<Signature<float, float>>::precision(const EvalContext &ctx, double ret, double arg) const
3370 {
3371     DE_UNREF(ret);
3372     if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3373     {
3374         if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3375             return deLdExp(1.0, -11);
3376         else
3377         {
3378             // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3379             // 2^-11 at x == pi.
3380             return deLdExp(deAbs(arg), -12);
3381         }
3382     }
3383     else
3384     {
3385         DE_ASSERT(ctx.floatPrecision == glu::PRECISION_MEDIUMP || ctx.floatPrecision == glu::PRECISION_LAST);
3386 
3387         if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3388             return deLdExp(1.0, -7);
3389         else
3390         {
3391             // |x| * 2^-8, slightly larger than 2^-7 at x == pi
3392             return deLdExp(deAbs(arg), -8);
3393         }
3394     }
3395 }
3396 //
3397 /*
3398  * Half tests
3399  * From Spec:
3400  * Absolute error 2^{-7} inside the range [-pi, pi].
3401  */
3402 template <>
precision(const EvalContext & ctx,double ret,double arg) const3403 double TrigFunc<Signature<deFloat16, deFloat16>>::precision(const EvalContext &ctx, double ret, double arg) const
3404 {
3405     DE_UNREF(ctx);
3406     DE_UNREF(ret);
3407     DE_UNREF(arg);
3408     DE_ASSERT(-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE && ctx.floatPrecision == glu::PRECISION_LAST);
3409     return deLdExp(1.0, -7);
3410 }
3411 
3412 // Spec: "The precision of double-precision instructions is at least that of single precision."
3413 // Lets pick float high precision as a reference.
3414 template <>
precision(const EvalContext & ctx,double ret,double arg) const3415 double TrigFunc<Signature<double, double>>::precision(const EvalContext &ctx, double ret, double arg) const
3416 {
3417     DE_UNREF(ctx);
3418     DE_UNREF(ret);
3419     if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3420         return deLdExp(1.0, -11);
3421     else
3422     {
3423         // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3424         // 2^-11 at x == pi.
3425         return deLdExp(deAbs(arg), -12);
3426     }
3427 }
3428 
3429 /*Proper parameters for template T
3430     Signature<float, float>        32bit tests
3431     Signature<float, deFloat16>    16bit tests*/
3432 template <class T>
3433 class Sin : public TrigFunc<T>
3434 {
3435 public:
Sin(void)3436     Sin(void) : TrigFunc<T>("sin", deSin, -1.0, 1.0)
3437     {
3438     }
3439 
3440 protected:
doGetSlope(double angle) const3441     int doGetSlope(double angle) const
3442     {
3443         return deIntSign(deCos(angle));
3444     }
3445 };
3446 
sin(const ExprP<float> & x)3447 ExprP<float> sin(const ExprP<float> &x)
3448 {
3449     return app<Sin<Signature<float, float>>>(x);
3450 }
sin(const ExprP<deFloat16> & x)3451 ExprP<deFloat16> sin(const ExprP<deFloat16> &x)
3452 {
3453     return app<Sin<Signature<deFloat16, deFloat16>>>(x);
3454 }
sin(const ExprP<double> & x)3455 ExprP<double> sin(const ExprP<double> &x)
3456 {
3457     return app<Sin<Signature<double, double>>>(x);
3458 }
3459 
3460 template <class T>
3461 class Cos : public TrigFunc<T>
3462 {
3463 public:
Cos(void)3464     Cos(void) : TrigFunc<T>("cos", deCos, -1.0, 1.0)
3465     {
3466     }
3467 
3468 protected:
doGetSlope(double angle) const3469     int doGetSlope(double angle) const
3470     {
3471         return -deIntSign(deSin(angle));
3472     }
3473 };
3474 
cos(const ExprP<float> & x)3475 ExprP<float> cos(const ExprP<float> &x)
3476 {
3477     return app<Cos<Signature<float, float>>>(x);
3478 }
cos(const ExprP<deFloat16> & x)3479 ExprP<deFloat16> cos(const ExprP<deFloat16> &x)
3480 {
3481     return app<Cos<Signature<deFloat16, deFloat16>>>(x);
3482 }
cos(const ExprP<double> & x)3483 ExprP<double> cos(const ExprP<double> &x)
3484 {
3485     return app<Cos<Signature<double, double>>>(x);
3486 }
3487 
3488 DEFINE_DERIVED_FLOAT1_INPUTRANGE(Tan, tan, x, sin(x) * (constant(1.0f) / cos(x)),
3489                                  Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3490 DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(Tan16Bit, tan, x, sin(x) * (constant((deFloat16)FLOAT16_1_0) / cos(x)),
3491                                        Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3492 
3493 template <class T>
3494 class ATan : public CFloatFunc1<T>
3495 {
3496 public:
ATan(void)3497     ATan(void) : CFloatFunc1<T>("atan", deAtanOver)
3498     {
3499     }
3500 
3501 protected:
precision(const EvalContext & ctx,double ret,double) const3502     double precision(const EvalContext &ctx, double ret, double) const
3503     {
3504         if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3505             return ctx.format.ulp(ret, 4096.0);
3506         else
3507             return ctx.format.ulp(ret, 5.0);
3508     }
3509 
getCodomain(const EvalContext & ctx) const3510     Interval getCodomain(const EvalContext &ctx) const
3511     {
3512         return ctx.format.roundOut(Interval(-0.5 * DE_PI_DOUBLE, 0.5 * DE_PI_DOUBLE), true);
3513     }
3514 };
3515 
3516 template <class T>
3517 class ATan2 : public CFloatFunc2<T>
3518 {
3519 public:
ATan2(void)3520     ATan2(void) : CFloatFunc2<T>("atan", deAtan2)
3521     {
3522     }
3523 
3524 protected:
innerExtrema(const EvalContext & ctx,const Interval & yi,const Interval & xi) const3525     Interval innerExtrema(const EvalContext &ctx, const Interval &yi, const Interval &xi) const
3526     {
3527         Interval ret;
3528 
3529         if (yi.contains(0.0))
3530         {
3531             if (xi.contains(0.0))
3532                 ret |= TCU_NAN;
3533             if (xi.intersects(Interval(-TCU_INFINITY, 0.0)))
3534                 ret |= ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3535         }
3536 
3537         if (!yi.isFinite(ctx.format.getMaxValue()) || !xi.isFinite(ctx.format.getMaxValue()))
3538         {
3539             // Infinities may not be supported, allow anything, including NaN
3540             ret |= TCU_NAN;
3541         }
3542 
3543         return ret;
3544     }
3545 
precision(const EvalContext & ctx,double ret,double,double) const3546     double precision(const EvalContext &ctx, double ret, double, double) const
3547     {
3548         if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3549             return ctx.format.ulp(ret, 4096.0);
3550         else
3551             return ctx.format.ulp(ret, 5.0);
3552     }
3553 
getCodomain(const EvalContext & ctx) const3554     Interval getCodomain(const EvalContext &ctx) const
3555     {
3556         return ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3557     }
3558 };
3559 
atan2(const ExprP<float> & x,const ExprP<float> & y)3560 ExprP<float> atan2(const ExprP<float> &x, const ExprP<float> &y)
3561 {
3562     return app<ATan2<Signature<float, float, float>>>(x, y);
3563 }
3564 
atan2(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)3565 ExprP<deFloat16> atan2(const ExprP<deFloat16> &x, const ExprP<deFloat16> &y)
3566 {
3567     return app<ATan2<Signature<deFloat16, deFloat16, deFloat16>>>(x, y);
3568 }
3569 
atan2(const ExprP<double> & x,const ExprP<double> & y)3570 ExprP<double> atan2(const ExprP<double> &x, const ExprP<double> &y)
3571 {
3572     return app<ATan2<Signature<double, double, double>>>(x, y);
3573 }
3574 
3575 DEFINE_DERIVED_FLOAT1(Sinh, sinh, x, (exp<float>(x) - exp<float>(-x)) / constant(2.0f))
3576 DEFINE_DERIVED_FLOAT1(Cosh, cosh, x, (exp<float>(x) + exp<float>(-x)) / constant(2.0f))
3577 DEFINE_DERIVED_FLOAT1(Tanh, tanh, x, sinh(x) / cosh(x))
3578 
3579 DEFINE_DERIVED_FLOAT1_16BIT(Sinh16Bit, sinh, x, (exp(x) - exp(-x)) / constant((deFloat16)FLOAT16_2_0))
3580 DEFINE_DERIVED_FLOAT1_16BIT(Cosh16Bit, cosh, x, (exp(x) + exp(-x)) / constant((deFloat16)FLOAT16_2_0))
3581 DEFINE_DERIVED_FLOAT1_16BIT(Tanh16Bit, tanh, x, sinh(x) / cosh(x))
3582 
3583 DEFINE_DERIVED_FLOAT1(ASin, asin, x, atan2(x, sqrt(constant(1.0f) - x * x)))
3584 DEFINE_DERIVED_FLOAT1(ACos, acos, x, atan2(sqrt(constant(1.0f) - x * x), x))
3585 DEFINE_DERIVED_FLOAT1(ASinh, asinh, x, log(x + sqrt(x * x + constant(1.0f))))
3586 DEFINE_DERIVED_FLOAT1(ACosh, acosh, x,
3587                       log(x +
3588                           sqrt(alternatives((x + constant(1.0f)) * (x - constant(1.0f)), (x * x - constant(1.0f))))))
3589 DEFINE_DERIVED_FLOAT1(ATanh, atanh, x, constant(0.5f) * log((constant(1.0f) + x) / (constant(1.0f) - x)))
3590 
3591 DEFINE_DERIVED_FLOAT1_16BIT(ASin16Bit, asin, x, atan2(x, sqrt(constant((deFloat16)FLOAT16_1_0) - x * x)))
3592 DEFINE_DERIVED_FLOAT1_16BIT(ACos16Bit, acos, x, atan2(sqrt(constant((deFloat16)FLOAT16_1_0) - x * x), x))
3593 DEFINE_DERIVED_FLOAT1_16BIT(ASinh16Bit, asinh, x, log(x + sqrt(x * x + constant((deFloat16)FLOAT16_1_0))))
3594 DEFINE_DERIVED_FLOAT1_16BIT(ACosh16Bit, acosh, x,
3595                             log(x + sqrt(alternatives((x + constant((deFloat16)FLOAT16_1_0)) *
3596                                                           (x - constant((deFloat16)FLOAT16_1_0)),
3597                                                       (x * x - constant((deFloat16)FLOAT16_1_0))))))
3598 DEFINE_DERIVED_FLOAT1_16BIT(ATanh16Bit, atanh, x,
3599                             constant((deFloat16)FLOAT16_0_5) *
3600                                 log((constant((deFloat16)FLOAT16_1_0) + x) / (constant((deFloat16)FLOAT16_1_0) - x)))
3601 
3602 template <typename T>
3603 class GetComponent : public PrimitiveFunc<Signature<typename T::Element, T, int>>
3604 {
3605 public:
3606     typedef typename GetComponent::IRet IRet;
3607 
getName(void) const3608     string getName(void) const
3609     {
3610         return "_getComponent";
3611     }
3612 
print(ostream & os,const BaseArgExprs & args) const3613     void print(ostream &os, const BaseArgExprs &args) const
3614     {
3615         os << *args[0] << "[" << *args[1] << "]";
3616     }
3617 
3618 protected:
doApply(const EvalContext &,const typename GetComponent::IArgs & iargs) const3619     IRet doApply(const EvalContext &, const typename GetComponent::IArgs &iargs) const
3620     {
3621         IRet ret;
3622 
3623         for (int compNdx = 0; compNdx < T::SIZE; ++compNdx)
3624         {
3625             if (iargs.b.contains(compNdx))
3626                 ret = unionIVal<typename T::Element>(ret, iargs.a[compNdx]);
3627         }
3628 
3629         return ret;
3630     }
3631 };
3632 
3633 template <typename T>
getComponent(const ExprP<T> & container,int ndx)3634 ExprP<typename T::Element> getComponent(const ExprP<T> &container, int ndx)
3635 {
3636     DE_ASSERT(0 <= ndx && ndx < T::SIZE);
3637     return app<GetComponent<T>>(container, constant(ndx));
3638 }
3639 
3640 template <typename T>
3641 string vecNamePrefix(void);
3642 template <>
vecNamePrefix(void)3643 string vecNamePrefix<float>(void)
3644 {
3645     return "";
3646 }
3647 template <>
vecNamePrefix(void)3648 string vecNamePrefix<deFloat16>(void)
3649 {
3650     return "";
3651 }
3652 template <>
vecNamePrefix(void)3653 string vecNamePrefix<double>(void)
3654 {
3655     return "d";
3656 }
3657 template <>
vecNamePrefix(void)3658 string vecNamePrefix<int>(void)
3659 {
3660     return "i";
3661 }
3662 template <>
vecNamePrefix(void)3663 string vecNamePrefix<bool>(void)
3664 {
3665     return "b";
3666 }
3667 
3668 template <typename T, int Size>
vecName(void)3669 string vecName(void)
3670 {
3671     return vecNamePrefix<T>() + "vec" + de::toString(Size);
3672 }
3673 
3674 template <typename T, int Size>
3675 class GenVec;
3676 
3677 template <typename T>
3678 class GenVec<T, 1> : public DerivedFunc<Signature<T, T>>
3679 {
3680 public:
3681     typedef typename GenVec<T, 1>::ArgExprs ArgExprs;
3682 
getName(void) const3683     string getName(void) const
3684     {
3685         return "_" + vecName<T, 1>();
3686     }
3687 
3688 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3689     ExprP<T> doExpand(ExpandContext &, const ArgExprs &args) const
3690     {
3691         return args.a;
3692     }
3693 };
3694 
3695 template <typename T>
3696 class GenVec<T, 2> : public PrimitiveFunc<Signature<Vector<T, 2>, T, T>>
3697 {
3698 public:
3699     typedef typename GenVec::IRet IRet;
3700     typedef typename GenVec::IArgs IArgs;
3701 
getName(void) const3702     string getName(void) const
3703     {
3704         return vecName<T, 2>();
3705     }
3706 
3707 protected:
doApply(const EvalContext &,const IArgs & iargs) const3708     IRet doApply(const EvalContext &, const IArgs &iargs) const
3709     {
3710         return IRet(iargs.a, iargs.b);
3711     }
3712 };
3713 
3714 template <typename T>
3715 class GenVec<T, 3> : public PrimitiveFunc<Signature<Vector<T, 3>, T, T, T>>
3716 {
3717 public:
3718     typedef typename GenVec::IRet IRet;
3719     typedef typename GenVec::IArgs IArgs;
3720 
getName(void) const3721     string getName(void) const
3722     {
3723         return vecName<T, 3>();
3724     }
3725 
3726 protected:
doApply(const EvalContext &,const IArgs & iargs) const3727     IRet doApply(const EvalContext &, const IArgs &iargs) const
3728     {
3729         return IRet(iargs.a, iargs.b, iargs.c);
3730     }
3731 };
3732 
3733 template <typename T>
3734 class GenVec<T, 4> : public PrimitiveFunc<Signature<Vector<T, 4>, T, T, T, T>>
3735 {
3736 public:
3737     typedef typename GenVec::IRet IRet;
3738     typedef typename GenVec::IArgs IArgs;
3739 
getName(void) const3740     string getName(void) const
3741     {
3742         return vecName<T, 4>();
3743     }
3744 
3745 protected:
doApply(const EvalContext &,const IArgs & iargs) const3746     IRet doApply(const EvalContext &, const IArgs &iargs) const
3747     {
3748         return IRet(iargs.a, iargs.b, iargs.c, iargs.d);
3749     }
3750 };
3751 
3752 template <typename T, int Rows, int Columns>
3753 class GenMat;
3754 
3755 template <typename T, int Rows>
3756 class GenMat<T, Rows, 2> : public PrimitiveFunc<Signature<Matrix<T, Rows, 2>, Vector<T, Rows>, Vector<T, Rows>>>
3757 {
3758 public:
3759     typedef typename GenMat::Ret Ret;
3760     typedef typename GenMat::IRet IRet;
3761     typedef typename GenMat::IArgs IArgs;
3762 
getName(void) const3763     string getName(void) const
3764     {
3765         return dataTypeNameOf<Ret>();
3766     }
3767 
3768 protected:
doApply(const EvalContext &,const IArgs & iargs) const3769     IRet doApply(const EvalContext &, const IArgs &iargs) const
3770     {
3771         IRet ret;
3772         ret[0] = iargs.a;
3773         ret[1] = iargs.b;
3774         return ret;
3775     }
3776 };
3777 
3778 template <typename T, int Rows>
3779 class GenMat<T, Rows, 3>
3780     : public PrimitiveFunc<Signature<Matrix<T, Rows, 3>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>>>
3781 {
3782 public:
3783     typedef typename GenMat::Ret Ret;
3784     typedef typename GenMat::IRet IRet;
3785     typedef typename GenMat::IArgs IArgs;
3786 
getName(void) const3787     string getName(void) const
3788     {
3789         return dataTypeNameOf<Ret>();
3790     }
3791 
3792 protected:
doApply(const EvalContext &,const IArgs & iargs) const3793     IRet doApply(const EvalContext &, const IArgs &iargs) const
3794     {
3795         IRet ret;
3796         ret[0] = iargs.a;
3797         ret[1] = iargs.b;
3798         ret[2] = iargs.c;
3799         return ret;
3800     }
3801 };
3802 
3803 template <typename T, int Rows>
3804 class GenMat<T, Rows, 4>
3805     : public PrimitiveFunc<
3806           Signature<Matrix<T, Rows, 4>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>>>
3807 {
3808 public:
3809     typedef typename GenMat::Ret Ret;
3810     typedef typename GenMat::IRet IRet;
3811     typedef typename GenMat::IArgs IArgs;
3812 
getName(void) const3813     string getName(void) const
3814     {
3815         return dataTypeNameOf<Ret>();
3816     }
3817 
3818 protected:
doApply(const EvalContext &,const IArgs & iargs) const3819     IRet doApply(const EvalContext &, const IArgs &iargs) const
3820     {
3821         IRet ret;
3822         ret[0] = iargs.a;
3823         ret[1] = iargs.b;
3824         ret[2] = iargs.c;
3825         ret[3] = iargs.d;
3826         return ret;
3827     }
3828 };
3829 
3830 template <typename T, int Rows>
mat2(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1)3831 ExprP<Matrix<T, Rows, 2>> mat2(const ExprP<Vector<T, Rows>> &arg0, const ExprP<Vector<T, Rows>> &arg1)
3832 {
3833     return app<GenMat<T, Rows, 2>>(arg0, arg1);
3834 }
3835 
3836 template <typename T, int Rows>
mat3(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2)3837 ExprP<Matrix<T, Rows, 3>> mat3(const ExprP<Vector<T, Rows>> &arg0, const ExprP<Vector<T, Rows>> &arg1,
3838                                const ExprP<Vector<T, Rows>> &arg2)
3839 {
3840     return app<GenMat<T, Rows, 3>>(arg0, arg1, arg2);
3841 }
3842 
3843 template <typename T, int Rows>
mat4(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2,const ExprP<Vector<T,Rows>> & arg3)3844 ExprP<Matrix<T, Rows, 4>> mat4(const ExprP<Vector<T, Rows>> &arg0, const ExprP<Vector<T, Rows>> &arg1,
3845                                const ExprP<Vector<T, Rows>> &arg2, const ExprP<Vector<T, Rows>> &arg3)
3846 {
3847     return app<GenMat<T, Rows, 4>>(arg0, arg1, arg2, arg3);
3848 }
3849 
3850 template <typename T, int Rows, int Cols>
3851 class MatNeg : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>, Matrix<T, Rows, Cols>>>
3852 {
3853 public:
3854     typedef typename MatNeg::IRet IRet;
3855     typedef typename MatNeg::IArgs IArgs;
3856 
getName(void) const3857     string getName(void) const
3858     {
3859         return "_matNeg";
3860     }
3861 
3862 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3863     void doPrint(ostream &os, const BaseArgExprs &args) const
3864     {
3865         os << "-(" << *args[0] << ")";
3866     }
3867 
doApply(const EvalContext &,const IArgs & iargs) const3868     IRet doApply(const EvalContext &, const IArgs &iargs) const
3869     {
3870         IRet ret;
3871 
3872         for (int col = 0; col < Cols; ++col)
3873         {
3874             for (int row = 0; row < Rows; ++row)
3875                 ret[col][row] = -iargs.a[col][row];
3876         }
3877 
3878         return ret;
3879     }
3880 };
3881 
3882 template <typename T, typename Sig>
3883 class CompWiseFunc : public PrimitiveFunc<Sig>
3884 {
3885 public:
3886     typedef Func<Signature<T, T, T>> ScalarFunc;
3887 
getName(void) const3888     string getName(void) const
3889     {
3890         return doGetScalarFunc().getName();
3891     }
3892 
3893 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3894     void doPrint(ostream &os, const BaseArgExprs &args) const
3895     {
3896         doGetScalarFunc().print(os, args);
3897     }
3898 
3899     virtual const ScalarFunc &doGetScalarFunc(void) const = 0;
3900 };
3901 
3902 template <typename T, int Rows, int Cols>
3903 class CompMatFuncBase
3904     : public CompWiseFunc<T, Signature<Matrix<T, Rows, Cols>, Matrix<T, Rows, Cols>, Matrix<T, Rows, Cols>>>
3905 {
3906 public:
3907     typedef typename CompMatFuncBase::IRet IRet;
3908     typedef typename CompMatFuncBase::IArgs IArgs;
3909 
3910 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const3911     IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
3912     {
3913         IRet ret;
3914 
3915         for (int col = 0; col < Cols; ++col)
3916         {
3917             for (int row = 0; row < Rows; ++row)
3918                 ret[col][row] = this->doGetScalarFunc().apply(ctx, iargs.a[col][row], iargs.b[col][row]);
3919         }
3920 
3921         return ret;
3922     }
3923 };
3924 
3925 template <typename F, typename T, int Rows, int Cols>
3926 class CompMatFunc : public CompMatFuncBase<T, Rows, Cols>
3927 {
3928 protected:
doGetScalarFunc(void) const3929     const typename CompMatFunc::ScalarFunc &doGetScalarFunc(void) const
3930     {
3931         return instance<F>();
3932     }
3933 };
3934 
3935 template <class T>
3936 class ScalarMatrixCompMult : public Mul<Signature<T, T, T>>
3937 {
3938 public:
getName(void) const3939     string getName(void) const
3940     {
3941         return "matrixCompMult";
3942     }
3943 
doPrint(ostream & os,const BaseArgExprs & args) const3944     void doPrint(ostream &os, const BaseArgExprs &args) const
3945     {
3946         Func<Signature<T, T, T>>::doPrint(os, args);
3947     }
3948 };
3949 
3950 template <int Rows, int Cols, class T>
3951 class MatrixCompMult : public CompMatFunc<ScalarMatrixCompMult<T>, T, Rows, Cols>
3952 {
3953 };
3954 
3955 template <int Rows, int Cols>
3956 class ScalarMatFuncBase
3957     : public CompWiseFunc<float, Signature<Matrix<float, Rows, Cols>, Matrix<float, Rows, Cols>, float>>
3958 {
3959 public:
3960     typedef typename ScalarMatFuncBase::IRet IRet;
3961     typedef typename ScalarMatFuncBase::IArgs IArgs;
3962 
3963 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const3964     IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
3965     {
3966         IRet ret;
3967 
3968         for (int col = 0; col < Cols; ++col)
3969         {
3970             for (int row = 0; row < Rows; ++row)
3971                 ret[col][row] = this->doGetScalarFunc().apply(ctx, iargs.a[col][row], iargs.b);
3972         }
3973 
3974         return ret;
3975     }
3976 };
3977 
3978 template <typename F, int Rows, int Cols>
3979 class ScalarMatFunc : public ScalarMatFuncBase<Rows, Cols>
3980 {
3981 protected:
doGetScalarFunc(void) const3982     const typename ScalarMatFunc::ScalarFunc &doGetScalarFunc(void) const
3983     {
3984         return instance<F>();
3985     }
3986 };
3987 
3988 template <typename T, int Size>
3989 struct GenXType;
3990 
3991 template <typename T>
3992 struct GenXType<T, 1>
3993 {
genXTypevkt::shaderexecutor::Functions::GenXType3994     static ExprP<T> genXType(const ExprP<T> &x)
3995     {
3996         return x;
3997     }
3998 };
3999 
4000 template <typename T>
4001 struct GenXType<T, 2>
4002 {
genXTypevkt::shaderexecutor::Functions::GenXType4003     static ExprP<Vector<T, 2>> genXType(const ExprP<T> &x)
4004     {
4005         return app<GenVec<T, 2>>(x, x);
4006     }
4007 };
4008 
4009 template <typename T>
4010 struct GenXType<T, 3>
4011 {
genXTypevkt::shaderexecutor::Functions::GenXType4012     static ExprP<Vector<T, 3>> genXType(const ExprP<T> &x)
4013     {
4014         return app<GenVec<T, 3>>(x, x, x);
4015     }
4016 };
4017 
4018 template <typename T>
4019 struct GenXType<T, 4>
4020 {
genXTypevkt::shaderexecutor::Functions::GenXType4021     static ExprP<Vector<T, 4>> genXType(const ExprP<T> &x)
4022     {
4023         return app<GenVec<T, 4>>(x, x, x, x);
4024     }
4025 };
4026 
4027 //! Returns an expression of vector of size `Size` (or scalar if Size == 1),
4028 //! with each element initialized with the expression `x`.
4029 template <typename T, int Size>
genXType(const ExprP<T> & x)4030 ExprP<typename ContainerOf<T, Size>::Container> genXType(const ExprP<T> &x)
4031 {
4032     return GenXType<T, Size>::genXType(x);
4033 }
4034 
4035 typedef GenVec<float, 2> FloatVec2;
4036 DEFINE_CONSTRUCTOR2(FloatVec2, Vec2, vec2, float, float)
4037 
4038 typedef GenVec<deFloat16, 2> FloatVec2_16bit;
4039 DEFINE_CONSTRUCTOR2(FloatVec2_16bit, Vec2_16Bit, vec2, deFloat16, deFloat16)
4040 
4041 typedef GenVec<double, 2> DoubleVec2;
4042 DEFINE_CONSTRUCTOR2(DoubleVec2, Vec2_64Bit, vec2, double, double)
4043 
4044 typedef GenVec<float, 3> FloatVec3;
4045 DEFINE_CONSTRUCTOR3(FloatVec3, Vec3, vec3, float, float, float)
4046 
4047 typedef GenVec<deFloat16, 3> FloatVec3_16bit;
4048 DEFINE_CONSTRUCTOR3(FloatVec3_16bit, Vec3_16Bit, vec3, deFloat16, deFloat16, deFloat16)
4049 
4050 typedef GenVec<double, 3> DoubleVec3;
4051 DEFINE_CONSTRUCTOR3(DoubleVec3, Vec3_64Bit, vec3, double, double, double)
4052 
4053 typedef GenVec<float, 4> FloatVec4;
4054 DEFINE_CONSTRUCTOR4(FloatVec4, Vec4, vec4, float, float, float, float)
4055 
4056 typedef GenVec<deFloat16, 4> FloatVec4_16bit;
4057 DEFINE_CONSTRUCTOR4(FloatVec4_16bit, Vec4_16Bit, vec4, deFloat16, deFloat16, deFloat16, deFloat16)
4058 
4059 typedef GenVec<double, 4> DoubleVec4;
4060 DEFINE_CONSTRUCTOR4(DoubleVec4, Vec4_64Bit, vec4, double, double, double, double)
4061 
4062 template <class T>
4063 const ExprP<T> getConstZero(void);
4064 template <class T>
4065 const ExprP<T> getConstOne(void);
4066 template <class T>
4067 const ExprP<T> getConstTwo(void);
4068 
4069 template <>
getConstZero(void)4070 const ExprP<float> getConstZero<float>(void)
4071 {
4072     return constant(0.0f);
4073 }
4074 
4075 template <>
getConstZero(void)4076 const ExprP<deFloat16> getConstZero<deFloat16>(void)
4077 {
4078     return constant((deFloat16)FLOAT16_0_0);
4079 }
4080 
4081 template <>
getConstZero(void)4082 const ExprP<double> getConstZero<double>(void)
4083 {
4084     return constant(0.0);
4085 }
4086 
4087 template <>
getConstOne(void)4088 const ExprP<float> getConstOne<float>(void)
4089 {
4090     return constant(1.0f);
4091 }
4092 
4093 template <>
getConstOne(void)4094 const ExprP<deFloat16> getConstOne<deFloat16>(void)
4095 {
4096     return constant((deFloat16)FLOAT16_1_0);
4097 }
4098 
4099 template <>
getConstOne(void)4100 const ExprP<double> getConstOne<double>(void)
4101 {
4102     return constant(1.0);
4103 }
4104 
4105 template <>
getConstTwo(void)4106 const ExprP<float> getConstTwo<float>(void)
4107 {
4108     return constant(2.0f);
4109 }
4110 
4111 template <>
getConstTwo(void)4112 const ExprP<deFloat16> getConstTwo<deFloat16>(void)
4113 {
4114     return constant((deFloat16)FLOAT16_2_0);
4115 }
4116 
4117 template <>
getConstTwo(void)4118 const ExprP<double> getConstTwo<double>(void)
4119 {
4120     return constant(2.0);
4121 }
4122 
4123 template <int Size, class T>
4124 class Dot : public DerivedFunc<Signature<T, Vector<T, Size>, Vector<T, Size>>>
4125 {
4126 public:
4127     typedef typename Dot::ArgExprs ArgExprs;
4128 
getName(void) const4129     string getName(void) const
4130     {
4131         return "dot";
4132     }
4133 
4134 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4135     ExprP<T> doExpand(ExpandContext &, const ArgExprs &args) const
4136     {
4137         ExprP<T> op[Size];
4138         // Precompute all products.
4139         for (int ndx = 0; ndx < Size; ++ndx)
4140             op[ndx] = args.a[ndx] * args.b[ndx];
4141 
4142         int idx[Size];
4143         //Prepare an array of indices.
4144         for (int ndx = 0; ndx < Size; ++ndx)
4145             idx[ndx] = ndx;
4146 
4147         ExprP<T> res = op[0];
4148         // Compute the first dot alternative: SUM(a[i]*b[i]), i = 0 .. Size-1
4149         for (int ndx = 1; ndx < Size; ++ndx)
4150             res = res + op[ndx];
4151 
4152         // Generate all permutations of indices and
4153         // using a permutation compute a dot alternative.
4154         // Generates all possible variants fo summation of products in the dot product expansion expression.
4155         do
4156         {
4157             ExprP<T> alt = getConstZero<T>();
4158             for (int ndx = 0; ndx < Size; ++ndx)
4159                 alt = alt + op[idx[ndx]];
4160             res = alternatives(res, alt);
4161         } while (std::next_permutation(idx, idx + Size));
4162 
4163         return res;
4164     }
4165 };
4166 
4167 template <class T>
4168 class Dot<1, T> : public DerivedFunc<Signature<T, T, T>>
4169 {
4170 public:
4171     typedef typename DerivedFunc<Signature<T, T, T>>::ArgExprs TArgExprs;
4172 
getName(void) const4173     string getName(void) const
4174     {
4175         return "dot";
4176     }
4177 
doExpand(ExpandContext &,const TArgExprs & args) const4178     ExprP<T> doExpand(ExpandContext &, const TArgExprs &args) const
4179     {
4180         return args.a * args.b;
4181     }
4182 };
4183 
4184 template <int Size>
dot(const ExprP<Vector<deFloat16,Size>> & x,const ExprP<Vector<deFloat16,Size>> & y)4185 ExprP<deFloat16> dot(const ExprP<Vector<deFloat16, Size>> &x, const ExprP<Vector<deFloat16, Size>> &y)
4186 {
4187     return app<Dot<Size, deFloat16>>(x, y);
4188 }
4189 
dot(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)4190 ExprP<deFloat16> dot(const ExprP<deFloat16> &x, const ExprP<deFloat16> &y)
4191 {
4192     return app<Dot<1, deFloat16>>(x, y);
4193 }
4194 
4195 template <int Size>
dot(const ExprP<Vector<float,Size>> & x,const ExprP<Vector<float,Size>> & y)4196 ExprP<float> dot(const ExprP<Vector<float, Size>> &x, const ExprP<Vector<float, Size>> &y)
4197 {
4198     return app<Dot<Size, float>>(x, y);
4199 }
4200 
dot(const ExprP<float> & x,const ExprP<float> & y)4201 ExprP<float> dot(const ExprP<float> &x, const ExprP<float> &y)
4202 {
4203     return app<Dot<1, float>>(x, y);
4204 }
4205 
4206 template <int Size>
dot(const ExprP<Vector<double,Size>> & x,const ExprP<Vector<double,Size>> & y)4207 ExprP<double> dot(const ExprP<Vector<double, Size>> &x, const ExprP<Vector<double, Size>> &y)
4208 {
4209     return app<Dot<Size, double>>(x, y);
4210 }
4211 
dot(const ExprP<double> & x,const ExprP<double> & y)4212 ExprP<double> dot(const ExprP<double> &x, const ExprP<double> &y)
4213 {
4214     return app<Dot<1, double>>(x, y);
4215 }
4216 
4217 template <int Size, class T>
4218 class Length : public DerivedFunc<Signature<T, typename ContainerOf<T, Size>::Container>>
4219 {
4220 public:
4221     typedef typename Length::ArgExprs ArgExprs;
4222 
getName(void) const4223     string getName(void) const
4224     {
4225         return "length";
4226     }
4227 
4228 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4229     ExprP<T> doExpand(ExpandContext &, const ArgExprs &args) const
4230     {
4231         return sqrt(dot(args.a, args.a));
4232     }
4233 };
4234 
4235 template <class T, class TRet>
length(const ExprP<T> & x)4236 ExprP<TRet> length(const ExprP<T> &x)
4237 {
4238     return app<Length<1, T>>(x);
4239 }
4240 
4241 template <int Size, class T, class TRet>
length(const ExprP<typename ContainerOf<T,Size>::Container> & x)4242 ExprP<TRet> length(const ExprP<typename ContainerOf<T, Size>::Container> &x)
4243 {
4244     return app<Length<Size, T>>(x);
4245 }
4246 
4247 template <int Size, class T>
4248 class Distance : public DerivedFunc<
4249                      Signature<T, typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container>>
4250 {
4251 public:
4252     typedef typename Distance::Ret Ret;
4253     typedef typename Distance::ArgExprs ArgExprs;
4254 
getName(void) const4255     string getName(void) const
4256     {
4257         return "distance";
4258     }
4259 
4260 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4261     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
4262     {
4263         return length<Size, T, Ret>(args.a - args.b);
4264     }
4265 };
4266 
4267 // cross
4268 
4269 class Cross : public DerivedFunc<Signature<Vec3, Vec3, Vec3>>
4270 {
4271 public:
getName(void) const4272     string getName(void) const
4273     {
4274         return "cross";
4275     }
4276 
4277 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4278     ExprP<Vec3> doExpand(ExpandContext &, const ArgExprs &x) const
4279     {
4280         return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2], x.a[2] * x.b[0] - x.b[2] * x.a[0],
4281                     x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4282     }
4283 };
4284 
4285 class Cross16Bit : public DerivedFunc<Signature<Vec3_16Bit, Vec3_16Bit, Vec3_16Bit>>
4286 {
4287 public:
getName(void) const4288     string getName(void) const
4289     {
4290         return "cross";
4291     }
4292 
4293 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4294     ExprP<Vec3_16Bit> doExpand(ExpandContext &, const ArgExprs &x) const
4295     {
4296         return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2], x.a[2] * x.b[0] - x.b[2] * x.a[0],
4297                     x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4298     }
4299 };
4300 
4301 class Cross64Bit : public DerivedFunc<Signature<Vec3_64Bit, Vec3_64Bit, Vec3_64Bit>>
4302 {
4303 public:
getName(void) const4304     string getName(void) const
4305     {
4306         return "cross";
4307     }
4308 
4309 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4310     ExprP<Vec3_64Bit> doExpand(ExpandContext &, const ArgExprs &x) const
4311     {
4312         return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2], x.a[2] * x.b[0] - x.b[2] * x.a[0],
4313                     x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4314     }
4315 };
4316 
4317 DEFINE_CONSTRUCTOR2(Cross, Vec3, cross, Vec3, Vec3)
4318 DEFINE_CONSTRUCTOR2(Cross16Bit, Vec3_16Bit, cross, Vec3_16Bit, Vec3_16Bit)
4319 DEFINE_CONSTRUCTOR2(Cross64Bit, Vec3_64Bit, cross, Vec3_64Bit, Vec3_64Bit)
4320 
4321 template <int Size, class T>
4322 class Normalize
4323     : public DerivedFunc<Signature<typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container>>
4324 {
4325 public:
4326     typedef typename Normalize::Ret Ret;
4327     typedef typename Normalize::ArgExprs ArgExprs;
4328 
getName(void) const4329     string getName(void) const
4330     {
4331         return "normalize";
4332     }
4333 
4334 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4335     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
4336     {
4337         return args.a * app<InverseSqrt<Signature<T, T>>>(dot(args.a, args.a));
4338     }
4339 };
4340 
4341 template <int Size, class T>
4342 class FaceForward
4343     : public DerivedFunc<Signature<typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container,
4344                                    typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container>>
4345 {
4346 public:
4347     typedef typename FaceForward::Ret Ret;
4348     typedef typename FaceForward::ArgExprs ArgExprs;
4349 
getName(void) const4350     string getName(void) const
4351     {
4352         return "faceforward";
4353     }
4354 
4355 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4356     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
4357     {
4358         return cond(dot(args.c, args.b) < getConstZero<T>(), args.a, -args.a);
4359     }
4360 };
4361 
4362 template <int Size, class T>
4363 class Reflect
4364     : public DerivedFunc<Signature<typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container,
4365                                    typename ContainerOf<T, Size>::Container>>
4366 {
4367 public:
4368     typedef typename Reflect::Ret Ret;
4369     typedef typename Reflect::Arg0 Arg0;
4370     typedef typename Reflect::Arg1 Arg1;
4371     typedef typename Reflect::ArgExprs ArgExprs;
4372 
getName(void) const4373     string getName(void) const
4374     {
4375         return "reflect";
4376     }
4377 
4378 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4379     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
4380     {
4381         const ExprP<Arg0> &i = args.a;
4382         const ExprP<Arg1> &n = args.b;
4383         const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4384 
4385         return i - alternatives(
4386                        (n * dotNI) * getConstTwo<T>(),
4387                        alternatives(n * (dotNI * getConstTwo<T>()),
4388                                     alternatives(n * dot(i * getConstTwo<T>(), n), n * dot(i, n * getConstTwo<T>()))));
4389     }
4390 };
4391 
4392 template <class T>
4393 class Reflect<1, T> : public DerivedFunc<Signature<T, T, T>>
4394 {
4395 public:
4396     typedef typename Reflect::Ret Ret;
4397     typedef typename Reflect::Arg0 Arg0;
4398     typedef typename Reflect::Arg1 Arg1;
4399     typedef typename Reflect::ArgExprs ArgExprs;
4400 
getName(void) const4401     string getName(void) const
4402     {
4403         return "reflect";
4404     }
4405 
4406 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4407     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
4408     {
4409         const ExprP<Arg0> &i = args.a;
4410         const ExprP<Arg1> &n = args.b;
4411         const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4412 
4413         return i - alternatives((n * dotNI) * getConstTwo<T>(),
4414                                 alternatives(n * (dotNI * getConstTwo<T>()),
4415                                              alternatives(n * dot(i * getConstTwo<T>(), n),
4416                                                           alternatives(n * dot(i, n * getConstTwo<T>()),
4417                                                                        dot(n * n, i * getConstTwo<T>())))));
4418     }
4419 };
4420 
4421 template <int Size, class T>
4422 class Refract
4423     : public DerivedFunc<Signature<typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container,
4424                                    typename ContainerOf<T, Size>::Container, T>>
4425 {
4426 public:
4427     typedef typename Refract::Ret Ret;
4428     typedef typename Refract::Arg0 Arg0;
4429     typedef typename Refract::Arg1 Arg1;
4430     typedef typename Refract::Arg2 Arg2;
4431     typedef typename Refract::ArgExprs ArgExprs;
4432 
getName(void) const4433     string getName(void) const
4434     {
4435         return "refract";
4436     }
4437 
4438 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4439     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
4440     {
4441         const ExprP<Arg0> &i   = args.a;
4442         const ExprP<Arg1> &n   = args.b;
4443         const ExprP<Arg2> &eta = args.c;
4444         const ExprP<T> dotNI   = bindExpression("dotNI", ctx, dot(n, i));
4445         const ExprP<T> k = bindExpression("k", ctx, getConstOne<T>() - eta * eta * (getConstOne<T>() - dotNI * dotNI));
4446         return cond(k < getConstZero<T>(), genXType<T, Size>(getConstZero<T>()), i * eta - n * (eta * dotNI + sqrt(k)));
4447     }
4448 };
4449 
4450 template <class T>
4451 class PreciseFunc1 : public CFloatFunc1<T>
4452 {
4453 public:
PreciseFunc1(const string & name,DoubleFunc1 & func)4454     PreciseFunc1(const string &name, DoubleFunc1 &func) : CFloatFunc1<T>(name, func)
4455     {
4456     }
4457 
4458 protected:
precision(const EvalContext &,double,double) const4459     double precision(const EvalContext &, double, double) const
4460     {
4461         return 0.0;
4462     }
4463 };
4464 
4465 template <class T>
4466 class Abs : public PreciseFunc1<T>
4467 {
4468 public:
Abs(void)4469     Abs(void) : PreciseFunc1<T>("abs", deAbs)
4470     {
4471     }
4472 };
4473 
4474 template <class T>
4475 class Sign : public PreciseFunc1<T>
4476 {
4477 public:
Sign(void)4478     Sign(void) : PreciseFunc1<T>("sign", deSign)
4479     {
4480     }
4481 };
4482 
4483 template <class T>
4484 class Floor : public PreciseFunc1<T>
4485 {
4486 public:
Floor(void)4487     Floor(void) : PreciseFunc1<T>("floor", deFloor)
4488     {
4489     }
4490 };
4491 
4492 template <class T>
4493 class Trunc : public PreciseFunc1<T>
4494 {
4495 public:
Trunc(void)4496     Trunc(void) : PreciseFunc1<T>("trunc", deTrunc)
4497     {
4498     }
4499 };
4500 
4501 template <class T>
4502 class Round : public FloatFunc1<T>
4503 {
4504 public:
getName(void) const4505     string getName(void) const
4506     {
4507         return "round";
4508     }
4509 
4510 protected:
applyPoint(const EvalContext &,double x) const4511     Interval applyPoint(const EvalContext &, double x) const
4512     {
4513         double truncated   = 0.0;
4514         const double fract = deModf(x, &truncated);
4515         Interval ret;
4516 
4517         if (fabs(fract) <= 0.5)
4518             ret |= truncated;
4519         if (fabs(fract) >= 0.5)
4520             ret |= truncated + deSign(fract);
4521 
4522         return ret;
4523     }
4524 
precision(const EvalContext &,double,double) const4525     double precision(const EvalContext &, double, double) const
4526     {
4527         return 0.0;
4528     }
4529 };
4530 
4531 template <class T>
4532 class RoundEven : public PreciseFunc1<T>
4533 {
4534 public:
RoundEven(void)4535     RoundEven(void) : PreciseFunc1<T>("roundEven", deRoundEven)
4536     {
4537     }
4538 };
4539 
4540 template <class T>
4541 class Ceil : public PreciseFunc1<T>
4542 {
4543 public:
Ceil(void)4544     Ceil(void) : PreciseFunc1<T>("ceil", deCeil)
4545     {
4546     }
4547 };
4548 
4549 typedef Floor<Signature<float, float>> Floor32Bit;
4550 typedef Floor<Signature<deFloat16, deFloat16>> Floor16Bit;
4551 typedef Floor<Signature<double, double>> Floor64Bit;
4552 
4553 typedef Trunc<Signature<float, float>> Trunc32Bit;
4554 typedef Trunc<Signature<deFloat16, deFloat16>> Trunc16Bit;
4555 typedef Trunc<Signature<double, double>> Trunc64Bit;
4556 
4557 typedef Trunc<Signature<float, float>> Trunc32Bit;
4558 typedef Trunc<Signature<deFloat16, deFloat16>> Trunc16Bit;
4559 
4560 DEFINE_DERIVED_FLOAT1(Fract, fract, x, x - app<Floor32Bit>(x))
4561 DEFINE_DERIVED_FLOAT1_16BIT(Fract16Bit, fract, x, x - app<Floor16Bit>(x))
4562 DEFINE_DERIVED_DOUBLE1(Fract64Bit, fract, x, x - app<Floor64Bit>(x))
4563 
4564 template <class T>
4565 class PreciseFunc2 : public CFloatFunc2<T>
4566 {
4567 public:
PreciseFunc2(const string & name,DoubleFunc2 & func)4568     PreciseFunc2(const string &name, DoubleFunc2 &func) : CFloatFunc2<T>(name, func)
4569     {
4570     }
4571 
4572 protected:
precision(const EvalContext &,double,double,double) const4573     double precision(const EvalContext &, double, double, double) const
4574     {
4575         return 0.0;
4576     }
4577 };
4578 
4579 DEFINE_DERIVED_FLOAT2(Mod32Bit, mod, x, y, x - y * app<Floor32Bit>(x / y))
4580 DEFINE_DERIVED_FLOAT2_16BIT(Mod16Bit, mod, x, y, x - y * app<Floor16Bit>(x / y))
4581 DEFINE_DERIVED_DOUBLE2(Mod64Bit, mod, x, y, x - y * app<Floor64Bit>(x / y))
4582 
4583 DEFINE_CASED_DERIVED_FLOAT2(FRem32Bit, frem, x, y, x - y * app<Trunc32Bit>(x / y), SPIRV_CASETYPE_FREM)
4584 DEFINE_CASED_DERIVED_FLOAT2_16BIT(FRem16Bit, frem, x, y, x - y * app<Trunc16Bit>(x / y), SPIRV_CASETYPE_FREM)
4585 DEFINE_CASED_DERIVED_DOUBLE2(FRem64Bit, frem, x, y, x - y * app<Trunc64Bit>(x / y), SPIRV_CASETYPE_FREM)
4586 
4587 template <class T>
4588 class Modf : public PrimitiveFunc<T>
4589 {
4590 public:
4591     typedef typename Modf<T>::IArgs TIArgs;
4592     typedef typename Modf<T>::IRet TIRet;
getName(void) const4593     string getName(void) const
4594     {
4595         return "modf";
4596     }
4597 
4598 protected:
doApply(const EvalContext & ctx,const TIArgs & iargs) const4599     TIRet doApply(const EvalContext &ctx, const TIArgs &iargs) const
4600     {
4601         Interval fracIV;
4602         Interval &wholeIV = const_cast<Interval &>(iargs.b);
4603         double intPart    = 0;
4604 
4605         TCU_INTERVAL_APPLY_MONOTONE1(fracIV, x, iargs.a, frac, frac = deModf(x, &intPart));
4606         TCU_INTERVAL_APPLY_MONOTONE1(wholeIV, x, iargs.a, whole, deModf(x, &intPart); whole = intPart);
4607 
4608         if (!iargs.a.isFinite(ctx.format.getMaxValue()))
4609         {
4610             // Behavior on modf(Inf) not well-defined, allow anything as a fractional part
4611             // See Khronos bug 13907
4612             fracIV |= TCU_NAN;
4613         }
4614 
4615         return fracIV;
4616     }
4617 
getOutParamIndex(void) const4618     int getOutParamIndex(void) const
4619     {
4620         return 1;
4621     }
4622 };
4623 typedef Modf<Signature<float, float, float>> Modf32Bit;
4624 typedef Modf<Signature<deFloat16, deFloat16, deFloat16>> Modf16Bit;
4625 typedef Modf<Signature<double, double, double>> Modf64Bit;
4626 
4627 template <class T>
4628 class ModfStruct : public Modf<T>
4629 {
4630 public:
getName(void) const4631     virtual string getName(void) const
4632     {
4633         return "modfstruct";
4634     }
getSpirvCase(void) const4635     virtual SpirVCaseT getSpirvCase(void) const
4636     {
4637         return SPIRV_CASETYPE_MODFSTRUCT;
4638     }
4639 };
4640 typedef ModfStruct<Signature<float, float, float>> ModfStruct32Bit;
4641 typedef ModfStruct<Signature<deFloat16, deFloat16, deFloat16>> ModfStruct16Bit;
4642 typedef ModfStruct<Signature<double, double, double>> ModfStruct64Bit;
4643 
4644 template <class T>
4645 class Min : public PreciseFunc2<T>
4646 {
4647 public:
Min(void)4648     Min(void) : PreciseFunc2<T>("min", deMin)
4649     {
4650     }
4651 };
4652 template <class T>
4653 class Max : public PreciseFunc2<T>
4654 {
4655 public:
Max(void)4656     Max(void) : PreciseFunc2<T>("max", deMax)
4657     {
4658     }
4659 };
4660 
4661 template <class T>
4662 class Clamp : public FloatFunc3<T>
4663 {
4664 public:
getName(void) const4665     string getName(void) const
4666     {
4667         return "clamp";
4668     }
4669 
applyExact(double x,double minVal,double maxVal) const4670     double applyExact(double x, double minVal, double maxVal) const
4671     {
4672         return de::min(de::max(x, minVal), maxVal);
4673     }
4674 
precision(const EvalContext &,double,double,double minVal,double maxVal) const4675     double precision(const EvalContext &, double, double, double minVal, double maxVal) const
4676     {
4677         return minVal > maxVal ? TCU_NAN : 0.0;
4678     }
4679 };
4680 
clamp(const ExprP<deFloat16> & x,const ExprP<deFloat16> & minVal,const ExprP<deFloat16> & maxVal)4681 ExprP<deFloat16> clamp(const ExprP<deFloat16> &x, const ExprP<deFloat16> &minVal, const ExprP<deFloat16> &maxVal)
4682 {
4683     return app<Clamp<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>>(x, minVal, maxVal);
4684 }
4685 
clamp(const ExprP<float> & x,const ExprP<float> & minVal,const ExprP<float> & maxVal)4686 ExprP<float> clamp(const ExprP<float> &x, const ExprP<float> &minVal, const ExprP<float> &maxVal)
4687 {
4688     return app<Clamp<Signature<float, float, float, float>>>(x, minVal, maxVal);
4689 }
4690 
clamp(const ExprP<double> & x,const ExprP<double> & minVal,const ExprP<double> & maxVal)4691 ExprP<double> clamp(const ExprP<double> &x, const ExprP<double> &minVal, const ExprP<double> &maxVal)
4692 {
4693     return app<Clamp<Signature<double, double, double, double>>>(x, minVal, maxVal);
4694 }
4695 
4696 template <class T>
4697 class NanIfGreaterOrEqual : public FloatFunc2<T>
4698 {
4699 public:
getName(void) const4700     string getName(void) const
4701     {
4702         return "nanIfGreaterOrEqual";
4703     }
4704 
applyExact(double edge0,double edge1) const4705     double applyExact(double edge0, double edge1) const
4706     {
4707         return (edge0 >= edge1) ? TCU_NAN : 0.0;
4708     }
4709 
precision(const EvalContext &,double,double edge0,double edge1) const4710     double precision(const EvalContext &, double, double edge0, double edge1) const
4711     {
4712         return (edge0 >= edge1) ? TCU_NAN : 0.0;
4713     }
4714 };
4715 
nanIfGreaterOrEqual(const ExprP<deFloat16> & edge0,const ExprP<deFloat16> & edge1)4716 ExprP<deFloat16> nanIfGreaterOrEqual(const ExprP<deFloat16> &edge0, const ExprP<deFloat16> &edge1)
4717 {
4718     return app<NanIfGreaterOrEqual<Signature<deFloat16, deFloat16, deFloat16>>>(edge0, edge1);
4719 }
4720 
nanIfGreaterOrEqual(const ExprP<float> & edge0,const ExprP<float> & edge1)4721 ExprP<float> nanIfGreaterOrEqual(const ExprP<float> &edge0, const ExprP<float> &edge1)
4722 {
4723     return app<NanIfGreaterOrEqual<Signature<float, float, float>>>(edge0, edge1);
4724 }
4725 
nanIfGreaterOrEqual(const ExprP<double> & edge0,const ExprP<double> & edge1)4726 ExprP<double> nanIfGreaterOrEqual(const ExprP<double> &edge0, const ExprP<double> &edge1)
4727 {
4728     return app<NanIfGreaterOrEqual<Signature<double, double, double>>>(edge0, edge1);
4729 }
4730 
4731 DEFINE_DERIVED_FLOAT3(Mix, mix, x, y, a, alternatives((x * (constant(1.0f) - a)) + y * a, x + (y - x) * a))
4732 
4733 DEFINE_DERIVED_FLOAT3_16BIT(Mix16Bit, mix, x, y, a,
4734                             alternatives((x * (constant((deFloat16)FLOAT16_1_0) - a)) + y * a, x + (y - x) * a))
4735 
4736 DEFINE_DERIVED_DOUBLE3(Mix64Bit, mix, x, y, a, alternatives((x * (constant(1.0) - a)) + y * a, x + (y - x) * a))
4737 
step(double edge,double x)4738 static double step(double edge, double x)
4739 {
4740     return x < edge ? 0.0 : 1.0;
4741 }
4742 
4743 template <class T>
4744 class Step : public PreciseFunc2<T>
4745 {
4746 public:
Step(void)4747     Step(void) : PreciseFunc2<T>("step", step)
4748     {
4749     }
4750 };
4751 
4752 template <class T>
4753 class SmoothStep : public DerivedFunc<T>
4754 {
4755 public:
4756     typedef typename SmoothStep<T>::ArgExprs TArgExprs;
4757     typedef typename SmoothStep<T>::Ret TRet;
getName(void) const4758     string getName(void) const
4759     {
4760         return "smoothstep";
4761     }
4762 
4763 protected:
4764     ExprP<TRet> doExpand(ExpandContext &ctx, const TArgExprs &args) const;
4765 };
4766 
4767 template <>
4768 ExprP<SmoothStep<Signature<float, float, float, float>>::Ret> SmoothStep<Signature<float, float, float, float>>::
doExpand(ExpandContext & ctx,const SmoothStep<Signature<float,float,float,float>>::ArgExprs & args) const4769     doExpand(ExpandContext &ctx, const SmoothStep<Signature<float, float, float, float>>::ArgExprs &args) const
4770 {
4771     const ExprP<float> &edge0 = args.a;
4772     const ExprP<float> &edge1 = args.b;
4773     const ExprP<float> &x     = args.c;
4774     const ExprP<float> tExpr =
4775         clamp((x - edge0) / (edge1 - edge0), constant(0.0f), constant(1.0f)) +
4776         nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4777     const ExprP<float> t = bindExpression("t", ctx, tExpr);
4778 
4779     return (t * t * (constant(3.0f) - constant(2.0f) * t));
4780 }
4781 
4782 template <>
4783 ExprP<SmoothStep<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>::TRet> SmoothStep<
doExpand(ExpandContext & ctx,const TArgExprs & args) const4784     Signature<deFloat16, deFloat16, deFloat16, deFloat16>>::doExpand(ExpandContext &ctx, const TArgExprs &args) const
4785 {
4786     const ExprP<deFloat16> &edge0 = args.a;
4787     const ExprP<deFloat16> &edge1 = args.b;
4788     const ExprP<deFloat16> &x     = args.c;
4789     const ExprP<deFloat16> tExpr =
4790         clamp((x - edge0) / (edge1 - edge0), constant((deFloat16)FLOAT16_0_0), constant((deFloat16)FLOAT16_1_0)) +
4791         nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4792     const ExprP<deFloat16> t = bindExpression("t", ctx, tExpr);
4793 
4794     return (t * t * (constant((deFloat16)FLOAT16_3_0) - constant((deFloat16)FLOAT16_2_0) * t));
4795 }
4796 
4797 template <>
4798 ExprP<SmoothStep<Signature<double, double, double, double>>::Ret> SmoothStep<
4799     Signature<double, double, double, double>>::
doExpand(ExpandContext & ctx,const SmoothStep<Signature<double,double,double,double>>::ArgExprs & args) const4800     doExpand(ExpandContext &ctx, const SmoothStep<Signature<double, double, double, double>>::ArgExprs &args) const
4801 {
4802     const ExprP<double> &edge0 = args.a;
4803     const ExprP<double> &edge1 = args.b;
4804     const ExprP<double> &x     = args.c;
4805     const ExprP<double> tExpr =
4806         clamp((x - edge0) / (edge1 - edge0), constant(0.0), constant(1.0)) +
4807         nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4808     const ExprP<double> t = bindExpression("t", ctx, tExpr);
4809 
4810     return (t * t * (constant(3.0) - constant(2.0) * t));
4811 }
4812 
4813 //Signature<float, float, int>
4814 //Signature<float, deFloat16, int>
4815 //Signature<double, double, int>
4816 template <class T>
4817 class FrExp : public PrimitiveFunc<T>
4818 {
4819 public:
getName(void) const4820     string getName(void) const
4821     {
4822         return "frexp";
4823     }
4824 
4825     typedef typename FrExp::IRet IRet;
4826     typedef typename FrExp::IArgs IArgs;
4827     typedef typename FrExp::IArg0 IArg0;
4828     typedef typename FrExp::IArg1 IArg1;
4829 
4830 protected:
doApply(const EvalContext &,const IArgs & iargs) const4831     IRet doApply(const EvalContext &, const IArgs &iargs) const
4832     {
4833         IRet ret;
4834         const IArg0 &x  = iargs.a;
4835         IArg1 &exponent = const_cast<IArg1 &>(iargs.b);
4836 
4837         if (x.hasNaN() || x.contains(TCU_INFINITY) || x.contains(-TCU_INFINITY))
4838         {
4839             // GLSL (in contrast to IEEE) says that result of applying frexp
4840             // to infinity is undefined
4841             ret      = Interval::unbounded() | TCU_NAN;
4842             exponent = Interval(-deLdExp(1.0, 31), deLdExp(1.0, 31) - 1);
4843         }
4844         else if (!x.empty())
4845         {
4846             int loExp           = 0;
4847             const double loFrac = deFrExp(x.lo(), &loExp);
4848             int hiExp           = 0;
4849             const double hiFrac = deFrExp(x.hi(), &hiExp);
4850 
4851             if (deSign(loFrac) != deSign(hiFrac))
4852             {
4853                 exponent = Interval(-TCU_INFINITY, de::max(loExp, hiExp));
4854                 ret      = Interval();
4855                 if (deSign(loFrac) < 0)
4856                     ret |= Interval(-1.0 + DBL_EPSILON * 0.5, 0.0);
4857                 if (deSign(hiFrac) > 0)
4858                     ret |= Interval(0.0, 1.0 - DBL_EPSILON * 0.5);
4859             }
4860             else
4861             {
4862                 exponent = Interval(loExp, hiExp);
4863                 if (loExp == hiExp)
4864                     ret = Interval(loFrac, hiFrac);
4865                 else
4866                     ret = deSign(loFrac) * Interval(0.5, 1.0 - DBL_EPSILON * 0.5);
4867             }
4868         }
4869 
4870         return ret;
4871     }
4872 
getOutParamIndex(void) const4873     int getOutParamIndex(void) const
4874     {
4875         return 1;
4876     }
4877 };
4878 typedef FrExp<Signature<float, float, int>> Frexp32Bit;
4879 typedef FrExp<Signature<deFloat16, deFloat16, int>> Frexp16Bit;
4880 typedef FrExp<Signature<double, double, int>> Frexp64Bit;
4881 
4882 template <class T>
4883 class FrexpStruct : public FrExp<T>
4884 {
4885 public:
getName(void) const4886     virtual string getName(void) const
4887     {
4888         return "frexpstruct";
4889     }
getSpirvCase(void) const4890     virtual SpirVCaseT getSpirvCase(void) const
4891     {
4892         return SPIRV_CASETYPE_FREXPSTRUCT;
4893     }
4894 };
4895 typedef FrexpStruct<Signature<float, float, int>> FrexpStruct32Bit;
4896 typedef FrexpStruct<Signature<deFloat16, deFloat16, int>> FrexpStruct16Bit;
4897 typedef FrexpStruct<Signature<double, double, int>> FrexpStruct64Bit;
4898 
4899 //Signature<float, float, int>
4900 //Signature<deFloat16, deFloat16, int>
4901 //Signature<double, double, int>
4902 template <class T>
4903 class LdExp : public PrimitiveFunc<T>
4904 {
4905 public:
4906     typedef typename LdExp::IRet IRet;
4907     typedef typename LdExp::IArgs IArgs;
4908 
getName(void) const4909     string getName(void) const
4910     {
4911         return "ldexp";
4912     }
4913 
4914 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4915     Interval doApply(const EvalContext &ctx, const IArgs &iargs) const
4916     {
4917         const int minExp = ctx.format.getMinExp();
4918         const int maxExp = ctx.format.getMaxExp();
4919         // Restrictions from the GLSL.std.450 instruction set.
4920         // See Khronos bugzilla 11180 for rationale.
4921         bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4922         Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4923         if (iargs.b.lo() < minExp)
4924             ret |= 0.0;
4925         if (!ret.isFinite(ctx.format.getMaxValue()))
4926             ret |= TCU_NAN;
4927         return ctx.format.convert(ret);
4928     }
4929 };
4930 
4931 template <>
doApply(const EvalContext & ctx,const IArgs & iargs) const4932 Interval LdExp<Signature<double, double, int>>::doApply(const EvalContext &ctx, const IArgs &iargs) const
4933 {
4934     const int minExp = ctx.format.getMinExp();
4935     const int maxExp = ctx.format.getMaxExp();
4936     // Restrictions from the GLSL.std.450 instruction set.
4937     // See Khronos bugzilla 11180 for rationale.
4938     bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4939     Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4940     // Add 1ULP precision tolerance to account for differing rounding modes between the GPU and deLdExp.
4941     ret += Interval(-ctx.format.ulp(ret.lo()), ctx.format.ulp(ret.hi()));
4942     if (iargs.b.lo() < minExp)
4943         ret |= 0.0;
4944     if (!ret.isFinite(ctx.format.getMaxValue()))
4945         ret |= TCU_NAN;
4946     return ctx.format.convert(ret);
4947 }
4948 
4949 template <int Rows, int Columns, class T>
4950 class Transpose : public PrimitiveFunc<Signature<Matrix<T, Rows, Columns>, Matrix<T, Columns, Rows>>>
4951 {
4952 public:
4953     typedef typename Transpose::IRet IRet;
4954     typedef typename Transpose::IArgs IArgs;
4955 
getName(void) const4956     string getName(void) const
4957     {
4958         return "transpose";
4959     }
4960 
4961 protected:
doApply(const EvalContext &,const IArgs & iargs) const4962     IRet doApply(const EvalContext &, const IArgs &iargs) const
4963     {
4964         IRet ret;
4965 
4966         for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
4967         {
4968             for (int colNdx = 0; colNdx < Columns; ++colNdx)
4969                 ret(rowNdx, colNdx) = iargs.a(colNdx, rowNdx);
4970         }
4971 
4972         return ret;
4973     }
4974 };
4975 
4976 template <typename Ret, typename Arg0, typename Arg1>
4977 class MulFunc : public PrimitiveFunc<Signature<Ret, Arg0, Arg1>>
4978 {
4979 public:
getName(void) const4980     string getName(void) const
4981     {
4982         return "mul";
4983     }
4984 
4985 protected:
doPrint(ostream & os,const BaseArgExprs & args) const4986     void doPrint(ostream &os, const BaseArgExprs &args) const
4987     {
4988         os << "(" << *args[0] << " * " << *args[1] << ")";
4989     }
4990 };
4991 
4992 template <typename T, int LeftRows, int Middle, int RightCols>
4993 class MatMul : public MulFunc<Matrix<T, LeftRows, RightCols>, Matrix<T, LeftRows, Middle>, Matrix<T, Middle, RightCols>>
4994 {
4995 protected:
4996     typedef typename MatMul::IRet IRet;
4997     typedef typename MatMul::IArgs IArgs;
4998     typedef typename MatMul::IArg0 IArg0;
4999     typedef typename MatMul::IArg1 IArg1;
5000 
doApply(const EvalContext & ctx,const IArgs & iargs) const5001     IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5002     {
5003         const IArg0 &left  = iargs.a;
5004         const IArg1 &right = iargs.b;
5005         IRet ret;
5006 
5007         for (int row = 0; row < LeftRows; ++row)
5008         {
5009             for (int col = 0; col < RightCols; ++col)
5010             {
5011                 Interval element(0.0);
5012 
5013                 for (int ndx = 0; ndx < Middle; ++ndx)
5014                     element = call<Add<Signature<T, T, T>>>(
5015                         ctx, element, call<Mul<Signature<T, T, T>>>(ctx, left[ndx][row], right[col][ndx]));
5016 
5017                 ret[col][row] = element;
5018             }
5019         }
5020 
5021         return ret;
5022     }
5023 };
5024 
5025 template <typename T, int Rows, int Cols>
5026 class VecMatMul : public MulFunc<Vector<T, Cols>, Vector<T, Rows>, Matrix<T, Rows, Cols>>
5027 {
5028 public:
5029     typedef typename VecMatMul::IRet IRet;
5030     typedef typename VecMatMul::IArgs IArgs;
5031     typedef typename VecMatMul::IArg0 IArg0;
5032     typedef typename VecMatMul::IArg1 IArg1;
5033 
5034 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const5035     IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5036     {
5037         const IArg0 &left  = iargs.a;
5038         const IArg1 &right = iargs.b;
5039         IRet ret;
5040 
5041         for (int col = 0; col < Cols; ++col)
5042         {
5043             Interval element(0.0);
5044 
5045             for (int row = 0; row < Rows; ++row)
5046                 element = call<Add<Signature<T, T, T>>>(ctx, element,
5047                                                         call<Mul<Signature<T, T, T>>>(ctx, left[row], right[col][row]));
5048 
5049             ret[col] = element;
5050         }
5051 
5052         return ret;
5053     }
5054 };
5055 
5056 template <int Rows, int Cols, class T>
5057 class MatVecMul : public MulFunc<Vector<T, Rows>, Matrix<T, Rows, Cols>, Vector<T, Cols>>
5058 {
5059 public:
5060     typedef typename MatVecMul::IRet IRet;
5061     typedef typename MatVecMul::IArgs IArgs;
5062     typedef typename MatVecMul::IArg0 IArg0;
5063     typedef typename MatVecMul::IArg1 IArg1;
5064 
5065 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const5066     IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5067     {
5068         const IArg0 &left  = iargs.a;
5069         const IArg1 &right = iargs.b;
5070 
5071         return call<VecMatMul<T, Cols, Rows>>(ctx, right, call<Transpose<Rows, Cols, T>>(ctx, left));
5072     }
5073 };
5074 
5075 template <int Rows, int Cols, class T>
5076 class OuterProduct : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>, Vector<T, Rows>, Vector<T, Cols>>>
5077 {
5078 public:
5079     typedef typename OuterProduct::IRet IRet;
5080     typedef typename OuterProduct::IArgs IArgs;
5081 
getName(void) const5082     string getName(void) const
5083     {
5084         return "outerProduct";
5085     }
5086 
5087 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const5088     IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5089     {
5090         IRet ret;
5091 
5092         for (int row = 0; row < Rows; ++row)
5093         {
5094             for (int col = 0; col < Cols; ++col)
5095                 ret[col][row] = call<Mul<Signature<T, T, T>>>(ctx, iargs.a[row], iargs.b[col]);
5096         }
5097 
5098         return ret;
5099     }
5100 };
5101 
5102 template <int Rows, int Cols, class T>
outerProduct(const ExprP<Vector<T,Rows>> & left,const ExprP<Vector<T,Cols>> & right)5103 ExprP<Matrix<T, Rows, Cols>> outerProduct(const ExprP<Vector<T, Rows>> &left, const ExprP<Vector<T, Cols>> &right)
5104 {
5105     return app<OuterProduct<Rows, Cols, T>>(left, right);
5106 }
5107 
5108 template <class T>
5109 class DeterminantBase : public DerivedFunc<T>
5110 {
5111 public:
getName(void) const5112     string getName(void) const
5113     {
5114         return "determinant";
5115     }
5116 };
5117 
5118 template <int Size>
5119 class Determinant;
5120 template <int Size>
5121 class Determinant16bit;
5122 template <int Size>
5123 class Determinant64bit;
5124 
5125 template <int Size>
determinant(ExprP<Matrix<float,Size,Size>> mat)5126 ExprP<float> determinant(ExprP<Matrix<float, Size, Size>> mat)
5127 {
5128     return app<Determinant<Size>>(mat);
5129 }
5130 
5131 template <int Size>
determinant(ExprP<Matrix<deFloat16,Size,Size>> mat)5132 ExprP<deFloat16> determinant(ExprP<Matrix<deFloat16, Size, Size>> mat)
5133 {
5134     return app<Determinant16bit<Size>>(mat);
5135 }
5136 
5137 template <int Size>
determinant(ExprP<Matrix<double,Size,Size>> mat)5138 ExprP<double> determinant(ExprP<Matrix<double, Size, Size>> mat)
5139 {
5140     return app<Determinant64bit<Size>>(mat);
5141 }
5142 
5143 template <>
5144 class Determinant<2> : public DeterminantBase<Signature<float, Matrix<float, 2, 2>>>
5145 {
5146 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5147     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5148     {
5149         ExprP<Mat2> mat = args.a;
5150 
5151         return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
5152     }
5153 };
5154 
5155 template <>
5156 class Determinant<3> : public DeterminantBase<Signature<float, Matrix<float, 3, 3>>>
5157 {
5158 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5159     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5160     {
5161         ExprP<Mat3> mat = args.a;
5162 
5163         return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
5164                 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
5165                 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
5166     }
5167 };
5168 
5169 template <>
5170 class Determinant<4> : public DeterminantBase<Signature<float, Matrix<float, 4, 4>>>
5171 {
5172 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5173     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5174     {
5175         ExprP<Mat4> mat = args.a;
5176         ExprP<Mat3> minors[4];
5177 
5178         for (int ndx = 0; ndx < 4; ++ndx)
5179         {
5180             ExprP<Vec4> minorColumns[3];
5181             ExprP<Vec3> columns[3];
5182 
5183             for (int col = 0; col < 3; ++col)
5184                 minorColumns[col] = mat[col < ndx ? col : col + 1];
5185 
5186             for (int col = 0; col < 3; ++col)
5187                 columns[col] = vec3(minorColumns[0][col + 1], minorColumns[1][col + 1], minorColumns[2][col + 1]);
5188 
5189             minors[ndx] = bindExpression("minor", ctx, mat3(columns[0], columns[1], columns[2]));
5190         }
5191 
5192         return (mat[0][0] * determinant(minors[0]) - mat[1][0] * determinant(minors[1]) +
5193                 mat[2][0] * determinant(minors[2]) - mat[3][0] * determinant(minors[3]));
5194     }
5195 };
5196 
5197 template <>
5198 class Determinant16bit<2> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 2, 2>>>
5199 {
5200 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5201     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5202     {
5203         ExprP<Mat2_16b> mat = args.a;
5204 
5205         return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
5206     }
5207 };
5208 
5209 template <>
5210 class Determinant16bit<3> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 3, 3>>>
5211 {
5212 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5213     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5214     {
5215         ExprP<Mat3_16b> mat = args.a;
5216 
5217         return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
5218                 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
5219                 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
5220     }
5221 };
5222 
5223 template <>
5224 class Determinant16bit<4> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 4, 4>>>
5225 {
5226 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5227     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5228     {
5229         ExprP<Mat4_16b> mat = args.a;
5230         ExprP<Mat3_16b> minors[4];
5231 
5232         for (int ndx = 0; ndx < 4; ++ndx)
5233         {
5234             ExprP<Vec4_16Bit> minorColumns[3];
5235             ExprP<Vec3_16Bit> columns[3];
5236 
5237             for (int col = 0; col < 3; ++col)
5238                 minorColumns[col] = mat[col < ndx ? col : col + 1];
5239 
5240             for (int col = 0; col < 3; ++col)
5241                 columns[col] = vec3(minorColumns[0][col + 1], minorColumns[1][col + 1], minorColumns[2][col + 1]);
5242 
5243             minors[ndx] = bindExpression("minor", ctx, mat3(columns[0], columns[1], columns[2]));
5244         }
5245 
5246         return (mat[0][0] * determinant(minors[0]) - mat[1][0] * determinant(minors[1]) +
5247                 mat[2][0] * determinant(minors[2]) - mat[3][0] * determinant(minors[3]));
5248     }
5249 };
5250 
5251 template <>
5252 class Determinant64bit<2> : public DeterminantBase<Signature<double, Matrix<double, 2, 2>>>
5253 {
5254 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5255     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5256     {
5257         ExprP<Matrix2d> mat = args.a;
5258 
5259         return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
5260     }
5261 };
5262 
5263 template <>
5264 class Determinant64bit<3> : public DeterminantBase<Signature<double, Matrix<double, 3, 3>>>
5265 {
5266 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5267     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5268     {
5269         ExprP<Matrix3d> mat = args.a;
5270 
5271         return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
5272                 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
5273                 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
5274     }
5275 };
5276 
5277 template <>
5278 class Determinant64bit<4> : public DeterminantBase<Signature<double, Matrix<double, 4, 4>>>
5279 {
5280 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5281     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5282     {
5283         ExprP<Matrix4d> mat = args.a;
5284         ExprP<Matrix3d> minors[4];
5285 
5286         for (int ndx = 0; ndx < 4; ++ndx)
5287         {
5288             ExprP<Vec4_64Bit> minorColumns[3];
5289             ExprP<Vec3_64Bit> columns[3];
5290 
5291             for (int col = 0; col < 3; ++col)
5292                 minorColumns[col] = mat[col < ndx ? col : col + 1];
5293 
5294             for (int col = 0; col < 3; ++col)
5295                 columns[col] = vec3(minorColumns[0][col + 1], minorColumns[1][col + 1], minorColumns[2][col + 1]);
5296 
5297             minors[ndx] = bindExpression("minor", ctx, mat3(columns[0], columns[1], columns[2]));
5298         }
5299 
5300         return (mat[0][0] * determinant(minors[0]) - mat[1][0] * determinant(minors[1]) +
5301                 mat[2][0] * determinant(minors[2]) - mat[3][0] * determinant(minors[3]));
5302     }
5303 };
5304 
5305 template <int Size>
5306 class Inverse;
5307 
5308 template <int Size>
inverse(ExprP<Matrix<float,Size,Size>> mat)5309 ExprP<Matrix<float, Size, Size>> inverse(ExprP<Matrix<float, Size, Size>> mat)
5310 {
5311     return app<Inverse<Size>>(mat);
5312 }
5313 
5314 template <int Size>
5315 class Inverse16bit;
5316 
5317 template <int Size>
inverse(ExprP<Matrix<deFloat16,Size,Size>> mat)5318 ExprP<Matrix<deFloat16, Size, Size>> inverse(ExprP<Matrix<deFloat16, Size, Size>> mat)
5319 {
5320     return app<Inverse16bit<Size>>(mat);
5321 }
5322 
5323 template <int Size>
5324 class Inverse64bit;
5325 
5326 template <int Size>
inverse(ExprP<Matrix<double,Size,Size>> mat)5327 ExprP<Matrix<double, Size, Size>> inverse(ExprP<Matrix<double, Size, Size>> mat)
5328 {
5329     return app<Inverse64bit<Size>>(mat);
5330 }
5331 
5332 template <>
5333 class Inverse<2> : public DerivedFunc<Signature<Mat2, Mat2>>
5334 {
5335 public:
getName(void) const5336     string getName(void) const
5337     {
5338         return "inverse";
5339     }
5340 
5341 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5342     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5343     {
5344         ExprP<Mat2> mat  = args.a;
5345         ExprP<float> det = bindExpression("det", ctx, determinant(mat));
5346 
5347         return mat2(vec2(mat[1][1] / det, -mat[0][1] / det), vec2(-mat[1][0] / det, mat[0][0] / det));
5348     }
5349 };
5350 
5351 template <>
5352 class Inverse<3> : public DerivedFunc<Signature<Mat3, Mat3>>
5353 {
5354 public:
getName(void) const5355     string getName(void) const
5356     {
5357         return "inverse";
5358     }
5359 
5360 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5361     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5362     {
5363         ExprP<Mat3> mat = args.a;
5364         ExprP<Mat2> invA =
5365             bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5366 
5367         ExprP<Vec2> matB  = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5368         ExprP<Vec2> matC  = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5369         ExprP<float> matD = bindExpression("matD", ctx, mat[2][2]);
5370 
5371         ExprP<float> schur = bindExpression("schur", ctx, constant(1.0f) / (matD - dot(matC * invA, matB)));
5372 
5373         ExprP<Vec2> t1     = invA * matB;
5374         ExprP<Vec2> t2     = t1 * schur;
5375         ExprP<Mat2> t3     = outerProduct(t2, matC);
5376         ExprP<Mat2> t4     = t3 * invA;
5377         ExprP<Mat2> t5     = invA + t4;
5378         ExprP<Mat2> blockA = bindExpression("blockA", ctx, t5);
5379         ExprP<Vec2> blockB = bindExpression("blockB", ctx, (invA * matB) * -schur);
5380         ExprP<Vec2> blockC = bindExpression("blockC", ctx, (matC * invA) * -schur);
5381 
5382         return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]), vec3(blockA[1][0], blockA[1][1], blockC[1]),
5383                     vec3(blockB[0], blockB[1], schur));
5384     }
5385 };
5386 
5387 template <>
5388 class Inverse<4> : public DerivedFunc<Signature<Mat4, Mat4>>
5389 {
5390 public:
getName(void) const5391     string getName(void) const
5392     {
5393         return "inverse";
5394     }
5395 
5396 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5397     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5398     {
5399         ExprP<Mat4> mat = args.a;
5400         ExprP<Mat2> invA =
5401             bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5402         ExprP<Mat2> matB   = bindExpression("matB", ctx, mat2(vec2(mat[2][0], mat[2][1]), vec2(mat[3][0], mat[3][1])));
5403         ExprP<Mat2> matC   = bindExpression("matC", ctx, mat2(vec2(mat[0][2], mat[0][3]), vec2(mat[1][2], mat[1][3])));
5404         ExprP<Mat2> matD   = bindExpression("matD", ctx, mat2(vec2(mat[2][2], mat[2][3]), vec2(mat[3][2], mat[3][3])));
5405         ExprP<Mat2> schur  = bindExpression("schur", ctx, inverse(matD + -(matC * invA * matB)));
5406         ExprP<Mat2> blockA = bindExpression("blockA", ctx, invA + (invA * matB * schur * matC * invA));
5407         ExprP<Mat2> blockB = bindExpression("blockB", ctx, (-invA) * matB * schur);
5408         ExprP<Mat2> blockC = bindExpression("blockC", ctx, (-schur) * matC * invA);
5409 
5410         return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5411                     vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5412                     vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5413                     vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5414     }
5415 };
5416 
5417 template <>
5418 class Inverse16bit<2> : public DerivedFunc<Signature<Mat2_16b, Mat2_16b>>
5419 {
5420 public:
getName(void) const5421     string getName(void) const
5422     {
5423         return "inverse";
5424     }
5425 
5426 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5427     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5428     {
5429         ExprP<Mat2_16b> mat  = args.a;
5430         ExprP<deFloat16> det = bindExpression("det", ctx, determinant(mat));
5431 
5432         return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)), vec2((-mat[1][0] / det), (mat[0][0] / det)));
5433     }
5434 };
5435 
5436 template <>
5437 class Inverse16bit<3> : public DerivedFunc<Signature<Mat3_16b, Mat3_16b>>
5438 {
5439 public:
getName(void) const5440     string getName(void) const
5441     {
5442         return "inverse";
5443     }
5444 
5445 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5446     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5447     {
5448         ExprP<Mat3_16b> mat = args.a;
5449         ExprP<Mat2_16b> invA =
5450             bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5451 
5452         ExprP<Vec2_16Bit> matB       = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5453         ExprP<Vec2_16Bit> matC       = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5454         ExprP<Mat3_16b::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5455 
5456         ExprP<Mat3_16b::Scalar> schur =
5457             bindExpression("schur", ctx, constant((deFloat16)FLOAT16_1_0) / (matD - dot(matC * invA, matB)));
5458 
5459         ExprP<Vec2_16Bit> t1     = invA * matB;
5460         ExprP<Vec2_16Bit> t2     = t1 * schur;
5461         ExprP<Mat2_16b> t3       = outerProduct(t2, matC);
5462         ExprP<Mat2_16b> t4       = t3 * invA;
5463         ExprP<Mat2_16b> t5       = invA + t4;
5464         ExprP<Mat2_16b> blockA   = bindExpression("blockA", ctx, t5);
5465         ExprP<Vec2_16Bit> blockB = bindExpression("blockB", ctx, (invA * matB) * -schur);
5466         ExprP<Vec2_16Bit> blockC = bindExpression("blockC", ctx, (matC * invA) * -schur);
5467 
5468         return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]), vec3(blockA[1][0], blockA[1][1], blockC[1]),
5469                     vec3(blockB[0], blockB[1], schur));
5470     }
5471 };
5472 
5473 template <>
5474 class Inverse16bit<4> : public DerivedFunc<Signature<Mat4_16b, Mat4_16b>>
5475 {
5476 public:
getName(void) const5477     string getName(void) const
5478     {
5479         return "inverse";
5480     }
5481 
5482 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5483     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5484     {
5485         ExprP<Mat4_16b> mat = args.a;
5486         ExprP<Mat2_16b> invA =
5487             bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5488         ExprP<Mat2_16b> matB =
5489             bindExpression("matB", ctx, mat2(vec2(mat[2][0], mat[2][1]), vec2(mat[3][0], mat[3][1])));
5490         ExprP<Mat2_16b> matC =
5491             bindExpression("matC", ctx, mat2(vec2(mat[0][2], mat[0][3]), vec2(mat[1][2], mat[1][3])));
5492         ExprP<Mat2_16b> matD =
5493             bindExpression("matD", ctx, mat2(vec2(mat[2][2], mat[2][3]), vec2(mat[3][2], mat[3][3])));
5494         ExprP<Mat2_16b> schur  = bindExpression("schur", ctx, inverse(matD + -(matC * invA * matB)));
5495         ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx, invA + (invA * matB * schur * matC * invA));
5496         ExprP<Mat2_16b> blockB = bindExpression("blockB", ctx, (-invA) * matB * schur);
5497         ExprP<Mat2_16b> blockC = bindExpression("blockC", ctx, (-schur) * matC * invA);
5498 
5499         return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5500                     vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5501                     vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5502                     vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5503     }
5504 };
5505 
5506 template <>
5507 class Inverse64bit<2> : public DerivedFunc<Signature<Matrix2d, Matrix2d>>
5508 {
5509 public:
getName(void) const5510     string getName(void) const
5511     {
5512         return "inverse";
5513     }
5514 
5515 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5516     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5517     {
5518         ExprP<Matrix2d> mat = args.a;
5519         ExprP<double> det   = bindExpression("det", ctx, determinant(mat));
5520 
5521         return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)), vec2((-mat[1][0] / det), (mat[0][0] / det)));
5522     }
5523 };
5524 
5525 template <>
5526 class Inverse64bit<3> : public DerivedFunc<Signature<Matrix3d, Matrix3d>>
5527 {
5528 public:
getName(void) const5529     string getName(void) const
5530     {
5531         return "inverse";
5532     }
5533 
5534 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5535     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5536     {
5537         ExprP<Matrix3d> mat = args.a;
5538         ExprP<Matrix2d> invA =
5539             bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5540 
5541         ExprP<Vec2_64Bit> matB       = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5542         ExprP<Vec2_64Bit> matC       = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5543         ExprP<Matrix3d::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5544 
5545         ExprP<Matrix3d::Scalar> schur = bindExpression("schur", ctx, constant(1.0) / (matD - dot(matC * invA, matB)));
5546 
5547         ExprP<Vec2_64Bit> t1     = invA * matB;
5548         ExprP<Vec2_64Bit> t2     = t1 * schur;
5549         ExprP<Matrix2d> t3       = outerProduct(t2, matC);
5550         ExprP<Matrix2d> t4       = t3 * invA;
5551         ExprP<Matrix2d> t5       = invA + t4;
5552         ExprP<Matrix2d> blockA   = bindExpression("blockA", ctx, t5);
5553         ExprP<Vec2_64Bit> blockB = bindExpression("blockB", ctx, (invA * matB) * -schur);
5554         ExprP<Vec2_64Bit> blockC = bindExpression("blockC", ctx, (matC * invA) * -schur);
5555 
5556         return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]), vec3(blockA[1][0], blockA[1][1], blockC[1]),
5557                     vec3(blockB[0], blockB[1], schur));
5558     }
5559 };
5560 
5561 template <>
5562 class Inverse64bit<4> : public DerivedFunc<Signature<Matrix4d, Matrix4d>>
5563 {
5564 public:
getName(void) const5565     string getName(void) const
5566     {
5567         return "inverse";
5568     }
5569 
5570 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5571     ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5572     {
5573         ExprP<Matrix4d> mat = args.a;
5574         ExprP<Matrix2d> invA =
5575             bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5576         ExprP<Matrix2d> matB =
5577             bindExpression("matB", ctx, mat2(vec2(mat[2][0], mat[2][1]), vec2(mat[3][0], mat[3][1])));
5578         ExprP<Matrix2d> matC =
5579             bindExpression("matC", ctx, mat2(vec2(mat[0][2], mat[0][3]), vec2(mat[1][2], mat[1][3])));
5580         ExprP<Matrix2d> matD =
5581             bindExpression("matD", ctx, mat2(vec2(mat[2][2], mat[2][3]), vec2(mat[3][2], mat[3][3])));
5582         ExprP<Matrix2d> schur  = bindExpression("schur", ctx, inverse(matD + -(matC * invA * matB)));
5583         ExprP<Matrix2d> blockA = bindExpression("blockA", ctx, invA + (invA * matB * schur * matC * invA));
5584         ExprP<Matrix2d> blockB = bindExpression("blockB", ctx, (-invA) * matB * schur);
5585         ExprP<Matrix2d> blockC = bindExpression("blockC", ctx, (-schur) * matC * invA);
5586 
5587         return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5588                     vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5589                     vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5590                     vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5591     }
5592 };
5593 
5594 //Signature<float, float, float, float>
5595 //Signature<deFloat16, deFloat16, deFloat16, deFloat16>
5596 //Signature<double, double, double, double>
5597 template <class T>
5598 class Fma : public DerivedFunc<T>
5599 {
5600 public:
5601     typedef typename Fma::ArgExprs ArgExprs;
5602     typedef typename Fma::Ret Ret;
5603 
getName(void) const5604     string getName(void) const
5605     {
5606         return "fma";
5607     }
5608 
5609 protected:
doExpand(ExpandContext &,const ArgExprs & x) const5610     ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &x) const
5611     {
5612         return x.a * x.b + x.c;
5613     }
5614 };
5615 
5616 } // namespace Functions
5617 
5618 using namespace Functions;
5619 
5620 template <typename T>
operator [](int i) const5621 ExprP<typename T::Element> ContainerExprPBase<T>::operator[](int i) const
5622 {
5623     return Functions::getComponent(exprP<T>(*this), i);
5624 }
5625 
operator +(const ExprP<float> & arg0,const ExprP<float> & arg1)5626 ExprP<float> operator+(const ExprP<float> &arg0, const ExprP<float> &arg1)
5627 {
5628     return app<Add<Signature<float, float, float>>>(arg0, arg1);
5629 }
5630 
operator +(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5631 ExprP<deFloat16> operator+(const ExprP<deFloat16> &arg0, const ExprP<deFloat16> &arg1)
5632 {
5633     return app<Add<Signature<deFloat16, deFloat16, deFloat16>>>(arg0, arg1);
5634 }
5635 
operator +(const ExprP<double> & arg0,const ExprP<double> & arg1)5636 ExprP<double> operator+(const ExprP<double> &arg0, const ExprP<double> &arg1)
5637 {
5638     return app<Add<Signature<double, double, double>>>(arg0, arg1);
5639 }
5640 
5641 template <typename T>
operator -(const ExprP<T> & arg0,const ExprP<T> & arg1)5642 ExprP<T> operator-(const ExprP<T> &arg0, const ExprP<T> &arg1)
5643 {
5644     return app<Sub<Signature<T, T, T>>>(arg0, arg1);
5645 }
5646 
5647 template <typename T>
operator -(const ExprP<T> & arg0)5648 ExprP<T> operator-(const ExprP<T> &arg0)
5649 {
5650     return app<Negate<Signature<T, T>>>(arg0);
5651 }
5652 
operator *(const ExprP<float> & arg0,const ExprP<float> & arg1)5653 ExprP<float> operator*(const ExprP<float> &arg0, const ExprP<float> &arg1)
5654 {
5655     return app<Mul<Signature<float, float, float>>>(arg0, arg1);
5656 }
5657 
operator *(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5658 ExprP<deFloat16> operator*(const ExprP<deFloat16> &arg0, const ExprP<deFloat16> &arg1)
5659 {
5660     return app<Mul<Signature<deFloat16, deFloat16, deFloat16>>>(arg0, arg1);
5661 }
5662 
operator *(const ExprP<double> & arg0,const ExprP<double> & arg1)5663 ExprP<double> operator*(const ExprP<double> &arg0, const ExprP<double> &arg1)
5664 {
5665     return app<Mul<Signature<double, double, double>>>(arg0, arg1);
5666 }
5667 
5668 template <typename T>
operator /(const ExprP<T> & arg0,const ExprP<T> & arg1)5669 ExprP<T> operator/(const ExprP<T> &arg0, const ExprP<T> &arg1)
5670 {
5671     return app<Div<Signature<T, T, T>>>(arg0, arg1);
5672 }
5673 
5674 template <typename Sig_, int Size>
5675 class GenFunc : public PrimitiveFunc<Signature<typename ContainerOf<typename Sig_::Ret, Size>::Container,
5676                                                typename ContainerOf<typename Sig_::Arg0, Size>::Container,
5677                                                typename ContainerOf<typename Sig_::Arg1, Size>::Container,
5678                                                typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5679                                                typename ContainerOf<typename Sig_::Arg3, Size>::Container>>
5680 {
5681 public:
5682     typedef typename GenFunc::IArgs IArgs;
5683     typedef typename GenFunc::IRet IRet;
5684 
GenFunc(const Func<Sig_> & scalarFunc)5685     GenFunc(const Func<Sig_> &scalarFunc) : m_func(scalarFunc)
5686     {
5687     }
5688 
getSpirvCase(void) const5689     SpirVCaseT getSpirvCase(void) const
5690     {
5691         return m_func.getSpirvCase();
5692     }
5693 
getName(void) const5694     string getName(void) const
5695     {
5696         return m_func.getName();
5697     }
5698 
getOutParamIndex(void) const5699     int getOutParamIndex(void) const
5700     {
5701         return m_func.getOutParamIndex();
5702     }
5703 
getRequiredExtension(void) const5704     string getRequiredExtension(void) const
5705     {
5706         return m_func.getRequiredExtension();
5707     }
5708 
getInputRange(const bool is16bit) const5709     Interval getInputRange(const bool is16bit) const
5710     {
5711         return m_func.getInputRange(is16bit);
5712     }
5713 
5714 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5715     void doPrint(ostream &os, const BaseArgExprs &args) const
5716     {
5717         m_func.print(os, args);
5718     }
5719 
doApply(const EvalContext & ctx,const IArgs & iargs) const5720     IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5721     {
5722         IRet ret;
5723 
5724         for (int ndx = 0; ndx < Size; ++ndx)
5725         {
5726             ret[ndx] = m_func.apply(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5727         }
5728 
5729         return ret;
5730     }
5731 
doFail(const EvalContext & ctx,const IArgs & iargs) const5732     IRet doFail(const EvalContext &ctx, const IArgs &iargs) const
5733     {
5734         IRet ret;
5735 
5736         for (int ndx = 0; ndx < Size; ++ndx)
5737         {
5738             ret[ndx] = m_func.fail(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5739         }
5740 
5741         return ret;
5742     }
5743 
doGetUsedFuncs(FuncSet & dst) const5744     void doGetUsedFuncs(FuncSet &dst) const
5745     {
5746         m_func.getUsedFuncs(dst);
5747     }
5748 
5749     const Func<Sig_> &m_func;
5750 };
5751 
5752 template <typename F, int Size>
5753 class VectorizedFunc : public GenFunc<typename F::Sig, Size>
5754 {
5755 public:
VectorizedFunc(void)5756     VectorizedFunc(void) : GenFunc<typename F::Sig, Size>(instance<F>())
5757     {
5758     }
5759 };
5760 
5761 template <typename Sig_, int Size>
5762 class FixedGenFunc
5763     : public PrimitiveFunc<Signature<typename ContainerOf<typename Sig_::Ret, Size>::Container,
5764                                      typename ContainerOf<typename Sig_::Arg0, Size>::Container, typename Sig_::Arg1,
5765                                      typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5766                                      typename ContainerOf<typename Sig_::Arg3, Size>::Container>>
5767 {
5768 public:
5769     typedef typename FixedGenFunc::IArgs IArgs;
5770     typedef typename FixedGenFunc::IRet IRet;
5771 
getName(void) const5772     string getName(void) const
5773     {
5774         return this->doGetScalarFunc().getName();
5775     }
5776 
getSpirvCase(void) const5777     SpirVCaseT getSpirvCase(void) const
5778     {
5779         return this->doGetScalarFunc().getSpirvCase();
5780     }
5781 
5782 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5783     void doPrint(ostream &os, const BaseArgExprs &args) const
5784     {
5785         this->doGetScalarFunc().print(os, args);
5786     }
5787 
doApply(const EvalContext & ctx,const IArgs & iargs) const5788     IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5789     {
5790         IRet ret;
5791         const Func<Sig_> &func = this->doGetScalarFunc();
5792 
5793         for (int ndx = 0; ndx < Size; ++ndx)
5794             ret[ndx] = func.apply(ctx, iargs.a[ndx], iargs.b, iargs.c[ndx], iargs.d[ndx]);
5795 
5796         return ret;
5797     }
5798 
5799     virtual const Func<Sig_> &doGetScalarFunc(void) const = 0;
5800 };
5801 
5802 template <typename F, int Size>
5803 class FixedVecFunc : public FixedGenFunc<typename F::Sig, Size>
5804 {
5805 protected:
doGetScalarFunc(void) const5806     const Func<typename F::Sig> &doGetScalarFunc(void) const
5807     {
5808         return instance<F>();
5809     }
5810 };
5811 
5812 template <typename Sig>
5813 struct GenFuncs
5814 {
GenFuncsvkt::shaderexecutor::GenFuncs5815     GenFuncs(const Func<Sig> &func_, const GenFunc<Sig, 2> &func2_, const GenFunc<Sig, 3> &func3_,
5816              const GenFunc<Sig, 4> &func4_)
5817         : func(func_)
5818         , func2(func2_)
5819         , func3(func3_)
5820         , func4(func4_)
5821     {
5822     }
5823 
5824     const Func<Sig> &func;
5825     const GenFunc<Sig, 2> &func2;
5826     const GenFunc<Sig, 3> &func3;
5827     const GenFunc<Sig, 4> &func4;
5828 };
5829 
5830 template <typename F>
makeVectorizedFuncs(void)5831 GenFuncs<typename F::Sig> makeVectorizedFuncs(void)
5832 {
5833     return GenFuncs<typename F::Sig>(instance<F>(), instance<VectorizedFunc<F, 2>>(), instance<VectorizedFunc<F, 3>>(),
5834                                      instance<VectorizedFunc<F, 4>>());
5835 }
5836 
5837 template <typename T, int Size>
operator /(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5838 ExprP<Vector<T, Size>> operator/(const ExprP<Vector<T, Size>> &arg0, const ExprP<T> &arg1)
5839 {
5840     return app<FixedVecFunc<Div<Signature<T, T, T>>, Size>>(arg0, arg1);
5841 }
5842 
5843 template <typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0)5844 ExprP<Vector<T, Size>> operator-(const ExprP<Vector<T, Size>> &arg0)
5845 {
5846     return app<VectorizedFunc<Negate<Signature<T, T>>, Size>>(arg0);
5847 }
5848 
5849 template <typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5850 ExprP<Vector<T, Size>> operator-(const ExprP<Vector<T, Size>> &arg0, const ExprP<Vector<T, Size>> &arg1)
5851 {
5852     return app<VectorizedFunc<Sub<Signature<T, T, T>>, Size>>(arg0, arg1);
5853 }
5854 
5855 template <int Size, typename T>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5856 ExprP<Vector<T, Size>> operator*(const ExprP<Vector<T, Size>> &arg0, const ExprP<T> &arg1)
5857 {
5858     return app<FixedVecFunc<Mul<Signature<T, T, T>>, Size>>(arg0, arg1);
5859 }
5860 
5861 template <typename T, int Size>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5862 ExprP<Vector<T, Size>> operator*(const ExprP<Vector<T, Size>> &arg0, const ExprP<Vector<T, Size>> &arg1)
5863 {
5864     return app<VectorizedFunc<Mul<Signature<T, T, T>>, Size>>(arg0, arg1);
5865 }
5866 
5867 template <int LeftRows, int Middle, int RightCols, typename T>
operator *(const ExprP<Matrix<T,LeftRows,Middle>> & left,const ExprP<Matrix<T,Middle,RightCols>> & right)5868 ExprP<Matrix<T, LeftRows, RightCols>> operator*(const ExprP<Matrix<T, LeftRows, Middle>> &left,
5869                                                 const ExprP<Matrix<T, Middle, RightCols>> &right)
5870 {
5871     return app<MatMul<T, LeftRows, Middle, RightCols>>(left, right);
5872 }
5873 
5874 template <int Rows, int Cols, typename T>
operator *(const ExprP<Vector<T,Cols>> & left,const ExprP<Matrix<T,Rows,Cols>> & right)5875 ExprP<Vector<T, Rows>> operator*(const ExprP<Vector<T, Cols>> &left, const ExprP<Matrix<T, Rows, Cols>> &right)
5876 {
5877     return app<VecMatMul<T, Rows, Cols>>(left, right);
5878 }
5879 
5880 template <int Rows, int Cols, class T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<Vector<T,Rows>> & right)5881 ExprP<Vector<T, Cols>> operator*(const ExprP<Matrix<T, Rows, Cols>> &left, const ExprP<Vector<T, Rows>> &right)
5882 {
5883     return app<MatVecMul<Rows, Cols, T>>(left, right);
5884 }
5885 
5886 template <int Rows, int Cols, typename T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<T> & right)5887 ExprP<Matrix<T, Rows, Cols>> operator*(const ExprP<Matrix<T, Rows, Cols>> &left, const ExprP<T> &right)
5888 {
5889     return app<ScalarMatFunc<Mul<Signature<T, T, T>>, Rows, Cols>>(left, right);
5890 }
5891 
5892 template <int Rows, int Cols>
operator +(const ExprP<Matrix<float,Rows,Cols>> & left,const ExprP<Matrix<float,Rows,Cols>> & right)5893 ExprP<Matrix<float, Rows, Cols>> operator+(const ExprP<Matrix<float, Rows, Cols>> &left,
5894                                            const ExprP<Matrix<float, Rows, Cols>> &right)
5895 {
5896     return app<CompMatFunc<Add<Signature<float, float, float>>, float, Rows, Cols>>(left, right);
5897 }
5898 
5899 template <int Rows, int Cols>
operator +(const ExprP<Matrix<deFloat16,Rows,Cols>> & left,const ExprP<Matrix<deFloat16,Rows,Cols>> & right)5900 ExprP<Matrix<deFloat16, Rows, Cols>> operator+(const ExprP<Matrix<deFloat16, Rows, Cols>> &left,
5901                                                const ExprP<Matrix<deFloat16, Rows, Cols>> &right)
5902 {
5903     return app<CompMatFunc<Add<Signature<deFloat16, deFloat16, deFloat16>>, deFloat16, Rows, Cols>>(left, right);
5904 }
5905 
5906 template <int Rows, int Cols>
operator +(const ExprP<Matrix<double,Rows,Cols>> & left,const ExprP<Matrix<double,Rows,Cols>> & right)5907 ExprP<Matrix<double, Rows, Cols>> operator+(const ExprP<Matrix<double, Rows, Cols>> &left,
5908                                             const ExprP<Matrix<double, Rows, Cols>> &right)
5909 {
5910     return app<CompMatFunc<Add<Signature<double, double, double>>, double, Rows, Cols>>(left, right);
5911 }
5912 
5913 template <typename T, int Rows, int Cols>
operator -(const ExprP<Matrix<T,Rows,Cols>> & mat)5914 ExprP<Matrix<T, Rows, Cols>> operator-(const ExprP<Matrix<T, Rows, Cols>> &mat)
5915 {
5916     return app<MatNeg<T, Rows, Cols>>(mat);
5917 }
5918 
5919 template <typename T>
5920 class Sampling
5921 {
5922 public:
~Sampling()5923     virtual ~Sampling()
5924     {
5925     }
5926 
genFixeds(const FloatFormat &,const Precision,vector<T> &,const Interval &) const5927     virtual void genFixeds(const FloatFormat &, const Precision, vector<T> &, const Interval &) const
5928     {
5929     }
genRandom(const FloatFormat &,const Precision,Random &,const Interval &) const5930     virtual T genRandom(const FloatFormat &, const Precision, Random &, const Interval &) const
5931     {
5932         return T();
5933     }
removeNotInRange(vector<T> &,const Interval &,const Precision) const5934     virtual void removeNotInRange(vector<T> &, const Interval &, const Precision) const
5935     {
5936     }
5937 };
5938 
5939 template <>
5940 class DefaultSampling<Void> : public Sampling<Void>
5941 {
5942 public:
genFixeds(const FloatFormat &,const Precision,vector<Void> & dst,const Interval &) const5943     void genFixeds(const FloatFormat &, const Precision, vector<Void> &dst, const Interval &) const
5944     {
5945         dst.push_back(Void());
5946     }
5947 };
5948 
5949 template <>
5950 class DefaultSampling<bool> : public Sampling<bool>
5951 {
5952 public:
genFixeds(const FloatFormat &,const Precision,vector<bool> & dst,const Interval &) const5953     void genFixeds(const FloatFormat &, const Precision, vector<bool> &dst, const Interval &) const
5954     {
5955         dst.push_back(true);
5956         dst.push_back(false);
5957     }
5958 };
5959 
5960 template <>
5961 class DefaultSampling<int> : public Sampling<int>
5962 {
5963 public:
genRandom(const FloatFormat &,const Precision prec,Random & rnd,const Interval &) const5964     int genRandom(const FloatFormat &, const Precision prec, Random &rnd, const Interval &) const
5965     {
5966         const int exp  = rnd.getInt(0, getNumBits(prec) - 2);
5967         const int sign = rnd.getBool() ? -1 : 1;
5968 
5969         return sign * rnd.getInt(0, (int32_t)1 << exp);
5970     }
5971 
genFixeds(const FloatFormat &,const Precision,vector<int> & dst,const Interval &) const5972     void genFixeds(const FloatFormat &, const Precision, vector<int> &dst, const Interval &) const
5973     {
5974         dst.push_back(0);
5975         dst.push_back(-1);
5976         dst.push_back(1);
5977     }
5978 
5979 private:
getNumBits(Precision prec)5980     static inline int getNumBits(Precision prec)
5981     {
5982         switch (prec)
5983         {
5984         case glu::PRECISION_LAST:
5985         case glu::PRECISION_MEDIUMP:
5986             return 16;
5987         case glu::PRECISION_HIGHP:
5988             return 32;
5989         default:
5990             DE_ASSERT(false);
5991             return 0;
5992         }
5993     }
5994 };
5995 
5996 template <>
5997 class DefaultSampling<float> : public Sampling<float>
5998 {
5999 public:
6000     float genRandom(const FloatFormat &format, const Precision prec, Random &rnd, const Interval &inputRange) const;
6001     void genFixeds(const FloatFormat &format, const Precision prec, vector<float> &dst,
6002                    const Interval &inputRange) const;
6003     void removeNotInRange(vector<float> &dst, const Interval &inputRange, const Precision prec) const;
6004 };
6005 
6006 template <>
6007 class DefaultSampling<double> : public Sampling<double>
6008 {
6009 public:
6010     double genRandom(const FloatFormat &format, const Precision prec, Random &rnd, const Interval &inputRange) const;
6011     void genFixeds(const FloatFormat &format, const Precision prec, vector<double> &dst,
6012                    const Interval &inputRange) const;
6013     void removeNotInRange(vector<double> &dst, const Interval &inputRange, const Precision prec) const;
6014 };
6015 
isDenorm16(deFloat16 v)6016 static bool isDenorm16(deFloat16 v)
6017 {
6018     const uint16_t mantissa = 0x03FF;
6019     const uint16_t exponent = 0x7C00;
6020     return ((exponent & v) == 0 && (mantissa & v) != 0);
6021 }
6022 
6023 //! Generate a random double from a reasonable general-purpose distribution.
randomDouble(const FloatFormat & format,Random & rnd,const Interval & inputRange)6024 double randomDouble(const FloatFormat &format, Random &rnd, const Interval &inputRange)
6025 {
6026     // No testing of subnormals. TODO: Could integrate float controls for some operations.
6027     const int minExp         = format.getMinExp();
6028     const int maxExp         = format.getMaxExp();
6029     const bool haveSubnormal = false;
6030     const double midpoint    = inputRange.midpoint();
6031 
6032     // Choose exponent so that the cumulative distribution is cubic.
6033     // This makes the probability distribution quadratic, with the peak centered on zero.
6034     const double minRoot   = deCbrt(minExp - 0.5 - (haveSubnormal ? 1.0 : 0.0));
6035     const double maxRoot   = deCbrt(maxExp + 0.5);
6036     const int fractionBits = format.getFractionBits();
6037     const int exp          = int(deRoundEven(dePow(rnd.getDouble(minRoot, maxRoot), 3.0)));
6038 
6039     // Generate some occasional special numbers
6040     switch (rnd.getInt(0, 64))
6041     {
6042     case 0:
6043         return inputRange.contains(0) ? 0 : midpoint;
6044     case 1:
6045         return inputRange.contains(TCU_INFINITY) ? TCU_INFINITY : midpoint;
6046     case 2:
6047         return inputRange.contains(-TCU_INFINITY) ? -TCU_INFINITY : midpoint;
6048     case 3:
6049         return inputRange.contains(TCU_NAN) ? TCU_NAN : midpoint;
6050     default:
6051         break;
6052     }
6053 
6054     DE_ASSERT(fractionBits < std::numeric_limits<double>::digits);
6055 
6056     // Normal number
6057     double base        = deLdExp(1.0, exp);
6058     double quantum     = deLdExp(1.0, exp - fractionBits); // smallest representable difference in the binade
6059     double significand = 0.0;
6060     switch (rnd.getInt(0, 16))
6061     {
6062     case 0: // The highest number in this binade, significand is all bits one.
6063         significand = base - quantum;
6064         break;
6065     case 1: // Significand is one.
6066         significand = quantum;
6067         break;
6068     case 2: // Significand is zero.
6069         significand = 0.0;
6070         break;
6071     default: // Random (evenly distributed) significand.
6072     {
6073         uint64_t intFraction = rnd.getUint64() & ((1 << fractionBits) - 1);
6074         significand          = double(intFraction) * quantum;
6075     }
6076     }
6077 
6078     // Produce positive numbers more often than negative.
6079     double value = (rnd.getInt(0, 3) == 0 ? -1.0 : 1.0) * (base + significand);
6080     return inputRange.contains(value) ? value : midpoint;
6081 }
6082 
6083 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const6084 float DefaultSampling<float>::genRandom(const FloatFormat &format, Precision prec, Random &rnd,
6085                                         const Interval &inputRange) const
6086 {
6087     DE_UNREF(prec);
6088     return (float)randomDouble(format, rnd, inputRange);
6089 }
6090 
6091 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<float> & dst,const Interval & inputRange) const6092 void DefaultSampling<float>::genFixeds(const FloatFormat &format, const Precision prec, vector<float> &dst,
6093                                        const Interval &inputRange) const
6094 {
6095     const int minExp          = format.getMinExp();
6096     const int maxExp          = format.getMaxExp();
6097     const int fractionBits    = format.getFractionBits();
6098     const float minQuantum    = deFloatLdExp(1.0f, minExp - fractionBits);
6099     const float minNormalized = deFloatLdExp(1.0f, minExp);
6100     const float maxQuantum    = deFloatLdExp(1.0f, maxExp - fractionBits);
6101 
6102     // NaN
6103     dst.push_back(TCU_NAN);
6104     // Zero
6105     dst.push_back(0.0f);
6106 
6107     for (int sign = -1; sign <= 1; sign += 2)
6108     {
6109         // Smallest normalized
6110         dst.push_back((float)sign * minNormalized);
6111 
6112         // Next smallest normalized
6113         dst.push_back((float)sign * (minNormalized + minQuantum));
6114 
6115         dst.push_back((float)sign * 0.5f);
6116         dst.push_back((float)sign * 1.0f);
6117         dst.push_back((float)sign * 2.0f);
6118 
6119         // Largest number
6120         dst.push_back((float)sign * (deFloatLdExp(1.0f, maxExp) + (deFloatLdExp(1.0f, maxExp) - maxQuantum)));
6121 
6122         dst.push_back((float)sign * TCU_INFINITY);
6123     }
6124     removeNotInRange(dst, inputRange, prec);
6125 }
6126 
removeNotInRange(vector<float> & dst,const Interval & inputRange,const Precision prec) const6127 void DefaultSampling<float>::removeNotInRange(vector<float> &dst, const Interval &inputRange,
6128                                               const Precision prec) const
6129 {
6130     for (vector<float>::iterator it = dst.begin(); it < dst.end();)
6131     {
6132         // Remove out of range values. PRECISION_LAST means this is an FP16 test so remove any values that
6133         // will be denorms when converted to FP16. (This is used in the precision_fp16_storage32b test group).
6134         if (!inputRange.contains(static_cast<double>(*it)) ||
6135             (prec == glu::PRECISION_LAST && isDenorm16(deFloat32To16Round(*it, DE_ROUNDINGMODE_TO_ZERO))))
6136             it = dst.erase(it);
6137         else
6138             ++it;
6139     }
6140 }
6141 
6142 //! Generate a random double from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const6143 double DefaultSampling<double>::genRandom(const FloatFormat &format, Precision prec, Random &rnd,
6144                                           const Interval &inputRange) const
6145 {
6146     DE_UNREF(prec);
6147     return randomDouble(format, rnd, inputRange);
6148 }
6149 
6150 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<double> & dst,const Interval & inputRange) const6151 void DefaultSampling<double>::genFixeds(const FloatFormat &format, const Precision prec, vector<double> &dst,
6152                                         const Interval &inputRange) const
6153 {
6154     const int minExp           = format.getMinExp();
6155     const int maxExp           = format.getMaxExp();
6156     const int fractionBits     = format.getFractionBits();
6157     const double minQuantum    = deLdExp(1.0, minExp - fractionBits);
6158     const double minNormalized = deLdExp(1.0, minExp);
6159     const double maxQuantum    = deLdExp(1.0, maxExp - fractionBits);
6160 
6161     // NaN
6162     dst.push_back(TCU_NAN);
6163     // Zero
6164     dst.push_back(0.0);
6165 
6166     for (int sign = -1; sign <= 1; sign += 2)
6167     {
6168         // Smallest normalized
6169         dst.push_back((double)sign * minNormalized);
6170 
6171         // Next smallest normalized
6172         dst.push_back((double)sign * (minNormalized + minQuantum));
6173 
6174         dst.push_back((double)sign * 0.5);
6175         dst.push_back((double)sign * 1.0);
6176         dst.push_back((double)sign * 2.0);
6177 
6178         // Largest number
6179         dst.push_back((double)sign * (deLdExp(1.0, maxExp) + (deLdExp(1.0, maxExp) - maxQuantum)));
6180 
6181         dst.push_back((double)sign * TCU_INFINITY);
6182     }
6183     removeNotInRange(dst, inputRange, prec);
6184 }
6185 
removeNotInRange(vector<double> & dst,const Interval & inputRange,const Precision) const6186 void DefaultSampling<double>::removeNotInRange(vector<double> &dst, const Interval &inputRange, const Precision) const
6187 {
6188     for (vector<double>::iterator it = dst.begin(); it < dst.end();)
6189     {
6190         if (!inputRange.contains(*it))
6191             it = dst.erase(it);
6192         else
6193             ++it;
6194     }
6195 }
6196 
6197 template <>
6198 class DefaultSampling<deFloat16> : public Sampling<deFloat16>
6199 {
6200 public:
6201     deFloat16 genRandom(const FloatFormat &format, const Precision prec, Random &rnd, const Interval &inputRange) const;
6202     void genFixeds(const FloatFormat &format, const Precision prec, vector<deFloat16> &dst,
6203                    const Interval &inputRange) const;
6204 
6205 private:
6206     void removeNotInRange(vector<deFloat16> &dst, const Interval &inputRange, const Precision prec) const;
6207 };
6208 
6209 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,const Precision prec,Random & rnd,const Interval & inputRange) const6210 deFloat16 DefaultSampling<deFloat16>::genRandom(const FloatFormat &format, const Precision prec, Random &rnd,
6211                                                 const Interval &inputRange) const
6212 {
6213     DE_UNREF(prec);
6214     return deFloat64To16Round(randomDouble(format, rnd, inputRange), DE_ROUNDINGMODE_TO_NEAREST_EVEN);
6215 }
6216 
6217 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<deFloat16> & dst,const Interval & inputRange) const6218 void DefaultSampling<deFloat16>::genFixeds(const FloatFormat &format, const Precision prec, vector<deFloat16> &dst,
6219                                            const Interval &inputRange) const
6220 {
6221     dst.push_back(uint16_t(0x3E00)); //1.5
6222     dst.push_back(uint16_t(0x3D00)); //1.25
6223     dst.push_back(uint16_t(0x3F00)); //1.75
6224     // Zero
6225     dst.push_back(uint16_t(0x0000));
6226     dst.push_back(uint16_t(0x8000));
6227     // Infinity
6228     dst.push_back(uint16_t(0x7c00));
6229     dst.push_back(uint16_t(0xfc00));
6230     // SNaN
6231     dst.push_back(uint16_t(0x7c0f));
6232     dst.push_back(uint16_t(0xfc0f));
6233     // QNaN
6234     dst.push_back(uint16_t(0x7cf0));
6235     dst.push_back(uint16_t(0xfcf0));
6236     // Normalized
6237     dst.push_back(uint16_t(0x0401));
6238     dst.push_back(uint16_t(0x8401));
6239     // Some normal number
6240     dst.push_back(uint16_t(0x14cb));
6241     dst.push_back(uint16_t(0x94cb));
6242 
6243     const int minExp          = format.getMinExp();
6244     const int maxExp          = format.getMaxExp();
6245     const int fractionBits    = format.getFractionBits();
6246     const float minQuantum    = deFloatLdExp(1.0f, minExp - fractionBits);
6247     const float minNormalized = deFloatLdExp(1.0f, minExp);
6248     const float maxQuantum    = deFloatLdExp(1.0f, maxExp - fractionBits);
6249 
6250     for (float sign = -1.0; sign <= 1.0f; sign += 2.0f)
6251     {
6252         // Smallest normalized
6253         dst.push_back(deFloat32To16Round(sign * minNormalized, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6254 
6255         // Next smallest normalized
6256         dst.push_back(deFloat32To16Round(sign * (minNormalized + minQuantum), DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6257 
6258         dst.push_back(deFloat32To16Round(sign * 0.5f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6259         dst.push_back(deFloat32To16Round(sign * 1.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6260         dst.push_back(deFloat32To16Round(sign * 2.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6261 
6262         // Largest number
6263         dst.push_back(
6264             deFloat32To16Round(sign * (deFloatLdExp(1.0f, maxExp) + (deFloatLdExp(1.0f, maxExp) - maxQuantum)),
6265                                DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6266 
6267         dst.push_back(deFloat32To16Round(sign * TCU_INFINITY, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6268     }
6269     removeNotInRange(dst, inputRange, prec);
6270 }
6271 
removeNotInRange(vector<deFloat16> & dst,const Interval & inputRange,const Precision) const6272 void DefaultSampling<deFloat16>::removeNotInRange(vector<deFloat16> &dst, const Interval &inputRange,
6273                                                   const Precision) const
6274 {
6275     for (vector<deFloat16>::iterator it = dst.begin(); it < dst.end();)
6276     {
6277         if (inputRange.contains(static_cast<double>(*it)))
6278             ++it;
6279         else
6280             it = dst.erase(it);
6281     }
6282 }
6283 
6284 template <typename T, int Size>
6285 class DefaultSampling<Vector<T, Size>> : public Sampling<Vector<T, Size>>
6286 {
6287 public:
6288     typedef Vector<T, Size> Value;
6289 
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const6290     Value genRandom(const FloatFormat &fmt, const Precision prec, Random &rnd, const Interval &inputRange) const
6291     {
6292         Value ret;
6293 
6294         for (int ndx = 0; ndx < Size; ++ndx)
6295             ret[ndx] = instance<DefaultSampling<T>>().genRandom(fmt, prec, rnd, inputRange);
6296 
6297         return ret;
6298     }
6299 
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const6300     void genFixeds(const FloatFormat &fmt, const Precision prec, vector<Value> &dst, const Interval &inputRange) const
6301     {
6302         vector<T> scalars;
6303 
6304         instance<DefaultSampling<T>>().genFixeds(fmt, prec, scalars, inputRange);
6305 
6306         for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
6307             dst.push_back(Value(scalars[scalarNdx]));
6308     }
6309 };
6310 
6311 template <typename T, int Rows, int Columns>
6312 class DefaultSampling<Matrix<T, Rows, Columns>> : public Sampling<Matrix<T, Rows, Columns>>
6313 {
6314 public:
6315     typedef Matrix<T, Rows, Columns> Value;
6316 
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const6317     Value genRandom(const FloatFormat &fmt, const Precision prec, Random &rnd, const Interval &inputRange) const
6318     {
6319         Value ret;
6320 
6321         for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
6322             for (int colNdx = 0; colNdx < Columns; ++colNdx)
6323                 ret(rowNdx, colNdx) = instance<DefaultSampling<T>>().genRandom(fmt, prec, rnd, inputRange);
6324 
6325         return ret;
6326     }
6327 
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const6328     void genFixeds(const FloatFormat &fmt, const Precision prec, vector<Value> &dst, const Interval &inputRange) const
6329     {
6330         vector<T> scalars;
6331 
6332         instance<DefaultSampling<T>>().genFixeds(fmt, prec, scalars, inputRange);
6333 
6334         for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
6335             dst.push_back(Value(scalars[scalarNdx]));
6336 
6337         if (Columns == Rows)
6338         {
6339             Value mat(T(0.0));
6340             T x       = T(1.0f);
6341             mat[0][0] = x;
6342             for (int ndx = 0; ndx < Columns; ++ndx)
6343             {
6344                 mat[Columns - 1 - ndx][ndx] = x;
6345                 x                           = static_cast<T>(x * static_cast<T>(2.0f));
6346             }
6347             dst.push_back(mat);
6348         }
6349     }
6350 };
6351 
6352 struct CaseContext
6353 {
CaseContextvkt::shaderexecutor::CaseContext6354     CaseContext(const string &name_, TestContext &testContext_, const FloatFormat &floatFormat_,
6355                 const FloatFormat &highpFormat_, const Precision precision_, const ShaderType shaderType_,
6356                 const size_t numRandoms_,
6357                 const PrecisionTestFeatures precisionTestFeatures_ = PRECISION_TEST_FEATURES_NONE,
6358                 const bool isPackFloat16b_ = false, const bool isFloat64b_ = false)
6359         : name(name_)
6360         , testContext(testContext_)
6361         , floatFormat(floatFormat_)
6362         , highpFormat(highpFormat_)
6363         , precision(precision_)
6364         , shaderType(shaderType_)
6365         , numRandoms(numRandoms_)
6366         , inputRange(-TCU_INFINITY, TCU_INFINITY)
6367         , precisionTestFeatures(precisionTestFeatures_)
6368         , isPackFloat16b(isPackFloat16b_)
6369         , isFloat64b(isFloat64b_)
6370     {
6371     }
6372 
6373     string name;
6374     TestContext &testContext;
6375     FloatFormat floatFormat;
6376     FloatFormat highpFormat;
6377     Precision precision;
6378     ShaderType shaderType;
6379     size_t numRandoms;
6380     Interval inputRange;
6381     PrecisionTestFeatures precisionTestFeatures;
6382     bool isPackFloat16b;
6383     bool isFloat64b;
6384 };
6385 
6386 template <typename In0_ = Void, typename In1_ = Void, typename In2_ = Void, typename In3_ = Void>
6387 struct InTypes
6388 {
6389     typedef In0_ In0;
6390     typedef In1_ In1;
6391     typedef In2_ In2;
6392     typedef In3_ In3;
6393 };
6394 
6395 template <typename In>
numInputs(void)6396 int numInputs(void)
6397 {
6398     return (!isTypeValid<typename In::In0>() ? 0 :
6399             !isTypeValid<typename In::In1>() ? 1 :
6400             !isTypeValid<typename In::In2>() ? 2 :
6401             !isTypeValid<typename In::In3>() ? 3 :
6402                                                4);
6403 }
6404 
6405 template <typename Out0_, typename Out1_ = Void>
6406 struct OutTypes
6407 {
6408     typedef Out0_ Out0;
6409     typedef Out1_ Out1;
6410 };
6411 
6412 template <typename Out>
numOutputs(void)6413 int numOutputs(void)
6414 {
6415     return (!isTypeValid<typename Out::Out0>() ? 0 : !isTypeValid<typename Out::Out1>() ? 1 : 2);
6416 }
6417 
6418 template <typename In>
6419 struct Inputs
6420 {
6421     vector<typename In::In0> in0;
6422     vector<typename In::In1> in1;
6423     vector<typename In::In2> in2;
6424     vector<typename In::In3> in3;
6425 };
6426 
6427 template <typename Out>
6428 struct Outputs
6429 {
Outputsvkt::shaderexecutor::Outputs6430     Outputs(size_t size) : out0(size), out1(size)
6431     {
6432     }
6433 
6434     vector<typename Out::Out0> out0;
6435     vector<typename Out::Out1> out1;
6436 };
6437 
6438 template <typename In, typename Out>
6439 struct Variables
6440 {
6441     VariableP<typename In::In0> in0;
6442     VariableP<typename In::In1> in1;
6443     VariableP<typename In::In2> in2;
6444     VariableP<typename In::In3> in3;
6445     VariableP<typename Out::Out0> out0;
6446     VariableP<typename Out::Out1> out1;
6447 };
6448 
6449 template <typename In>
6450 struct Samplings
6451 {
Samplingsvkt::shaderexecutor::Samplings6452     Samplings(const Sampling<typename In::In0> &in0_, const Sampling<typename In::In1> &in1_,
6453               const Sampling<typename In::In2> &in2_, const Sampling<typename In::In3> &in3_)
6454         : in0(in0_)
6455         , in1(in1_)
6456         , in2(in2_)
6457         , in3(in3_)
6458     {
6459     }
6460 
6461     const Sampling<typename In::In0> &in0;
6462     const Sampling<typename In::In1> &in1;
6463     const Sampling<typename In::In2> &in2;
6464     const Sampling<typename In::In3> &in3;
6465 };
6466 
6467 template <typename In>
6468 struct DefaultSamplings : Samplings<In>
6469 {
DefaultSamplingsvkt::shaderexecutor::DefaultSamplings6470     DefaultSamplings(void)
6471         : Samplings<In>(instance<DefaultSampling<typename In::In0>>(), instance<DefaultSampling<typename In::In1>>(),
6472                         instance<DefaultSampling<typename In::In2>>(), instance<DefaultSampling<typename In::In3>>())
6473     {
6474     }
6475 };
6476 
6477 template <typename In, typename Out>
6478 class BuiltinPrecisionCaseTestInstance : public TestInstance
6479 {
6480 public:
BuiltinPrecisionCaseTestInstance(Context & context,const CaseContext caseCtx,const ShaderSpec & shaderSpec,const Variables<In,Out> variables,const Samplings<In> & samplings,const StatementP stmt,bool modularOp=false)6481     BuiltinPrecisionCaseTestInstance(Context &context, const CaseContext caseCtx, const ShaderSpec &shaderSpec,
6482                                      const Variables<In, Out> variables, const Samplings<In> &samplings,
6483                                      const StatementP stmt, bool modularOp = false)
6484         : TestInstance(context)
6485         , m_caseCtx(caseCtx)
6486         , m_variables(variables)
6487         , m_samplings(samplings)
6488         , m_stmt(stmt)
6489         , m_executor(createExecutor(context, caseCtx.shaderType, shaderSpec))
6490         , m_modularOp(modularOp)
6491     {
6492     }
6493     virtual tcu::TestStatus iterate(void);
6494 
6495 protected:
6496     CaseContext m_caseCtx;
6497     Variables<In, Out> m_variables;
6498     const Samplings<In> &m_samplings;
6499     StatementP m_stmt;
6500     de::UniquePtr<ShaderExecutor> m_executor;
6501     bool m_modularOp;
6502 };
6503 
6504 template <class In, class Out>
iterate(void)6505 tcu::TestStatus BuiltinPrecisionCaseTestInstance<In, Out>::iterate(void)
6506 {
6507     typedef typename In::In0 In0;
6508     typedef typename In::In1 In1;
6509     typedef typename In::In2 In2;
6510     typedef typename In::In3 In3;
6511     typedef typename Out::Out0 Out0;
6512     typedef typename Out::Out1 Out1;
6513 
6514     areFeaturesSupported(m_context, m_caseCtx.precisionTestFeatures);
6515     Inputs<In> inputs =
6516         generateInputs(m_samplings, m_caseCtx.floatFormat, m_caseCtx.precision, m_caseCtx.numRandoms,
6517                        0xdeadbeefu + m_caseCtx.testContext.getCommandLine().getBaseSeed(), m_caseCtx.inputRange);
6518     const FloatFormat &fmt = m_caseCtx.floatFormat;
6519     const int inCount      = numInputs<In>();
6520     const int outCount     = numOutputs<Out>();
6521     const size_t numValues = (inCount > 0) ? inputs.in0.size() : 1;
6522     Outputs<Out> outputs(numValues);
6523     const FloatFormat highpFmt = m_caseCtx.highpFormat;
6524     const int maxMsgs          = 100;
6525     int numErrors              = 0;
6526     Environment env; // Hoisted out of the inner loop for optimization.
6527     ResultCollector status;
6528     TestLog &testLog = m_context.getTestContext().getLog();
6529 
6530     // Module operations need exactly two inputs and have exactly one output.
6531     if (m_modularOp)
6532     {
6533         DE_ASSERT(inCount == 2);
6534         DE_ASSERT(outCount == 1);
6535     }
6536 
6537     const void *inputArr[] = {
6538         inputs.in0.data(),
6539         inputs.in1.data(),
6540         inputs.in2.data(),
6541         inputs.in3.data(),
6542     };
6543     void *outputArr[] = {
6544         outputs.out0.data(),
6545         outputs.out1.data(),
6546     };
6547 
6548     // Print out the statement and its definitions
6549     testLog << TestLog::Message << "Statement: " << m_stmt << TestLog::EndMessage;
6550     {
6551         ostringstream oss;
6552         FuncSet funcs;
6553 
6554         m_stmt->getUsedFuncs(funcs);
6555         for (FuncSet::const_iterator it = funcs.begin(); it != funcs.end(); ++it)
6556         {
6557             (*it)->printDefinition(oss);
6558         }
6559         if (!funcs.empty())
6560             testLog << TestLog::Message << "Reference definitions:\n" << oss.str() << TestLog::EndMessage;
6561     }
6562     switch (inCount)
6563     {
6564     case 4:
6565         DE_ASSERT(inputs.in3.size() == numValues);
6566     // Fallthrough
6567     case 3:
6568         DE_ASSERT(inputs.in2.size() == numValues);
6569     // Fallthrough
6570     case 2:
6571         DE_ASSERT(inputs.in1.size() == numValues);
6572     // Fallthrough
6573     case 1:
6574         DE_ASSERT(inputs.in0.size() == numValues);
6575     // Fallthrough
6576     default:
6577         break;
6578     }
6579 
6580     m_executor->execute(int(numValues), inputArr, outputArr);
6581 
6582     // Initialize environment with unused values so we don't need to bind in inner loop.
6583     {
6584         const typename Traits<In0>::IVal in0;
6585         const typename Traits<In1>::IVal in1;
6586         const typename Traits<In2>::IVal in2;
6587         const typename Traits<In3>::IVal in3;
6588         const typename Traits<Out0>::IVal reference0;
6589         const typename Traits<Out1>::IVal reference1;
6590 
6591         env.bind(*m_variables.in0, in0);
6592         env.bind(*m_variables.in1, in1);
6593         env.bind(*m_variables.in2, in2);
6594         env.bind(*m_variables.in3, in3);
6595         env.bind(*m_variables.out0, reference0);
6596         env.bind(*m_variables.out1, reference1);
6597     }
6598 
6599     // For each input tuple, compute output reference interval and compare
6600     // shader output to the reference.
6601     for (size_t valueNdx = 0; valueNdx < numValues; valueNdx++)
6602     {
6603         bool result             = true;
6604         const bool isInput16Bit = m_executor->areInputs16Bit();
6605         const bool isInput64Bit = m_executor->areInputs64Bit();
6606 
6607         DE_ASSERT(!(isInput16Bit && isInput64Bit));
6608 
6609         typename Traits<Out0>::IVal reference0;
6610         typename Traits<Out1>::IVal reference1;
6611 
6612         if (valueNdx % (size_t)TOUCH_WATCHDOG_VALUE_FREQUENCY == 0)
6613             m_context.getTestContext().touchWatchdog();
6614 
6615         env.lookup(*m_variables.in0) = convert<In0>(fmt, round(fmt, inputs.in0[valueNdx]));
6616         env.lookup(*m_variables.in1) = convert<In1>(fmt, round(fmt, inputs.in1[valueNdx]));
6617         env.lookup(*m_variables.in2) = convert<In2>(fmt, round(fmt, inputs.in2[valueNdx]));
6618         env.lookup(*m_variables.in3) = convert<In3>(fmt, round(fmt, inputs.in3[valueNdx]));
6619 
6620         {
6621             EvalContext ctx(fmt, m_caseCtx.precision, env, 0);
6622             m_stmt->execute(ctx);
6623 
6624             switch (outCount)
6625             {
6626             case 2:
6627                 reference1 = convert<Out1>(highpFmt, env.lookup(*m_variables.out1));
6628                 if (!status.check(contains(reference1, outputs.out1[valueNdx], m_caseCtx.isPackFloat16b),
6629                                   "Shader output 1 is outside acceptable range"))
6630                     result = false;
6631             // Fallthrough
6632             case 1:
6633             {
6634                 // Pass b from mod(a, b) if we are in the modulo operation.
6635                 const tcu::Maybe<In1> modularDivisor = (m_modularOp ? tcu::just(inputs.in1[valueNdx]) : tcu::Nothing);
6636 
6637                 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6638                 if (!status.check(
6639                         contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor),
6640                         "Shader output 0 is outside acceptable range"))
6641                 {
6642                     m_stmt->failed(ctx);
6643                     reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6644                     if (!status.check(
6645                             contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor),
6646                             "Shader output 0 is outside acceptable range"))
6647                         result = false;
6648                 }
6649             }
6650             // Fallthrough
6651             default:
6652                 break;
6653             }
6654         }
6655         if (!result)
6656             ++numErrors;
6657 
6658         if ((!result && numErrors <= maxMsgs) || GLS_LOG_ALL_RESULTS)
6659         {
6660             MessageBuilder builder = testLog.message();
6661 
6662             builder << (result ? "Passed" : "Failed") << " sample:\n";
6663 
6664             if (inCount > 0)
6665             {
6666                 builder << "\t" << m_variables.in0->getName() << " = "
6667                         << (isInput64Bit ? value64ToString(highpFmt, inputs.in0[valueNdx]) :
6668                                            (isInput16Bit ? value16ToString(highpFmt, inputs.in0[valueNdx]) :
6669                                                            value32ToString(highpFmt, inputs.in0[valueNdx])))
6670                         << "\n";
6671             }
6672 
6673             if (inCount > 1)
6674             {
6675                 builder << "\t" << m_variables.in1->getName() << " = "
6676                         << (isInput64Bit ? value64ToString(highpFmt, inputs.in1[valueNdx]) :
6677                                            (isInput16Bit ? value16ToString(highpFmt, inputs.in1[valueNdx]) :
6678                                                            value32ToString(highpFmt, inputs.in1[valueNdx])))
6679                         << "\n";
6680             }
6681 
6682             if (inCount > 2)
6683             {
6684                 builder << "\t" << m_variables.in2->getName() << " = "
6685                         << (isInput64Bit ? value64ToString(highpFmt, inputs.in2[valueNdx]) :
6686                                            (isInput16Bit ? value16ToString(highpFmt, inputs.in2[valueNdx]) :
6687                                                            value32ToString(highpFmt, inputs.in2[valueNdx])))
6688                         << "\n";
6689             }
6690 
6691             if (inCount > 3)
6692             {
6693                 builder << "\t" << m_variables.in3->getName() << " = "
6694                         << (isInput64Bit ? value64ToString(highpFmt, inputs.in3[valueNdx]) :
6695                                            (isInput16Bit ? value16ToString(highpFmt, inputs.in3[valueNdx]) :
6696                                                            value32ToString(highpFmt, inputs.in3[valueNdx])))
6697                         << "\n";
6698             }
6699 
6700             if (outCount > 0)
6701             {
6702                 if (m_executor->spirvCase() == SPIRV_CASETYPE_COMPARE)
6703                 {
6704                     builder << "Output:\n"
6705                             << comparisonMessage(outputs.out0[valueNdx]) << "Expected result:\n"
6706                             << comparisonMessageInterval<typename Out::Out0>(reference0) << "\n";
6707                 }
6708                 else
6709                 {
6710                     builder << "\t" << m_variables.out0->getName() << " = "
6711                             << (m_executor->isOutput64Bit(0u) ?
6712                                     value64ToString(highpFmt, outputs.out0[valueNdx]) :
6713                                     (m_executor->isOutput16Bit(0u) || m_caseCtx.isPackFloat16b ?
6714                                          value16ToString(highpFmt, outputs.out0[valueNdx]) :
6715                                          value32ToString(highpFmt, outputs.out0[valueNdx])))
6716                             << "\n"
6717                             << "\tExpected range: " << intervalToString<typename Out::Out0>(highpFmt, reference0)
6718                             << "\n";
6719                 }
6720             }
6721 
6722             if (outCount > 1)
6723             {
6724                 builder << "\t" << m_variables.out1->getName() << " = "
6725                         << (m_executor->isOutput64Bit(1u) ? value64ToString(highpFmt, outputs.out1[valueNdx]) :
6726                                                             (m_executor->isOutput16Bit(1u) || m_caseCtx.isPackFloat16b ?
6727                                                                  value16ToString(highpFmt, outputs.out1[valueNdx]) :
6728                                                                  value32ToString(highpFmt, outputs.out1[valueNdx])))
6729                         << "\n"
6730                         << "\tExpected range: " << intervalToString<typename Out::Out1>(highpFmt, reference1) << "\n";
6731             }
6732 
6733             builder << TestLog::EndMessage;
6734         }
6735     }
6736 
6737     if (numErrors > maxMsgs)
6738     {
6739         testLog << TestLog::Message << "(Skipped " << (numErrors - maxMsgs) << " messages.)" << TestLog::EndMessage;
6740     }
6741 
6742     if (numErrors == 0)
6743     {
6744         testLog << TestLog::Message << "All " << numValues << " inputs passed." << TestLog::EndMessage;
6745     }
6746     else
6747     {
6748         testLog << TestLog::Message << numErrors << "/" << numValues << " inputs failed." << TestLog::EndMessage;
6749     }
6750 
6751     if (numErrors)
6752         return tcu::TestStatus::fail(de::toString(numErrors) + string(" test failed. Check log for the details"));
6753     else
6754         return tcu::TestStatus::pass("Pass");
6755 }
6756 
6757 class PrecisionCase : public TestCase
6758 {
6759 protected:
PrecisionCase(const CaseContext & context,const string & name,const Interval & inputRange,const string & extension="")6760     PrecisionCase(const CaseContext &context, const string &name, const Interval &inputRange,
6761                   const string &extension = "")
6762         : TestCase(context.testContext, name.c_str())
6763         , m_ctx(context)
6764         , m_extension(extension)
6765     {
6766         m_ctx.inputRange      = inputRange;
6767         m_spec.packFloat16Bit = context.isPackFloat16b;
6768     }
6769 
initPrograms(vk::SourceCollections & programCollection) const6770     virtual void initPrograms(vk::SourceCollections &programCollection) const
6771     {
6772         generateSources(m_ctx.shaderType, m_spec, programCollection);
6773     }
6774 
getFormat(void) const6775     const FloatFormat &getFormat(void) const
6776     {
6777         return m_ctx.floatFormat;
6778     }
6779 
6780     template <typename In, typename Out>
6781     void testStatement(const Variables<In, Out> &variables, const Statement &stmt, SpirVCaseT spirvCase);
6782 
6783     template <typename T>
makeSymbol(const Variable<T> & variable)6784     Symbol makeSymbol(const Variable<T> &variable)
6785     {
6786         return Symbol(variable.getName(), getVarTypeOf<T>(m_ctx.precision));
6787     }
6788 
6789     CaseContext m_ctx;
6790     const string m_extension;
6791     ShaderSpec m_spec;
6792 };
6793 
6794 template <typename In, typename Out>
testStatement(const Variables<In,Out> & variables,const Statement & stmt,SpirVCaseT spirvCase)6795 void PrecisionCase::testStatement(const Variables<In, Out> &variables, const Statement &stmt, SpirVCaseT spirvCase)
6796 {
6797     const int inCount  = numInputs<In>();
6798     const int outCount = numOutputs<Out>();
6799     Environment env; // Hoisted out of the inner loop for optimization.
6800 
6801     // Initialize ShaderSpec from precision, variables and statement.
6802     if (m_ctx.precision != glu::PRECISION_LAST)
6803     {
6804         ostringstream os;
6805         os << "precision " << glu::getPrecisionName(m_ctx.precision) << " float;\n";
6806         m_spec.globalDeclarations = os.str();
6807     }
6808 
6809     if (!m_extension.empty())
6810         m_spec.globalDeclarations = "#extension " + m_extension + " : require\n";
6811 
6812     m_spec.inputs.resize(inCount);
6813 
6814     switch (inCount)
6815     {
6816     case 4:
6817         m_spec.inputs[3] = makeSymbol(*variables.in3);
6818     // Fallthrough
6819     case 3:
6820         m_spec.inputs[2] = makeSymbol(*variables.in2);
6821     // Fallthrough
6822     case 2:
6823         m_spec.inputs[1] = makeSymbol(*variables.in1);
6824     // Fallthrough
6825     case 1:
6826         m_spec.inputs[0] = makeSymbol(*variables.in0);
6827     // Fallthrough
6828     default:
6829         break;
6830     }
6831 
6832     bool inputs16Bit = false;
6833     for (vector<Symbol>::const_iterator symIter = m_spec.inputs.begin(); symIter != m_spec.inputs.end(); ++symIter)
6834         inputs16Bit = inputs16Bit || glu::isDataTypeFloat16OrVec(symIter->varType.getBasicType());
6835 
6836     if (inputs16Bit || m_spec.packFloat16Bit)
6837         m_spec.globalDeclarations += "#extension GL_EXT_shader_explicit_arithmetic_types: require\n";
6838 
6839     m_spec.outputs.resize(outCount);
6840 
6841     switch (outCount)
6842     {
6843     case 2:
6844         m_spec.outputs[1] = makeSymbol(*variables.out1);
6845     // Fallthrough
6846     case 1:
6847         m_spec.outputs[0] = makeSymbol(*variables.out0);
6848     // Fallthrough
6849     default:
6850         break;
6851     }
6852 
6853     m_spec.source    = de::toString(stmt);
6854     m_spec.spirvCase = spirvCase;
6855 }
6856 
6857 template <typename T>
6858 struct InputLess
6859 {
operator ()vkt::shaderexecutor::InputLess6860     bool operator()(const T &val1, const T &val2) const
6861     {
6862         return val1 < val2;
6863     }
6864 };
6865 
6866 template <typename T>
inputLess(const T & val1,const T & val2)6867 bool inputLess(const T &val1, const T &val2)
6868 {
6869     return InputLess<T>()(val1, val2);
6870 }
6871 
6872 template <>
6873 struct InputLess<float>
6874 {
operator ()vkt::shaderexecutor::InputLess6875     bool operator()(const float &val1, const float &val2) const
6876     {
6877         if (deIsNaN(val1))
6878             return false;
6879         if (deIsNaN(val2))
6880             return true;
6881         return val1 < val2;
6882     }
6883 };
6884 
6885 template <typename T, int Size>
6886 struct InputLess<Vector<T, Size>>
6887 {
operator ()vkt::shaderexecutor::InputLess6888     bool operator()(const Vector<T, Size> &vec1, const Vector<T, Size> &vec2) const
6889     {
6890         for (int ndx = 0; ndx < Size; ++ndx)
6891         {
6892             if (inputLess(vec1[ndx], vec2[ndx]))
6893                 return true;
6894             if (inputLess(vec2[ndx], vec1[ndx]))
6895                 return false;
6896         }
6897 
6898         return false;
6899     }
6900 };
6901 
6902 template <typename T, int Rows, int Cols>
6903 struct InputLess<Matrix<T, Rows, Cols>>
6904 {
operator ()vkt::shaderexecutor::InputLess6905     bool operator()(const Matrix<T, Rows, Cols> &mat1, const Matrix<T, Rows, Cols> &mat2) const
6906     {
6907         for (int col = 0; col < Cols; ++col)
6908         {
6909             if (inputLess(mat1[col], mat2[col]))
6910                 return true;
6911             if (inputLess(mat2[col], mat1[col]))
6912                 return false;
6913         }
6914 
6915         return false;
6916     }
6917 };
6918 
6919 template <typename In>
6920 struct InTuple : public Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>
6921 {
InTuplevkt::shaderexecutor::InTuple6922     InTuple(const typename In::In0 &in0, const typename In::In1 &in1, const typename In::In2 &in2,
6923             const typename In::In3 &in3)
6924         : Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>(in0, in1, in2, in3)
6925     {
6926     }
6927 };
6928 
6929 template <typename In>
6930 struct InputLess<InTuple<In>>
6931 {
operator ()vkt::shaderexecutor::InputLess6932     bool operator()(const InTuple<In> &in1, const InTuple<In> &in2) const
6933     {
6934         if (inputLess(in1.a, in2.a))
6935             return true;
6936         if (inputLess(in2.a, in1.a))
6937             return false;
6938         if (inputLess(in1.b, in2.b))
6939             return true;
6940         if (inputLess(in2.b, in1.b))
6941             return false;
6942         if (inputLess(in1.c, in2.c))
6943             return true;
6944         if (inputLess(in2.c, in1.c))
6945             return false;
6946         if (inputLess(in1.d, in2.d))
6947             return true;
6948         return false;
6949     }
6950 };
6951 
6952 template <typename In>
generateInputs(const Samplings<In> & samplings,const FloatFormat & floatFormat,Precision intPrecision,size_t numSamples,uint32_t seed,const Interval & inputRange)6953 Inputs<In> generateInputs(const Samplings<In> &samplings, const FloatFormat &floatFormat, Precision intPrecision,
6954                           size_t numSamples, uint32_t seed, const Interval &inputRange)
6955 {
6956     Random rnd(seed);
6957     Inputs<In> ret;
6958     Inputs<In> fixedInputs;
6959     set<InTuple<In>, InputLess<InTuple<In>>> seenInputs;
6960 
6961     samplings.in0.genFixeds(floatFormat, intPrecision, fixedInputs.in0, inputRange);
6962     samplings.in1.genFixeds(floatFormat, intPrecision, fixedInputs.in1, inputRange);
6963     samplings.in2.genFixeds(floatFormat, intPrecision, fixedInputs.in2, inputRange);
6964     samplings.in3.genFixeds(floatFormat, intPrecision, fixedInputs.in3, inputRange);
6965 
6966     for (size_t ndx0 = 0; ndx0 < fixedInputs.in0.size(); ++ndx0)
6967     {
6968         for (size_t ndx1 = 0; ndx1 < fixedInputs.in1.size(); ++ndx1)
6969         {
6970             for (size_t ndx2 = 0; ndx2 < fixedInputs.in2.size(); ++ndx2)
6971             {
6972                 for (size_t ndx3 = 0; ndx3 < fixedInputs.in3.size(); ++ndx3)
6973                 {
6974                     const InTuple<In> tuple(fixedInputs.in0[ndx0], fixedInputs.in1[ndx1], fixedInputs.in2[ndx2],
6975                                             fixedInputs.in3[ndx3]);
6976 
6977                     seenInputs.insert(tuple);
6978                     ret.in0.push_back(tuple.a);
6979                     ret.in1.push_back(tuple.b);
6980                     ret.in2.push_back(tuple.c);
6981                     ret.in3.push_back(tuple.d);
6982                 }
6983             }
6984         }
6985     }
6986 
6987     for (size_t ndx = 0; ndx < numSamples; ++ndx)
6988     {
6989         const typename In::In0 in0 = samplings.in0.genRandom(floatFormat, intPrecision, rnd, inputRange);
6990         const typename In::In1 in1 = samplings.in1.genRandom(floatFormat, intPrecision, rnd, inputRange);
6991         const typename In::In2 in2 = samplings.in2.genRandom(floatFormat, intPrecision, rnd, inputRange);
6992         const typename In::In3 in3 = samplings.in3.genRandom(floatFormat, intPrecision, rnd, inputRange);
6993         const InTuple<In> tuple(in0, in1, in2, in3);
6994 
6995         if (de::contains(seenInputs, tuple))
6996             continue;
6997 
6998         seenInputs.insert(tuple);
6999         ret.in0.push_back(in0);
7000         ret.in1.push_back(in1);
7001         ret.in2.push_back(in2);
7002         ret.in3.push_back(in3);
7003     }
7004 
7005     return ret;
7006 }
7007 
7008 class FuncCaseBase : public PrecisionCase
7009 {
7010 protected:
FuncCaseBase(const CaseContext & context,const string & name,const FuncBase & func)7011     FuncCaseBase(const CaseContext &context, const string &name, const FuncBase &func)
7012         : PrecisionCase(context, name,
7013                         func.getInputRange(!context.isFloat64b &&
7014                                            (context.precision == glu::PRECISION_LAST || context.isPackFloat16b)),
7015                         func.getRequiredExtension())
7016     {
7017     }
7018 
7019     StatementP m_stmt;
7020 };
7021 
7022 template <typename Sig>
7023 class FuncCase : public FuncCaseBase
7024 {
7025 public:
7026     typedef Func<Sig> CaseFunc;
7027     typedef typename Sig::Ret Ret;
7028     typedef typename Sig::Arg0 Arg0;
7029     typedef typename Sig::Arg1 Arg1;
7030     typedef typename Sig::Arg2 Arg2;
7031     typedef typename Sig::Arg3 Arg3;
7032     typedef InTypes<Arg0, Arg1, Arg2, Arg3> In;
7033     typedef OutTypes<Ret> Out;
7034 
FuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)7035     FuncCase(const CaseContext &context, const string &name, const CaseFunc &func, bool modularOp = false)
7036         : FuncCaseBase(context, name, func)
7037         , m_func(func)
7038         , m_modularOp(modularOp)
7039     {
7040         buildTest();
7041     }
7042 
createInstance(Context & context) const7043     virtual TestInstance *createInstance(Context &context) const
7044     {
7045         return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(),
7046                                                              m_stmt, m_modularOp);
7047     }
7048 
7049 protected:
7050     void buildTest(void);
getSamplings(void) const7051     virtual const Samplings<In> &getSamplings(void) const
7052     {
7053         return instance<DefaultSamplings<In>>();
7054     }
7055 
7056 private:
7057     const CaseFunc &m_func;
7058     Variables<In, Out> m_variables;
7059     bool m_modularOp;
7060 };
7061 
7062 template <typename Sig>
buildTest(void)7063 void FuncCase<Sig>::buildTest(void)
7064 {
7065     m_variables.out0 = variable<Ret>("out0");
7066     m_variables.out1 = variable<Void>("out1");
7067     m_variables.in0  = variable<Arg0>("in0");
7068     m_variables.in1  = variable<Arg1>("in1");
7069     m_variables.in2  = variable<Arg2>("in2");
7070     m_variables.in3  = variable<Arg3>("in3");
7071 
7072     {
7073         ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.in1, m_variables.in2, m_variables.in3);
7074         m_stmt          = variableAssignment(m_variables.out0, expr);
7075 
7076         this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
7077     }
7078 }
7079 
7080 template <typename Sig>
7081 class InOutFuncCase : public FuncCaseBase
7082 {
7083 public:
7084     typedef Func<Sig> CaseFunc;
7085     typedef typename Sig::Ret Ret;
7086     typedef typename Sig::Arg0 Arg0;
7087     typedef typename Sig::Arg1 Arg1;
7088     typedef typename Sig::Arg2 Arg2;
7089     typedef typename Sig::Arg3 Arg3;
7090     typedef InTypes<Arg0, Arg2, Arg3> In;
7091     typedef OutTypes<Ret, Arg1> Out;
7092 
InOutFuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)7093     InOutFuncCase(const CaseContext &context, const string &name, const CaseFunc &func, bool modularOp = false)
7094         : FuncCaseBase(context, name, func)
7095         , m_func(func)
7096         , m_modularOp(modularOp)
7097     {
7098         buildTest();
7099     }
createInstance(Context & context) const7100     virtual TestInstance *createInstance(Context &context) const
7101     {
7102         return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(),
7103                                                              m_stmt, m_modularOp);
7104     }
7105 
7106 protected:
7107     void buildTest(void);
getSamplings(void) const7108     virtual const Samplings<In> &getSamplings(void) const
7109     {
7110         return instance<DefaultSamplings<In>>();
7111     }
7112 
7113 private:
7114     const CaseFunc &m_func;
7115     Variables<In, Out> m_variables;
7116     bool m_modularOp;
7117 };
7118 
7119 template <typename Sig>
buildTest(void)7120 void InOutFuncCase<Sig>::buildTest(void)
7121 {
7122     m_variables.out0 = variable<Ret>("out0");
7123     m_variables.out1 = variable<Arg1>("out1");
7124     m_variables.in0  = variable<Arg0>("in0");
7125     m_variables.in1  = variable<Arg2>("in1");
7126     m_variables.in2  = variable<Arg3>("in2");
7127     m_variables.in3  = variable<Void>("in3");
7128 
7129     {
7130         ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.out1, m_variables.in1, m_variables.in2);
7131         m_stmt          = variableAssignment(m_variables.out0, expr);
7132 
7133         this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
7134     }
7135 }
7136 
7137 template <typename Sig>
createFuncCase(const CaseContext & context,const string & name,const Func<Sig> & func,bool modularOp=false)7138 PrecisionCase *createFuncCase(const CaseContext &context, const string &name, const Func<Sig> &func,
7139                               bool modularOp = false)
7140 {
7141     switch (func.getOutParamIndex())
7142     {
7143     case -1:
7144         return new FuncCase<Sig>(context, name, func, modularOp);
7145     case 1:
7146         return new InOutFuncCase<Sig>(context, name, func, modularOp);
7147     default:
7148         DE_FATAL("Impossible");
7149     }
7150     return DE_NULL;
7151 }
7152 
7153 class CaseFactory
7154 {
7155 public:
~CaseFactory(void)7156     virtual ~CaseFactory(void)
7157     {
7158     }
7159     virtual MovePtr<TestNode> createCase(const CaseContext &ctx) const = 0;
7160     virtual string getName(void) const                                 = 0;
7161     virtual string getDesc(void) const                                 = 0;
7162 };
7163 
7164 class FuncCaseFactory : public CaseFactory
7165 {
7166 public:
7167     virtual const FuncBase &getFunc(void) const = 0;
getName(void) const7168     string getName(void) const
7169     {
7170         return de::toLower(getFunc().getName());
7171     }
getDesc(void) const7172     string getDesc(void) const
7173     {
7174         return "Function '" + getFunc().getName() + "'";
7175     }
7176 };
7177 
7178 template <typename Sig>
7179 class GenFuncCaseFactory : public CaseFactory
7180 {
7181 public:
GenFuncCaseFactory(const GenFuncs<Sig> & funcs,const string & name,bool modularOp=false)7182     GenFuncCaseFactory(const GenFuncs<Sig> &funcs, const string &name, bool modularOp = false)
7183         : m_funcs(funcs)
7184         , m_name(de::toLower(name))
7185         , m_modularOp(modularOp)
7186     {
7187     }
7188 
createCase(const CaseContext & ctx) const7189     MovePtr<TestNode> createCase(const CaseContext &ctx) const
7190     {
7191         TestCaseGroup *group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
7192 
7193         group->addChild(createFuncCase(ctx, "scalar", m_funcs.func, m_modularOp));
7194         group->addChild(createFuncCase(ctx, "vec2", m_funcs.func2, m_modularOp));
7195         group->addChild(createFuncCase(ctx, "vec3", m_funcs.func3, m_modularOp));
7196         group->addChild(createFuncCase(ctx, "vec4", m_funcs.func4, m_modularOp));
7197         return MovePtr<TestNode>(group);
7198     }
7199 
getName(void) const7200     string getName(void) const
7201     {
7202         return m_name;
7203     }
getDesc(void) const7204     string getDesc(void) const
7205     {
7206         return "Function '" + m_funcs.func.getName() + "'";
7207     }
7208 
7209 private:
7210     const GenFuncs<Sig> m_funcs;
7211     string m_name;
7212     bool m_modularOp;
7213 };
7214 
7215 template <template <int, class> class GenF, typename T>
7216 class TemplateFuncCaseFactory : public FuncCaseFactory
7217 {
7218 public:
createCase(const CaseContext & ctx) const7219     MovePtr<TestNode> createCase(const CaseContext &ctx) const
7220     {
7221         TestCaseGroup *group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
7222 
7223         group->addChild(createFuncCase(ctx, "scalar", instance<GenF<1, T>>()));
7224         group->addChild(createFuncCase(ctx, "vec2", instance<GenF<2, T>>()));
7225         group->addChild(createFuncCase(ctx, "vec3", instance<GenF<3, T>>()));
7226         group->addChild(createFuncCase(ctx, "vec4", instance<GenF<4, T>>()));
7227 
7228         return MovePtr<TestNode>(group);
7229     }
7230 
getFunc(void) const7231     const FuncBase &getFunc(void) const
7232     {
7233         return instance<GenF<1, T>>();
7234     }
7235 };
7236 
7237 #ifndef CTS_USES_VULKANSC
7238 template <template <int> class GenF>
7239 class SquareMatrixFuncCaseFactory : public FuncCaseFactory
7240 {
7241 public:
createCase(const CaseContext & ctx) const7242     MovePtr<TestNode> createCase(const CaseContext &ctx) const
7243     {
7244         TestCaseGroup *group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
7245 
7246         group->addChild(createFuncCase(ctx, "mat2", instance<GenF<2>>()));
7247 
7248         // There is no defined precision for mediump/RelaxedPrecision in Vulkan
7249         if (ctx.name != "mediump")
7250         {
7251             static const char dataDir[] = "builtin/precision/square_matrix";
7252             std::string fileName        = getFunc().getName() + "_" + ctx.name;
7253             std::vector<std::string> requirements;
7254 
7255             if (ctx.name == "compute")
7256             {
7257                 if (ctx.isFloat64b)
7258                 {
7259                     requirements.push_back("Features.shaderFloat64");
7260                     fileName += "_fp64";
7261                 }
7262                 else
7263                 {
7264                     requirements.push_back("Float16Int8Features.shaderFloat16");
7265                     requirements.push_back("VK_KHR_16bit_storage");
7266                     requirements.push_back("VK_KHR_storage_buffer_storage_class");
7267                     fileName += "_fp16";
7268 
7269                     if (ctx.isPackFloat16b == true)
7270                     {
7271                         fileName += "_32bit";
7272                     }
7273                     else
7274                     {
7275                         requirements.push_back("Storage16BitFeatures.storageBuffer16BitAccess");
7276                     }
7277                 }
7278             }
7279 
7280             group->addChild(cts_amber::createAmberTestCase(ctx.testContext, "mat3", "Square matrix 3x3 precision tests",
7281                                                            dataDir, fileName + "_mat_3x3.amber", requirements));
7282             group->addChild(cts_amber::createAmberTestCase(ctx.testContext, "mat4", "Square matrix 4x4 precision tests",
7283                                                            dataDir, fileName + "_mat_4x4.amber", requirements));
7284         }
7285 
7286         return MovePtr<TestNode>(group);
7287     }
7288 
getFunc(void) const7289     const FuncBase &getFunc(void) const
7290     {
7291         return instance<GenF<2>>();
7292     }
7293 };
7294 #endif // CTS_USES_VULKANSC
7295 
7296 template <template <int, int, class> class GenF, typename T>
7297 class MatrixFuncCaseFactory : public FuncCaseFactory
7298 {
7299 public:
createCase(const CaseContext & ctx) const7300     MovePtr<TestNode> createCase(const CaseContext &ctx) const
7301     {
7302         TestCaseGroup *const group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
7303 
7304         this->addCase<2, 2>(ctx, group);
7305         this->addCase<3, 2>(ctx, group);
7306         this->addCase<4, 2>(ctx, group);
7307         this->addCase<2, 3>(ctx, group);
7308         this->addCase<3, 3>(ctx, group);
7309         this->addCase<4, 3>(ctx, group);
7310         this->addCase<2, 4>(ctx, group);
7311         this->addCase<3, 4>(ctx, group);
7312         this->addCase<4, 4>(ctx, group);
7313 
7314         return MovePtr<TestNode>(group);
7315     }
7316 
getFunc(void) const7317     const FuncBase &getFunc(void) const
7318     {
7319         return instance<GenF<2, 2, T>>();
7320     }
7321 
7322 private:
7323     template <int Rows, int Cols>
addCase(const CaseContext & ctx,TestCaseGroup * group) const7324     void addCase(const CaseContext &ctx, TestCaseGroup *group) const
7325     {
7326         const char *const name = dataTypeNameOf<Matrix<float, Rows, Cols>>();
7327         group->addChild(createFuncCase(ctx, name, instance<GenF<Rows, Cols, T>>()));
7328     }
7329 };
7330 
7331 template <typename Sig>
7332 class SimpleFuncCaseFactory : public CaseFactory
7333 {
7334 public:
SimpleFuncCaseFactory(const Func<Sig> & func)7335     SimpleFuncCaseFactory(const Func<Sig> &func) : m_func(func)
7336     {
7337     }
7338 
createCase(const CaseContext & ctx) const7339     MovePtr<TestNode> createCase(const CaseContext &ctx) const
7340     {
7341         return MovePtr<TestNode>(createFuncCase(ctx, ctx.name.c_str(), m_func));
7342     }
getName(void) const7343     string getName(void) const
7344     {
7345         return de::toLower(m_func.getName());
7346     }
getDesc(void) const7347     string getDesc(void) const
7348     {
7349         return "Function '" + getName() + "'";
7350     }
7351 
7352 private:
7353     const Func<Sig> &m_func;
7354 };
7355 
7356 template <typename F>
createSimpleFuncCaseFactory(void)7357 SharedPtr<SimpleFuncCaseFactory<typename F::Sig>> createSimpleFuncCaseFactory(void)
7358 {
7359     return SharedPtr<SimpleFuncCaseFactory<typename F::Sig>>(new SimpleFuncCaseFactory<typename F::Sig>(instance<F>()));
7360 }
7361 
7362 class CaseFactories
7363 {
7364 public:
~CaseFactories(void)7365     virtual ~CaseFactories(void)
7366     {
7367     }
7368     virtual const std::vector<const CaseFactory *> getFactories(void) const = 0;
7369 };
7370 
7371 class BuiltinFuncs : public CaseFactories
7372 {
7373 public:
getFactories(void) const7374     const vector<const CaseFactory *> getFactories(void) const
7375     {
7376         vector<const CaseFactory *> ret;
7377 
7378         for (size_t ndx = 0; ndx < m_factories.size(); ++ndx)
7379             ret.push_back(m_factories[ndx].get());
7380 
7381         return ret;
7382     }
7383 
addFactory(SharedPtr<const CaseFactory> fact)7384     void addFactory(SharedPtr<const CaseFactory> fact)
7385     {
7386         m_factories.push_back(fact);
7387     }
7388 
7389 private:
7390     vector<SharedPtr<const CaseFactory>> m_factories;
7391 };
7392 
7393 template <typename F>
addScalarFactory(BuiltinFuncs & funcs,string name="",bool modularOp=false)7394 void addScalarFactory(BuiltinFuncs &funcs, string name = "", bool modularOp = false)
7395 {
7396     if (name.empty())
7397         name = instance<F>().getName();
7398 
7399     funcs.addFactory(SharedPtr<const CaseFactory>(
7400         new GenFuncCaseFactory<typename F::Sig>(makeVectorizedFuncs<F>(), name, modularOp)));
7401 }
7402 
createBuiltinCases()7403 MovePtr<const CaseFactories> createBuiltinCases()
7404 {
7405     MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7406 
7407     // Tests for ES3 builtins
7408     addScalarFactory<Comparison<Signature<int, float, float>>>(*funcs);
7409     addScalarFactory<Add<Signature<float, float, float>>>(*funcs);
7410     addScalarFactory<Sub<Signature<float, float, float>>>(*funcs);
7411     addScalarFactory<Mul<Signature<float, float, float>>>(*funcs);
7412     addScalarFactory<Div<Signature<float, float, float>>>(*funcs);
7413 
7414     addScalarFactory<Radians>(*funcs);
7415     addScalarFactory<Degrees>(*funcs);
7416     addScalarFactory<Sin<Signature<float, float>>>(*funcs);
7417     addScalarFactory<Cos<Signature<float, float>>>(*funcs);
7418     addScalarFactory<Tan>(*funcs);
7419 
7420     addScalarFactory<ASin>(*funcs);
7421     addScalarFactory<ACos>(*funcs);
7422     addScalarFactory<ATan2<Signature<float, float, float>>>(*funcs, "atan2");
7423     addScalarFactory<ATan<Signature<float, float>>>(*funcs);
7424     addScalarFactory<Sinh>(*funcs);
7425     addScalarFactory<Cosh>(*funcs);
7426     addScalarFactory<Tanh>(*funcs);
7427     addScalarFactory<ASinh>(*funcs);
7428     addScalarFactory<ACosh>(*funcs);
7429     addScalarFactory<ATanh>(*funcs);
7430 
7431     addScalarFactory<Pow>(*funcs);
7432     addScalarFactory<Exp<Signature<float, float>>>(*funcs);
7433     addScalarFactory<Log<Signature<float, float>>>(*funcs);
7434     addScalarFactory<Exp2<Signature<float, float>>>(*funcs);
7435     addScalarFactory<Log2<Signature<float, float>>>(*funcs);
7436     addScalarFactory<Sqrt32Bit>(*funcs);
7437     addScalarFactory<InverseSqrt<Signature<float, float>>>(*funcs);
7438 
7439     addScalarFactory<Abs<Signature<float, float>>>(*funcs);
7440     addScalarFactory<Sign<Signature<float, float>>>(*funcs);
7441     addScalarFactory<Floor32Bit>(*funcs);
7442     addScalarFactory<Trunc32Bit>(*funcs);
7443     addScalarFactory<Round<Signature<float, float>>>(*funcs);
7444     addScalarFactory<RoundEven<Signature<float, float>>>(*funcs);
7445     addScalarFactory<Ceil<Signature<float, float>>>(*funcs);
7446     addScalarFactory<Fract>(*funcs);
7447 
7448     addScalarFactory<Mod32Bit>(*funcs, "mod", true);
7449     addScalarFactory<FRem32Bit>(*funcs);
7450 
7451     addScalarFactory<Modf32Bit>(*funcs);
7452     addScalarFactory<ModfStruct32Bit>(*funcs);
7453     addScalarFactory<Min<Signature<float, float, float>>>(*funcs);
7454     addScalarFactory<Max<Signature<float, float, float>>>(*funcs);
7455     addScalarFactory<Clamp<Signature<float, float, float, float>>>(*funcs);
7456     addScalarFactory<Mix>(*funcs);
7457     addScalarFactory<Step<Signature<float, float, float>>>(*funcs);
7458     addScalarFactory<SmoothStep<Signature<float, float, float, float>>>(*funcs);
7459 
7460     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, float>()));
7461     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, float>()));
7462     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, float>()));
7463     funcs->addFactory(createSimpleFuncCaseFactory<Cross>());
7464     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, float>()));
7465     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, float>()));
7466     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, float>()));
7467     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, float>()));
7468 
7469     funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, float>()));
7470     funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, float>()));
7471     funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, float>()));
7472 #ifndef CTS_USES_VULKANSC
7473     funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant>()));
7474     funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse>()));
7475 #endif // CTS_USES_VULKANSC
7476 
7477     addScalarFactory<Frexp32Bit>(*funcs);
7478     addScalarFactory<FrexpStruct32Bit>(*funcs);
7479     addScalarFactory<LdExp<Signature<float, float, int>>>(*funcs);
7480     addScalarFactory<Fma<Signature<float, float, float, float>>>(*funcs);
7481 
7482     return MovePtr<const CaseFactories>(funcs.release());
7483 }
7484 
createBuiltinDoubleCases()7485 MovePtr<const CaseFactories> createBuiltinDoubleCases()
7486 {
7487     MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7488 
7489     // Tests for ES3 builtins
7490     addScalarFactory<Comparison<Signature<int, double, double>>>(*funcs);
7491     addScalarFactory<Add<Signature<double, double, double>>>(*funcs);
7492     addScalarFactory<Sub<Signature<double, double, double>>>(*funcs);
7493     addScalarFactory<Mul<Signature<double, double, double>>>(*funcs);
7494     addScalarFactory<Div<Signature<double, double, double>>>(*funcs);
7495 
7496     // Radians, degrees, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, pow, exp, log, exp2 and log2
7497     // only work with 16-bit and 32-bit floating point types according to the spec.
7498 
7499     addScalarFactory<Sqrt64Bit>(*funcs);
7500     addScalarFactory<InverseSqrt<Signature<double, double>>>(*funcs);
7501 
7502     addScalarFactory<Abs<Signature<double, double>>>(*funcs);
7503     addScalarFactory<Sign<Signature<double, double>>>(*funcs);
7504     addScalarFactory<Floor64Bit>(*funcs);
7505     addScalarFactory<Trunc64Bit>(*funcs);
7506     addScalarFactory<Round<Signature<double, double>>>(*funcs);
7507     addScalarFactory<RoundEven<Signature<double, double>>>(*funcs);
7508     addScalarFactory<Ceil<Signature<double, double>>>(*funcs);
7509     addScalarFactory<Fract64Bit>(*funcs);
7510 
7511     addScalarFactory<Mod64Bit>(*funcs, "mod", true);
7512     addScalarFactory<FRem64Bit>(*funcs);
7513 
7514     addScalarFactory<Modf64Bit>(*funcs);
7515     addScalarFactory<ModfStruct64Bit>(*funcs);
7516     addScalarFactory<Min<Signature<double, double, double>>>(*funcs);
7517     addScalarFactory<Max<Signature<double, double, double>>>(*funcs);
7518     addScalarFactory<Clamp<Signature<double, double, double, double>>>(*funcs);
7519     addScalarFactory<Mix64Bit>(*funcs);
7520     addScalarFactory<Step<Signature<double, double, double>>>(*funcs);
7521     addScalarFactory<SmoothStep<Signature<double, double, double, double>>>(*funcs);
7522 
7523     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, double>()));
7524     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, double>()));
7525     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, double>()));
7526     funcs->addFactory(createSimpleFuncCaseFactory<Cross64Bit>());
7527     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, double>()));
7528     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, double>()));
7529     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, double>()));
7530     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, double>()));
7531 
7532     funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, double>()));
7533     funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, double>()));
7534     funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, double>()));
7535 #ifndef CTS_USES_VULKANSC
7536     funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant64bit>()));
7537     funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse64bit>()));
7538 #endif // CTS_USES_VULKANSC
7539 
7540     addScalarFactory<Frexp64Bit>(*funcs);
7541     addScalarFactory<FrexpStruct64Bit>(*funcs);
7542     addScalarFactory<LdExp<Signature<double, double, int>>>(*funcs);
7543     addScalarFactory<Fma<Signature<double, double, double, double>>>(*funcs);
7544 
7545     return MovePtr<const CaseFactories>(funcs.release());
7546 }
7547 
createBuiltinCases16Bit(void)7548 MovePtr<const CaseFactories> createBuiltinCases16Bit(void)
7549 {
7550     MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7551 
7552     addScalarFactory<Comparison<Signature<int, deFloat16, deFloat16>>>(*funcs);
7553     addScalarFactory<Add<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7554     addScalarFactory<Sub<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7555     addScalarFactory<Mul<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7556     addScalarFactory<Div<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7557 
7558     addScalarFactory<Radians16>(*funcs);
7559     addScalarFactory<Degrees16>(*funcs);
7560 
7561     addScalarFactory<Sin<Signature<deFloat16, deFloat16>>>(*funcs);
7562     addScalarFactory<Cos<Signature<deFloat16, deFloat16>>>(*funcs);
7563     addScalarFactory<Tan16Bit>(*funcs);
7564     addScalarFactory<ASin16Bit>(*funcs);
7565     addScalarFactory<ACos16Bit>(*funcs);
7566     addScalarFactory<ATan2<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs, "atan2");
7567     addScalarFactory<ATan<Signature<deFloat16, deFloat16>>>(*funcs);
7568 
7569     addScalarFactory<Sinh16Bit>(*funcs);
7570     addScalarFactory<Cosh16Bit>(*funcs);
7571     addScalarFactory<Tanh16Bit>(*funcs);
7572     addScalarFactory<ASinh16Bit>(*funcs);
7573     addScalarFactory<ACosh16Bit>(*funcs);
7574     addScalarFactory<ATanh16Bit>(*funcs);
7575 
7576     addScalarFactory<Pow16>(*funcs);
7577     addScalarFactory<Exp<Signature<deFloat16, deFloat16>>>(*funcs);
7578     addScalarFactory<Log<Signature<deFloat16, deFloat16>>>(*funcs);
7579     addScalarFactory<Exp2<Signature<deFloat16, deFloat16>>>(*funcs);
7580     addScalarFactory<Log2<Signature<deFloat16, deFloat16>>>(*funcs);
7581     addScalarFactory<Sqrt16Bit>(*funcs);
7582     addScalarFactory<InverseSqrt16Bit>(*funcs);
7583 
7584     addScalarFactory<Abs<Signature<deFloat16, deFloat16>>>(*funcs);
7585     addScalarFactory<Sign<Signature<deFloat16, deFloat16>>>(*funcs);
7586     addScalarFactory<Floor16Bit>(*funcs);
7587     addScalarFactory<Trunc16Bit>(*funcs);
7588     addScalarFactory<Round<Signature<deFloat16, deFloat16>>>(*funcs);
7589     addScalarFactory<RoundEven<Signature<deFloat16, deFloat16>>>(*funcs);
7590     addScalarFactory<Ceil<Signature<deFloat16, deFloat16>>>(*funcs);
7591     addScalarFactory<Fract16Bit>(*funcs);
7592 
7593     addScalarFactory<Mod16Bit>(*funcs, "mod", true);
7594     addScalarFactory<FRem16Bit>(*funcs);
7595 
7596     addScalarFactory<Modf16Bit>(*funcs);
7597     addScalarFactory<ModfStruct16Bit>(*funcs);
7598     addScalarFactory<Min<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7599     addScalarFactory<Max<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7600     addScalarFactory<Clamp<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>>(*funcs);
7601     addScalarFactory<Mix16Bit>(*funcs);
7602     addScalarFactory<Step<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7603     addScalarFactory<SmoothStep<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>>(*funcs);
7604 
7605     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, deFloat16>()));
7606     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, deFloat16>()));
7607     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, deFloat16>()));
7608     funcs->addFactory(createSimpleFuncCaseFactory<Cross16Bit>());
7609     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, deFloat16>()));
7610     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, deFloat16>()));
7611     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, deFloat16>()));
7612     funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, deFloat16>()));
7613 
7614     funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, deFloat16>()));
7615     funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, deFloat16>()));
7616 #ifndef CTS_USES_VULKANSC
7617     funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant16bit>()));
7618     funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse16bit>()));
7619 #endif // CTS_USES_VULKANSC
7620 
7621     addScalarFactory<Frexp16Bit>(*funcs);
7622     addScalarFactory<FrexpStruct16Bit>(*funcs);
7623     addScalarFactory<LdExp<Signature<deFloat16, deFloat16, int>>>(*funcs);
7624     addScalarFactory<Fma<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>>(*funcs);
7625 
7626     return MovePtr<const CaseFactories>(funcs.release());
7627 }
7628 
createFuncGroup(TestContext & ctx,const CaseFactory & factory,int numRandoms)7629 TestCaseGroup *createFuncGroup(TestContext &ctx, const CaseFactory &factory, int numRandoms)
7630 {
7631     TestCaseGroup *const group = new TestCaseGroup(ctx, factory.getName().c_str());
7632     const FloatFormat highp(-126, 127, 23, true,
7633                             tcu::MAYBE,  // subnormals
7634                             tcu::YES,    // infinities
7635                             tcu::MAYBE); // NaN
7636     const FloatFormat mediump(-14, 13, 10, false, tcu::MAYBE);
7637 
7638     for (int precNdx = glu::PRECISION_MEDIUMP; precNdx < glu::PRECISION_LAST; ++precNdx)
7639     {
7640         const Precision precision = Precision(precNdx);
7641         const string precName(glu::getPrecisionName(precision));
7642         const FloatFormat &fmt = precNdx == glu::PRECISION_MEDIUMP ? mediump : highp;
7643 
7644         const CaseContext caseCtx(precName, ctx, fmt, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms);
7645 
7646         group->addChild(factory.createCase(caseCtx).release());
7647     }
7648 
7649     return group;
7650 }
7651 
createFuncGroupDouble(TestContext & ctx,const CaseFactory & factory,int numRandoms)7652 TestCaseGroup *createFuncGroupDouble(TestContext &ctx, const CaseFactory &factory, int numRandoms)
7653 {
7654     TestCaseGroup *const group = new TestCaseGroup(ctx, factory.getName().c_str());
7655     const Precision precision  = Precision(glu::PRECISION_LAST);
7656     const FloatFormat highp(-1022, 1023, 52, true,
7657                             tcu::MAYBE,  // subnormals
7658                             tcu::YES,    // infinities
7659                             tcu::MAYBE); // NaN
7660 
7661     PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT;
7662 
7663     const CaseContext caseCtx("compute", ctx, highp, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms,
7664                               precisionTestFeatures, false, true);
7665     group->addChild(factory.createCase(caseCtx).release());
7666 
7667     return group;
7668 }
7669 
createFuncGroup16Bit(TestContext & ctx,const CaseFactory & factory,int numRandoms,bool storage32)7670 TestCaseGroup *createFuncGroup16Bit(TestContext &ctx, const CaseFactory &factory, int numRandoms, bool storage32)
7671 {
7672     TestCaseGroup *const group = new TestCaseGroup(ctx, factory.getName().c_str());
7673     const Precision precision  = Precision(glu::PRECISION_LAST);
7674     const FloatFormat float16(-14, 15, 10, true, tcu::MAYBE);
7675 
7676     PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT;
7677     if (!storage32)
7678         precisionTestFeatures |= PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS;
7679 
7680     const CaseContext caseCtx("compute", ctx, float16, float16, precision, glu::SHADERTYPE_COMPUTE, numRandoms,
7681                               precisionTestFeatures, storage32);
7682     group->addChild(factory.createCase(caseCtx).release());
7683 
7684     return group;
7685 }
7686 
7687 const int defRandoms = 16384;
7688 
addBuiltinPrecisionTests(TestContext & ctx,TestCaseGroup & dstGroup,const bool test16Bit=false,const bool storage32Bit=false)7689 void addBuiltinPrecisionTests(TestContext &ctx, TestCaseGroup &dstGroup, const bool test16Bit = false,
7690                               const bool storage32Bit = false)
7691 {
7692     const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7693     const int numRandoms  = userRandoms > 0 ? userRandoms : defRandoms;
7694 
7695     MovePtr<const CaseFactories> cases =
7696         (test16Bit && !storage32Bit) ? createBuiltinCases16Bit() : createBuiltinCases();
7697     for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7698     {
7699         if (!test16Bit)
7700             dstGroup.addChild(createFuncGroup(ctx, *cases->getFactories()[ndx], numRandoms));
7701         else
7702             dstGroup.addChild(createFuncGroup16Bit(ctx, *cases->getFactories()[ndx], numRandoms, storage32Bit));
7703     }
7704 }
7705 
addBuiltinPrecisionDoubleTests(TestContext & ctx,TestCaseGroup & dstGroup)7706 void addBuiltinPrecisionDoubleTests(TestContext &ctx, TestCaseGroup &dstGroup)
7707 {
7708     const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7709     const int numRandoms  = userRandoms > 0 ? userRandoms : defRandoms;
7710 
7711     MovePtr<const CaseFactories> cases = createBuiltinDoubleCases();
7712     for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7713     {
7714         dstGroup.addChild(createFuncGroupDouble(ctx, *cases->getFactories()[ndx], numRandoms));
7715     }
7716 }
7717 
BuiltinPrecisionTests(tcu::TestContext & testCtx)7718 BuiltinPrecisionTests::BuiltinPrecisionTests(tcu::TestContext &testCtx) : tcu::TestCaseGroup(testCtx, "precision")
7719 {
7720 }
7721 
~BuiltinPrecisionTests(void)7722 BuiltinPrecisionTests::~BuiltinPrecisionTests(void)
7723 {
7724 }
7725 
init(void)7726 void BuiltinPrecisionTests::init(void)
7727 {
7728     addBuiltinPrecisionTests(m_testCtx, *this);
7729 }
7730 
BuiltinPrecisionDoubleTests(tcu::TestContext & testCtx)7731 BuiltinPrecisionDoubleTests::BuiltinPrecisionDoubleTests(tcu::TestContext &testCtx)
7732     : tcu::TestCaseGroup(testCtx, "precision_double")
7733 {
7734 }
7735 
~BuiltinPrecisionDoubleTests(void)7736 BuiltinPrecisionDoubleTests::~BuiltinPrecisionDoubleTests(void)
7737 {
7738 }
7739 
init(void)7740 void BuiltinPrecisionDoubleTests::init(void)
7741 {
7742     addBuiltinPrecisionDoubleTests(m_testCtx, *this);
7743 }
7744 
BuiltinPrecision16BitTests(tcu::TestContext & testCtx)7745 BuiltinPrecision16BitTests::BuiltinPrecision16BitTests(tcu::TestContext &testCtx)
7746     : tcu::TestCaseGroup(testCtx, "precision_fp16_storage16b")
7747 {
7748 }
7749 
~BuiltinPrecision16BitTests(void)7750 BuiltinPrecision16BitTests::~BuiltinPrecision16BitTests(void)
7751 {
7752 }
7753 
init(void)7754 void BuiltinPrecision16BitTests::init(void)
7755 {
7756     addBuiltinPrecisionTests(m_testCtx, *this, true);
7757 }
7758 
BuiltinPrecision16Storage32BitTests(tcu::TestContext & testCtx)7759 BuiltinPrecision16Storage32BitTests::BuiltinPrecision16Storage32BitTests(tcu::TestContext &testCtx)
7760     : tcu::TestCaseGroup(testCtx, "precision_fp16_storage32b")
7761 {
7762 }
7763 
~BuiltinPrecision16Storage32BitTests(void)7764 BuiltinPrecision16Storage32BitTests::~BuiltinPrecision16Storage32BitTests(void)
7765 {
7766 }
7767 
init(void)7768 void BuiltinPrecision16Storage32BitTests::init(void)
7769 {
7770     addBuiltinPrecisionTests(m_testCtx, *this, true, true);
7771 }
7772 
7773 } // namespace shaderexecutor
7774 } // namespace vkt
7775