1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2018 The Khronos Group Inc.
6 * Copyright (c) 2015 Samsung Electronics Co., Ltd.
7 * Copyright (c) 2016 The Android Open Source Project
8 *
9 * Licensed under the Apache License, Version 2.0 (the "License");
10 * you may not use this file except in compliance with the License.
11 * You may obtain a copy of the License at
12 *
13 * http://www.apache.org/licenses/LICENSE-2.0
14 *
15 * Unless required by applicable law or agreed to in writing, software
16 * distributed under the License is distributed on an "AS IS" BASIS,
17 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 * See the License for the specific language governing permissions and
19 * limitations under the License.
20 *
21 *//*!
22 * \file
23 * \brief Precision and range tests for builtins and types.
24 *
25 *//*--------------------------------------------------------------------*/
26
27 #include "vktShaderBuiltinPrecisionTests.hpp"
28 #include "vktShaderExecutor.hpp"
29 #include "amber/vktAmberTestCase.hpp"
30
31 #include "deMath.h"
32 #include "deMemory.h"
33 #include "deFloat16.h"
34 #include "deDefs.hpp"
35 #include "deRandom.hpp"
36 #include "deSTLUtil.hpp"
37 #include "deStringUtil.hpp"
38 #include "deUniquePtr.hpp"
39 #include "deSharedPtr.hpp"
40 #include "deArrayUtil.hpp"
41
42 #include "tcuCommandLine.hpp"
43 #include "tcuFloatFormat.hpp"
44 #include "tcuInterval.hpp"
45 #include "tcuTestLog.hpp"
46 #include "tcuVector.hpp"
47 #include "tcuMatrix.hpp"
48 #include "tcuResultCollector.hpp"
49 #include "tcuMaybe.hpp"
50
51 #include "gluContextInfo.hpp"
52 #include "gluVarType.hpp"
53 #include "gluRenderContext.hpp"
54 #include "glwDefs.hpp"
55
56 #include <cmath>
57 #include <string>
58 #include <sstream>
59 #include <iostream>
60 #include <map>
61 #include <utility>
62 #include <limits>
63
64 // Uncomment this to get evaluation trace dumps to std::cerr
65 // #define GLS_ENABLE_TRACE
66
67 // set this to true to dump even passing results
68 #define GLS_LOG_ALL_RESULTS false
69
70 #define FLOAT16_1_0 0x3C00 //1.0 float16bit
71 #define FLOAT16_2_0 0x4000 //2.0 float16bit
72 #define FLOAT16_3_0 0x4200 //3.0 float16bit
73 #define FLOAT16_0_5 0x3800 //0.5 float16bit
74 #define FLOAT16_0_0 0x0000 //0.0 float16bit
75
76 using tcu::Vector;
77 typedef Vector<deFloat16, 1> Vec1_16Bit;
78 typedef Vector<deFloat16, 2> Vec2_16Bit;
79 typedef Vector<deFloat16, 3> Vec3_16Bit;
80 typedef Vector<deFloat16, 4> Vec4_16Bit;
81
82 typedef Vector<double, 1> Vec1_64Bit;
83 typedef Vector<double, 2> Vec2_64Bit;
84 typedef Vector<double, 3> Vec3_64Bit;
85 typedef Vector<double, 4> Vec4_64Bit;
86
87 enum
88 {
89 // Computing reference intervals can take a non-trivial amount of time, especially on
90 // platforms where toggling floating-point rounding mode is slow (emulated arm on x86).
91 // As a workaround watchdog is kept happy by touching it periodically during reference
92 // interval computation.
93 TOUCH_WATCHDOG_VALUE_FREQUENCY = 512
94 };
95
96 namespace vkt
97 {
98 namespace shaderexecutor
99 {
100
101 using std::map;
102 using std::ostream;
103 using std::ostringstream;
104 using std::pair;
105 using std::set;
106 using std::string;
107 using std::vector;
108
109 using de::MovePtr;
110 using de::Random;
111 using de::SharedPtr;
112 using de::UniquePtr;
113 using glu::DataType;
114 using glu::Precision;
115 using glu::ShaderType;
116 using glu::VarType;
117 using tcu::FloatFormat;
118 using tcu::Interval;
119 using tcu::Matrix;
120 using tcu::MessageBuilder;
121 using tcu::TestLog;
122 using tcu::Vector;
123
124 enum PrecisionTestFeatureBits
125 {
126 PRECISION_TEST_FEATURES_NONE = 0u,
127 PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS = (1u << 1),
128 PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS = (1u << 2),
129 PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT = (1u << 3),
130 PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT = (1u << 4),
131 PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT = (1u << 5),
132 PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT = (1u << 6),
133 };
134 typedef uint32_t PrecisionTestFeatures;
135
areFeaturesSupported(const Context & context,uint32_t toCheck)136 void areFeaturesSupported(const Context &context, uint32_t toCheck)
137 {
138 if (toCheck == PRECISION_TEST_FEATURES_NONE)
139 return;
140
141 const vk::VkPhysicalDevice16BitStorageFeatures &extensionFeatures = context.get16BitStorageFeatures();
142
143 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_BUFFER_ACCESS) != 0 &&
144 extensionFeatures.storageBuffer16BitAccess == VK_FALSE)
145 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
146
147 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS) != 0 &&
148 extensionFeatures.uniformAndStorageBuffer16BitAccess == VK_FALSE)
149 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
150
151 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_PUSH_CONSTANT) != 0 &&
152 extensionFeatures.storagePushConstant16 == VK_FALSE)
153 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
154
155 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_INPUT_OUTPUT) != 0 &&
156 extensionFeatures.storageInputOutput16 == VK_FALSE)
157 TCU_THROW(NotSupportedError, "Requested 16bit storage features not supported");
158
159 if ((toCheck & PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT) != 0 &&
160 context.getShaderFloat16Int8Features().shaderFloat16 == VK_FALSE)
161 TCU_THROW(NotSupportedError, "Requested 16-bit floats (halfs) are not supported in shader code");
162
163 if ((toCheck & PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT) != 0 &&
164 context.getDeviceFeatures().shaderFloat64 == VK_FALSE)
165 TCU_THROW(NotSupportedError, "Requested 64-bit floats are not supported in shader code");
166 }
167
168 /*--------------------------------------------------------------------*//*!
169 * \brief Generic singleton creator.
170 *
171 * instance<T>() returns a reference to a unique default-constructed instance
172 * of T. This is mainly used for our GLSL function implementations: each
173 * function is implemented by an object, and each of the objects has a
174 * distinct class. It would be extremely toilsome to maintain a separate
175 * context object that contained individual instances of the function classes,
176 * so we have to resort to global singleton instances.
177 *
178 *//*--------------------------------------------------------------------*/
179 template <typename T>
instance(void)180 const T &instance(void)
181 {
182 static const T s_instance = T();
183 return s_instance;
184 }
185
186 /*--------------------------------------------------------------------*//*!
187 * \brief Empty placeholder type for unused template parameters.
188 *
189 * In the precision tests we are dealing with functions of different arities.
190 * To minimize code duplication, we only define templates with the maximum
191 * number of arguments, currently four. If a function's arity is less than the
192 * maximum, Void us used as the type for unused arguments.
193 *
194 * Although Voids are not used at run-time, they still must be compilable, so
195 * they must support all operations that other types do.
196 *
197 *//*--------------------------------------------------------------------*/
198 struct Void
199 {
200 typedef Void Element;
201 enum
202 {
203 SIZE = 0,
204 };
205
206 template <typename T>
Voidvkt::shaderexecutor::Void207 explicit Void(const T &)
208 {
209 }
Voidvkt::shaderexecutor::Void210 Void(void)
211 {
212 }
operator doublevkt::shaderexecutor::Void213 operator double(void) const
214 {
215 return TCU_NAN;
216 }
217
218 // These are used to make Voids usable as containers in container-generic code.
operator []vkt::shaderexecutor::Void219 Void &operator[](int)
220 {
221 return *this;
222 }
operator []vkt::shaderexecutor::Void223 const Void &operator[](int) const
224 {
225 return *this;
226 }
227 };
228
operator <<(ostream & os,Void)229 ostream &operator<<(ostream &os, Void)
230 {
231 return os << "()";
232 }
233
234 //! Returns true for all other types except Void
235 template <typename T>
isTypeValid(void)236 bool isTypeValid(void)
237 {
238 return true;
239 }
240 template <>
isTypeValid(void)241 bool isTypeValid<Void>(void)
242 {
243 return false;
244 }
245
246 template <typename T>
isInteger(void)247 bool isInteger(void)
248 {
249 return false;
250 }
251 template <>
isInteger(void)252 bool isInteger<int>(void)
253 {
254 return true;
255 }
256 template <>
isInteger(void)257 bool isInteger<tcu::IVec2>(void)
258 {
259 return true;
260 }
261 template <>
isInteger(void)262 bool isInteger<tcu::IVec3>(void)
263 {
264 return true;
265 }
266 template <>
isInteger(void)267 bool isInteger<tcu::IVec4>(void)
268 {
269 return true;
270 }
271
272 //! Utility function for getting the name of a data type.
273 //! This is used in vector and matrix constructors.
274 template <typename T>
dataTypeNameOf(void)275 const char *dataTypeNameOf(void)
276 {
277 return glu::getDataTypeName(glu::dataTypeOf<T>());
278 }
279
280 template <>
dataTypeNameOf(void)281 const char *dataTypeNameOf<Void>(void)
282 {
283 DE_FATAL("Impossible");
284 return DE_NULL;
285 }
286
287 template <typename T>
getVarTypeOf(Precision prec=glu::PRECISION_LAST)288 VarType getVarTypeOf(Precision prec = glu::PRECISION_LAST)
289 {
290 return glu::varTypeOf<T>(prec);
291 }
292
293 //! A hack to get Void support for VarType.
294 template <>
getVarTypeOf(Precision)295 VarType getVarTypeOf<Void>(Precision)
296 {
297 DE_FATAL("Impossible");
298 return VarType();
299 }
300
301 /*--------------------------------------------------------------------*//*!
302 * \brief Type traits for generalized interval types.
303 *
304 * We are trying to compute sets of acceptable values not only for
305 * float-valued expressions but also for compound values: vectors and
306 * matrices. We approximate a set of vectors as a vector of intervals and
307 * likewise for matrices.
308 *
309 * We now need generalized operations for each type and its interval
310 * approximation. These are given in the type Traits<T>.
311 *
312 * The type Traits<T>::IVal is the approximation of T: it is `Interval` for
313 * scalar types, and a vector or matrix of intervals for container types.
314 *
315 * To allow template inference to take place, there are function wrappers for
316 * the actual operations in Traits<T>. Hence we can just use:
317 *
318 * makeIVal(someFloat)
319 *
320 * instead of:
321 *
322 * Traits<float>::doMakeIVal(value)
323 *
324 *//*--------------------------------------------------------------------*/
325
326 template <typename T>
327 struct Traits;
328
329 //! Create container from elementwise singleton values.
330 template <typename T>
makeIVal(const T & value)331 typename Traits<T>::IVal makeIVal(const T &value)
332 {
333 return Traits<T>::doMakeIVal(value);
334 }
335
336 //! Elementwise union of intervals.
337 template <typename T>
unionIVal(const typename Traits<T>::IVal & a,const typename Traits<T>::IVal & b)338 typename Traits<T>::IVal unionIVal(const typename Traits<T>::IVal &a, const typename Traits<T>::IVal &b)
339 {
340 return Traits<T>::doUnion(a, b);
341 }
342
343 //! Returns true iff every element of `ival` contains the corresponding element of `value`.
344 template <typename T, typename U = Void>
contains(const typename Traits<T>::IVal & ival,const T & value,bool is16Bit=false,const tcu::Maybe<U> & modularDivisor=tcu::Nothing)345 bool contains(const typename Traits<T>::IVal &ival, const T &value, bool is16Bit = false,
346 const tcu::Maybe<U> &modularDivisor = tcu::Nothing)
347 {
348 return Traits<T>::doContains(ival, value, is16Bit, modularDivisor);
349 }
350
351 //! Print out an interval with the precision of `fmt`.
352 template <typename T>
printIVal(const FloatFormat & fmt,const typename Traits<T>::IVal & ival,ostream & os)353 void printIVal(const FloatFormat &fmt, const typename Traits<T>::IVal &ival, ostream &os)
354 {
355 Traits<T>::doPrintIVal(fmt, ival, os);
356 }
357
358 template <typename T>
intervalToString(const FloatFormat & fmt,const typename Traits<T>::IVal & ival)359 string intervalToString(const FloatFormat &fmt, const typename Traits<T>::IVal &ival)
360 {
361 ostringstream oss;
362 printIVal<T>(fmt, ival, oss);
363 return oss.str();
364 }
365
366 //! Print out a value with the precision of `fmt`.
367 template <typename T>
printValue16(const FloatFormat & fmt,const T & value,ostream & os)368 void printValue16(const FloatFormat &fmt, const T &value, ostream &os)
369 {
370 Traits<T>::doPrintValue16(fmt, value, os);
371 }
372
373 template <typename T>
value16ToString(const FloatFormat & fmt,const T & val)374 string value16ToString(const FloatFormat &fmt, const T &val)
375 {
376 ostringstream oss;
377 printValue16(fmt, val, oss);
378 return oss.str();
379 }
380
getComparisonOperation(const int ndx)381 const std::string getComparisonOperation(const int ndx)
382 {
383 const int operationCount = 10;
384 DE_ASSERT(de::inBounds(ndx, 0, operationCount));
385 const std::string operations[operationCount] = {
386 "OpFOrdEqual\t\t\t", "OpFOrdGreaterThan\t", "OpFOrdLessThan\t\t", "OpFOrdGreaterThanEqual",
387 "OpFOrdLessThanEqual\t", "OpFUnordEqual\t\t", "OpFUnordGreaterThan\t", "OpFUnordLessThan\t",
388 "OpFUnordGreaterThanEqual", "OpFUnordLessThanEqual"};
389 return operations[ndx];
390 }
391
392 template <typename T>
comparisonMessage(const T & val)393 string comparisonMessage(const T &val)
394 {
395 DE_UNREF(val);
396 return "";
397 }
398
399 template <>
comparisonMessage(const int & val)400 string comparisonMessage(const int &val)
401 {
402 ostringstream oss;
403
404 int flags = val;
405 for (int ndx = 0; ndx < 10; ++ndx)
406 {
407 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags & 1) == 1 ? "TRUE" : "FALSE") << "\n";
408 flags = flags >> 1;
409 }
410 return oss.str();
411 }
412
413 template <>
comparisonMessage(const tcu::IVec2 & val)414 string comparisonMessage(const tcu::IVec2 &val)
415 {
416 ostringstream oss;
417 tcu::IVec2 flags = val;
418 for (int ndx = 0; ndx < 10; ++ndx)
419 {
420 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
421 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
422 flags.x() = flags.x() >> 1;
423 flags.y() = flags.y() >> 1;
424 }
425 return oss.str();
426 }
427
428 template <>
comparisonMessage(const tcu::IVec3 & val)429 string comparisonMessage(const tcu::IVec3 &val)
430 {
431 ostringstream oss;
432 tcu::IVec3 flags = val;
433 for (int ndx = 0; ndx < 10; ++ndx)
434 {
435 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
436 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t" << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
437 flags.x() = flags.x() >> 1;
438 flags.y() = flags.y() >> 1;
439 flags.z() = flags.z() >> 1;
440 }
441 return oss.str();
442 }
443
444 template <>
comparisonMessage(const tcu::IVec4 & val)445 string comparisonMessage(const tcu::IVec4 &val)
446 {
447 ostringstream oss;
448 tcu::IVec4 flags = val;
449 for (int ndx = 0; ndx < 10; ++ndx)
450 {
451 oss << getComparisonOperation(ndx) << "\t:\t" << ((flags.x() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
452 << ((flags.y() & 1) == 1 ? "TRUE" : "FALSE") << "\t" << ((flags.z() & 1) == 1 ? "TRUE" : "FALSE") << "\t"
453 << ((flags.w() & 1) == 1 ? "TRUE" : "FALSE") << "\n";
454 flags.x() = flags.x() >> 1;
455 flags.y() = flags.y() >> 1;
456 flags.z() = flags.z() >> 1;
457 flags.w() = flags.z() >> 1;
458 }
459 return oss.str();
460 }
461 //! Print out a value with the precision of `fmt`.
462 template <typename T>
printValue32(const FloatFormat & fmt,const T & value,ostream & os)463 void printValue32(const FloatFormat &fmt, const T &value, ostream &os)
464 {
465 Traits<T>::doPrintValue32(fmt, value, os);
466 }
467
468 template <typename T>
value32ToString(const FloatFormat & fmt,const T & val)469 string value32ToString(const FloatFormat &fmt, const T &val)
470 {
471 ostringstream oss;
472 printValue32(fmt, val, oss);
473 return oss.str();
474 }
475
476 template <typename T>
printValue64(const FloatFormat & fmt,const T & value,ostream & os)477 void printValue64(const FloatFormat &fmt, const T &value, ostream &os)
478 {
479 Traits<T>::doPrintValue64(fmt, value, os);
480 }
481
482 template <typename T>
value64ToString(const FloatFormat & fmt,const T & val)483 string value64ToString(const FloatFormat &fmt, const T &val)
484 {
485 ostringstream oss;
486 printValue64(fmt, val, oss);
487 return oss.str();
488 }
489
490 //! Approximate `value` elementwise to the float precision defined in `fmt`.
491 //! The resulting interval might not be a singleton if rounding in both
492 //! directions is allowed.
493 template <typename T>
round(const FloatFormat & fmt,const T & value)494 typename Traits<T>::IVal round(const FloatFormat &fmt, const T &value)
495 {
496 return Traits<T>::doRound(fmt, value);
497 }
498
499 template <typename T>
convert(const FloatFormat & fmt,const typename Traits<T>::IVal & value)500 typename Traits<T>::IVal convert(const FloatFormat &fmt, const typename Traits<T>::IVal &value)
501 {
502 return Traits<T>::doConvert(fmt, value);
503 }
504
505 // Matching input and output types. We may be in a modulo case and modularDivisor may have an actual value.
506 template <typename T>
intervalContains(const Interval & interval,T value,const tcu::Maybe<T> & modularDivisor)507 bool intervalContains(const Interval &interval, T value, const tcu::Maybe<T> &modularDivisor)
508 {
509 bool contained = interval.contains(value);
510
511 if (!contained && modularDivisor)
512 {
513 const T divisor = modularDivisor.get();
514
515 // In a modulo operation, if the calculated answer contains the divisor, allow exactly 0.0 as a replacement. Alternatively,
516 // if the calculated answer contains 0.0, allow exactly the divisor as a replacement.
517 if (interval.contains(static_cast<double>(divisor)))
518 contained |= (value == 0.0);
519 if (interval.contains(0.0))
520 contained |= (value == divisor);
521 }
522 return contained;
523 }
524
525 // When the input and output types do not match, we are not in a real modulo operation. Do not take the divisor into account. This
526 // version is provided for syntactical compatibility only.
527 template <typename T, typename U>
intervalContains(const Interval & interval,T value,const tcu::Maybe<U> & modularDivisor)528 bool intervalContains(const Interval &interval, T value, const tcu::Maybe<U> &modularDivisor)
529 {
530 DE_UNREF(modularDivisor); // For release builds.
531 DE_ASSERT(!modularDivisor);
532 return interval.contains(value);
533 }
534
535 //! Common traits for scalar types.
536 template <typename T>
537 struct ScalarTraits
538 {
539 typedef Interval IVal;
540
doMakeIValvkt::shaderexecutor::ScalarTraits541 static Interval doMakeIVal(const T &value)
542 {
543 // Thankfully all scalar types have a well-defined conversion to `double`,
544 // hence Interval can represent their ranges without problems.
545 return Interval(double(value));
546 }
547
doUnionvkt::shaderexecutor::ScalarTraits548 static Interval doUnion(const Interval &a, const Interval &b)
549 {
550 return a | b;
551 }
552
doContainsvkt::shaderexecutor::ScalarTraits553 static bool doContains(const Interval &a, T value)
554 {
555 return a.contains(double(value));
556 }
557
doConvertvkt::shaderexecutor::ScalarTraits558 static Interval doConvert(const FloatFormat &fmt, const IVal &ival)
559 {
560 return fmt.convert(ival);
561 }
562
doConvertvkt::shaderexecutor::ScalarTraits563 static Interval doConvert(const FloatFormat &fmt, const IVal &ival, bool is16Bit)
564 {
565 DE_UNREF(is16Bit);
566 return fmt.convert(ival);
567 }
568
doRoundvkt::shaderexecutor::ScalarTraits569 static Interval doRound(const FloatFormat &fmt, T value)
570 {
571 return fmt.roundOut(double(value), false);
572 }
573 };
574
575 template <>
576 struct ScalarTraits<uint16_t>
577 {
578 typedef Interval IVal;
579
doMakeIValvkt::shaderexecutor::ScalarTraits580 static Interval doMakeIVal(const uint16_t &value)
581 {
582 // Thankfully all scalar types have a well-defined conversion to `double`,
583 // hence Interval can represent their ranges without problems.
584 return Interval(double(deFloat16To32(value)));
585 }
586
doUnionvkt::shaderexecutor::ScalarTraits587 static Interval doUnion(const Interval &a, const Interval &b)
588 {
589 return a | b;
590 }
591
doConvertvkt::shaderexecutor::ScalarTraits592 static Interval doConvert(const FloatFormat &fmt, const IVal &ival)
593 {
594 return fmt.convert(ival);
595 }
596
doRoundvkt::shaderexecutor::ScalarTraits597 static Interval doRound(const FloatFormat &fmt, uint16_t value)
598 {
599 return fmt.roundOut(double(deFloat16To32(value)), false);
600 }
601 };
602
603 template <>
604 struct Traits<float> : ScalarTraits<float>
605 {
doPrintIValvkt::shaderexecutor::Traits606 static void doPrintIVal(const FloatFormat &fmt, const Interval &ival, ostream &os)
607 {
608 os << fmt.intervalToHex(ival);
609 }
610
doPrintValue16vkt::shaderexecutor::Traits611 static void doPrintValue16(const FloatFormat &fmt, const float &value, ostream &os)
612 {
613 const uint32_t iRep = reinterpret_cast<const uint32_t &>(value);
614 float res0 = deFloat16To32((deFloat16)(iRep & 0xFFFF));
615 float res1 = deFloat16To32((deFloat16)(iRep >> 16));
616 os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
617 }
618
doPrintValue32vkt::shaderexecutor::Traits619 static void doPrintValue32(const FloatFormat &fmt, const float &value, ostream &os)
620 {
621 os << fmt.floatToHex(value);
622 }
623
doPrintValue64vkt::shaderexecutor::Traits624 static void doPrintValue64(const FloatFormat &fmt, const float &value, ostream &os)
625 {
626 os << fmt.floatToHex(value);
627 }
628
629 template <typename U>
doContainsvkt::shaderexecutor::Traits630 static bool doContains(const Interval &a, const float &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
631 {
632 if (is16Bit)
633 {
634 // Note: for deFloat16s packed in 32 bits, the original divisor is provided as a float to the shader in the input
635 // buffer, so U is also float here and we call the right interlvalContains() version.
636 const uint32_t iRep = reinterpret_cast<const uint32_t &>(value);
637 float res0 = deFloat16To32((deFloat16)(iRep & 0xFFFF));
638 float res1 = deFloat16To32((deFloat16)(iRep >> 16));
639 return intervalContains(a, res0, modularDivisor) && (res1 == -1.0);
640 }
641 return intervalContains(a, value, modularDivisor);
642 }
643 };
644
645 template <>
646 struct Traits<double> : ScalarTraits<double>
647 {
doPrintIValvkt::shaderexecutor::Traits648 static void doPrintIVal(const FloatFormat &fmt, const Interval &ival, ostream &os)
649 {
650 os << fmt.intervalToHex(ival);
651 }
652
doPrintValue16vkt::shaderexecutor::Traits653 static void doPrintValue16(const FloatFormat &fmt, const double &value, ostream &os)
654 {
655 const uint64_t iRep = reinterpret_cast<const uint64_t &>(value);
656 double byte0 = deFloat16To64((deFloat16)((iRep)&0xffff));
657 double byte1 = deFloat16To64((deFloat16)((iRep >> 16) & 0xffff));
658 double byte2 = deFloat16To64((deFloat16)((iRep >> 32) & 0xffff));
659 double byte3 = deFloat16To64((deFloat16)((iRep >> 48) & 0xffff));
660 os << fmt.floatToHex(byte0) << " " << fmt.floatToHex(byte1) << " " << fmt.floatToHex(byte2) << " "
661 << fmt.floatToHex(byte3);
662 }
663
doPrintValue32vkt::shaderexecutor::Traits664 static void doPrintValue32(const FloatFormat &fmt, const double &value, ostream &os)
665 {
666 const uint64_t iRep = reinterpret_cast<const uint64_t &>(value);
667 double res0 = static_cast<double>((float)((iRep)&0xffffffff));
668 double res1 = static_cast<double>((float)((iRep >> 32) & 0xffffffff));
669 os << fmt.floatToHex(res0) << " " << fmt.floatToHex(res1);
670 }
671
doPrintValue64vkt::shaderexecutor::Traits672 static void doPrintValue64(const FloatFormat &fmt, const double &value, ostream &os)
673 {
674 os << fmt.floatToHex(value);
675 }
676
677 template <class U>
doContainsvkt::shaderexecutor::Traits678 static bool doContains(const Interval &a, const double &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
679 {
680 DE_UNREF(is16Bit);
681 DE_ASSERT(!is16Bit);
682 return intervalContains(a, value, modularDivisor);
683 }
684 };
685
686 template <>
687 struct Traits<deFloat16> : ScalarTraits<deFloat16>
688 {
doPrintIValvkt::shaderexecutor::Traits689 static void doPrintIVal(const FloatFormat &fmt, const Interval &ival, ostream &os)
690 {
691 os << fmt.intervalToHex(ival);
692 }
693
doPrintValue16vkt::shaderexecutor::Traits694 static void doPrintValue16(const FloatFormat &fmt, const deFloat16 &value, ostream &os)
695 {
696 const float res0 = deFloat16To32(value);
697 os << fmt.floatToHex(static_cast<double>(res0));
698 }
doPrintValue32vkt::shaderexecutor::Traits699 static void doPrintValue32(const FloatFormat &fmt, const deFloat16 &value, ostream &os)
700 {
701 const float res0 = deFloat16To32(value);
702 os << fmt.floatToHex(static_cast<double>(res0));
703 }
704
doPrintValue64vkt::shaderexecutor::Traits705 static void doPrintValue64(const FloatFormat &fmt, const deFloat16 &value, ostream &os)
706 {
707 const double res0 = deFloat16To64(value);
708 os << fmt.floatToHex(res0);
709 }
710
711 // When the value and divisor are both deFloat16, convert both to float to call the right intervalContains version.
doContainsvkt::shaderexecutor::Traits712 static bool doContains(const Interval &a, const deFloat16 &value, bool is16Bit,
713 const tcu::Maybe<deFloat16> &modularDivisor)
714 {
715 DE_UNREF(is16Bit);
716 float res0 = deFloat16To32(value);
717 const tcu::Maybe<float> convertedDivisor =
718 (modularDivisor ? tcu::just(deFloat16To32(modularDivisor.get())) : tcu::Nothing);
719 return intervalContains(a, res0, convertedDivisor);
720 }
721
722 // If the types don't match we should not be in a modulo operation, so no conversion should take place.
723 template <class U>
doContainsvkt::shaderexecutor::Traits724 static bool doContains(const Interval &a, const deFloat16 &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
725 {
726 DE_UNREF(is16Bit);
727 float res0 = deFloat16To32(value);
728 return intervalContains(a, res0, modularDivisor);
729 }
730 };
731
732 template <>
733 struct Traits<bool> : ScalarTraits<bool>
734 {
doPrintValue16vkt::shaderexecutor::Traits735 static void doPrintValue16(const FloatFormat &, const float &value, ostream &os)
736 {
737 os << (value != 0.0f ? "true" : "false");
738 }
739
doPrintValue32vkt::shaderexecutor::Traits740 static void doPrintValue32(const FloatFormat &, const float &value, ostream &os)
741 {
742 os << (value != 0.0f ? "true" : "false");
743 }
744
doPrintValue64vkt::shaderexecutor::Traits745 static void doPrintValue64(const FloatFormat &, const float &value, ostream &os)
746 {
747 os << (value != 0.0f ? "true" : "false");
748 }
749
doPrintIValvkt::shaderexecutor::Traits750 static void doPrintIVal(const FloatFormat &, const Interval &ival, ostream &os)
751 {
752 os << "{";
753 if (ival.contains(false))
754 os << "false";
755 if (ival.contains(false) && ival.contains(true))
756 os << ", ";
757 if (ival.contains(true))
758 os << "true";
759 os << "}";
760 }
761 };
762
763 template <>
764 struct Traits<int> : ScalarTraits<int>
765 {
doPrintValue16vkt::shaderexecutor::Traits766 static void doPrintValue16(const FloatFormat &, const int &value, ostream &os)
767 {
768 int res0 = value & 0xFFFF;
769 int res1 = value >> 16;
770 os << res0 << " " << res1;
771 }
772
doPrintValue32vkt::shaderexecutor::Traits773 static void doPrintValue32(const FloatFormat &, const int &value, ostream &os)
774 {
775 os << value;
776 }
777
doPrintValue64vkt::shaderexecutor::Traits778 static void doPrintValue64(const FloatFormat &, const int &value, ostream &os)
779 {
780 os << value;
781 }
782
doPrintIValvkt::shaderexecutor::Traits783 static void doPrintIVal(const FloatFormat &, const Interval &ival, ostream &os)
784 {
785 os << "[" << int(ival.lo()) << ", " << int(ival.hi()) << "]";
786 }
787
788 template <typename U>
doContainsvkt::shaderexecutor::Traits789 static bool doContains(const Interval &a, const int &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
790 {
791 DE_UNREF(is16Bit);
792 return intervalContains(a, value, modularDivisor);
793 }
794 };
795
796 //! Common traits for containers, i.e. vectors and matrices.
797 //! T is the container type itself, I is the same type with interval elements.
798 template <typename T, typename I>
799 struct ContainerTraits
800 {
801 typedef typename T::Element Element;
802 typedef I IVal;
803
doMakeIValvkt::shaderexecutor::ContainerTraits804 static IVal doMakeIVal(const T &value)
805 {
806 IVal ret;
807
808 for (int ndx = 0; ndx < T::SIZE; ++ndx)
809 ret[ndx] = makeIVal(value[ndx]);
810
811 return ret;
812 }
813
doUnionvkt::shaderexecutor::ContainerTraits814 static IVal doUnion(const IVal &a, const IVal &b)
815 {
816 IVal ret;
817
818 for (int ndx = 0; ndx < T::SIZE; ++ndx)
819 ret[ndx] = unionIVal<Element>(a[ndx], b[ndx]);
820
821 return ret;
822 }
823
824 // When the input and output types match, we may be in a modulo operation. If the divisor is provided, use each of its
825 // components to determine if the obtained result is fine.
doContainsvkt::shaderexecutor::ContainerTraits826 static bool doContains(const IVal &ival, const T &value, bool is16Bit, const tcu::Maybe<T> &modularDivisor)
827 {
828 using DivisorElement = typename T::Element;
829
830 for (int ndx = 0; ndx < T::SIZE; ++ndx)
831 {
832 const tcu::Maybe<DivisorElement> divisorElement =
833 (modularDivisor ? tcu::just((*modularDivisor)[ndx]) : tcu::Nothing);
834 if (!contains(ival[ndx], value[ndx], is16Bit, divisorElement))
835 return false;
836 }
837
838 return true;
839 }
840
841 // When the input and output types do not match we should not be in a modulo operation. This version is provided for syntactical
842 // compatibility.
843 template <typename U>
doContainsvkt::shaderexecutor::ContainerTraits844 static bool doContains(const IVal &ival, const T &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
845 {
846 for (int ndx = 0; ndx < T::SIZE; ++ndx)
847 {
848 if (!contains(ival[ndx], value[ndx], is16Bit, modularDivisor))
849 return false;
850 }
851
852 return true;
853 }
854
doPrintIValvkt::shaderexecutor::ContainerTraits855 static void doPrintIVal(const FloatFormat &fmt, const IVal ival, ostream &os)
856 {
857 os << "(";
858
859 for (int ndx = 0; ndx < T::SIZE; ++ndx)
860 {
861 if (ndx > 0)
862 os << ", ";
863
864 printIVal<Element>(fmt, ival[ndx], os);
865 }
866
867 os << ")";
868 }
869
doPrintValue16vkt::shaderexecutor::ContainerTraits870 static void doPrintValue16(const FloatFormat &fmt, const T &value, ostream &os)
871 {
872 os << dataTypeNameOf<T>() << "(";
873
874 for (int ndx = 0; ndx < T::SIZE; ++ndx)
875 {
876 if (ndx > 0)
877 os << ", ";
878
879 printValue16<Element>(fmt, value[ndx], os);
880 }
881
882 os << ")";
883 }
884
doPrintValue32vkt::shaderexecutor::ContainerTraits885 static void doPrintValue32(const FloatFormat &fmt, const T &value, ostream &os)
886 {
887 os << dataTypeNameOf<T>() << "(";
888
889 for (int ndx = 0; ndx < T::SIZE; ++ndx)
890 {
891 if (ndx > 0)
892 os << ", ";
893
894 printValue32<Element>(fmt, value[ndx], os);
895 }
896
897 os << ")";
898 }
899
doPrintValue64vkt::shaderexecutor::ContainerTraits900 static void doPrintValue64(const FloatFormat &fmt, const T &value, ostream &os)
901 {
902 os << dataTypeNameOf<T>() << "(";
903
904 for (int ndx = 0; ndx < T::SIZE; ++ndx)
905 {
906 if (ndx > 0)
907 os << ", ";
908
909 printValue64<Element>(fmt, value[ndx], os);
910 }
911
912 os << ")";
913 }
914
doConvertvkt::shaderexecutor::ContainerTraits915 static IVal doConvert(const FloatFormat &fmt, const IVal &value)
916 {
917 IVal ret;
918
919 for (int ndx = 0; ndx < T::SIZE; ++ndx)
920 ret[ndx] = convert<Element>(fmt, value[ndx]);
921
922 return ret;
923 }
924
doRoundvkt::shaderexecutor::ContainerTraits925 static IVal doRound(const FloatFormat &fmt, T value)
926 {
927 IVal ret;
928
929 for (int ndx = 0; ndx < T::SIZE; ++ndx)
930 ret[ndx] = round(fmt, value[ndx]);
931
932 return ret;
933 }
934 };
935
936 template <typename T, int Size>
937 struct Traits<Vector<T, Size>> : ContainerTraits<Vector<T, Size>, Vector<typename Traits<T>::IVal, Size>>
938 {
939 };
940
941 template <typename T, int Rows, int Cols>
942 struct Traits<Matrix<T, Rows, Cols>>
943 : ContainerTraits<Matrix<T, Rows, Cols>, Matrix<typename Traits<T>::IVal, Rows, Cols>>
944 {
945 };
946
947 //! Void traits. These are just dummies, but technically valid: a Void is a
948 //! unit type with a single possible value.
949 template <>
950 struct Traits<Void>
951 {
952 typedef Void IVal;
953
doMakeIValvkt::shaderexecutor::Traits954 static Void doMakeIVal(const Void &value)
955 {
956 return value;
957 }
doUnionvkt::shaderexecutor::Traits958 static Void doUnion(const Void &, const Void &)
959 {
960 return Void();
961 }
doContainsvkt::shaderexecutor::Traits962 static bool doContains(const Void &, Void)
963 {
964 return true;
965 }
966 template <typename U>
doContainsvkt::shaderexecutor::Traits967 static bool doContains(const Void &, const Void &value, bool is16Bit, const tcu::Maybe<U> &modularDivisor)
968 {
969 DE_UNREF(value);
970 DE_UNREF(is16Bit);
971 DE_UNREF(modularDivisor);
972 return true;
973 }
doRoundvkt::shaderexecutor::Traits974 static Void doRound(const FloatFormat &, const Void &value)
975 {
976 return value;
977 }
doConvertvkt::shaderexecutor::Traits978 static Void doConvert(const FloatFormat &, const Void &value)
979 {
980 return value;
981 }
982
doPrintValue16vkt::shaderexecutor::Traits983 static void doPrintValue16(const FloatFormat &, const Void &, ostream &os)
984 {
985 os << "()";
986 }
987
doPrintValue32vkt::shaderexecutor::Traits988 static void doPrintValue32(const FloatFormat &, const Void &, ostream &os)
989 {
990 os << "()";
991 }
992
doPrintValue64vkt::shaderexecutor::Traits993 static void doPrintValue64(const FloatFormat &, const Void &, ostream &os)
994 {
995 os << "()";
996 }
997
doPrintIValvkt::shaderexecutor::Traits998 static void doPrintIVal(const FloatFormat &, const Void &, ostream &os)
999 {
1000 os << "()";
1001 }
1002 };
1003
1004 //! This is needed for container-generic operations.
1005 //! We want a scalar type T to be its own "one-element vector".
1006 template <typename T, int Size>
1007 struct ContainerOf
1008 {
1009 typedef Vector<T, Size> Container;
1010 };
1011
1012 template <typename T>
1013 struct ContainerOf<T, 1>
1014 {
1015 typedef T Container;
1016 };
1017 template <int Size>
1018 struct ContainerOf<Void, Size>
1019 {
1020 typedef Void Container;
1021 };
1022
1023 // This is a kludge that is only needed to get the ExprP::operator[] syntactic sugar to work.
1024 template <typename T>
1025 struct ElementOf
1026 {
1027 typedef typename T::Element Element;
1028 };
1029 template <>
1030 struct ElementOf<float>
1031 {
1032 typedef void Element;
1033 };
1034 template <>
1035 struct ElementOf<double>
1036 {
1037 typedef void Element;
1038 };
1039 template <>
1040 struct ElementOf<bool>
1041 {
1042 typedef void Element;
1043 };
1044 template <>
1045 struct ElementOf<int>
1046 {
1047 typedef void Element;
1048 };
1049
1050 template <typename T>
comparisonMessageInterval(const typename Traits<T>::IVal & val)1051 string comparisonMessageInterval(const typename Traits<T>::IVal &val)
1052 {
1053 DE_UNREF(val);
1054 return "";
1055 }
1056
1057 template <>
comparisonMessageInterval(const Traits<int>::IVal & val)1058 string comparisonMessageInterval<int>(const Traits<int>::IVal &val)
1059 {
1060 return comparisonMessage(static_cast<int>(val.lo()));
1061 }
1062
1063 template <>
comparisonMessageInterval(const Traits<float>::IVal & val)1064 string comparisonMessageInterval<float>(const Traits<float>::IVal &val)
1065 {
1066 return comparisonMessage(static_cast<int>(val.lo()));
1067 }
1068
1069 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,2> & val)1070 string comparisonMessageInterval<tcu::Vector<int, 2>>(const tcu::Vector<tcu::Interval, 2> &val)
1071 {
1072 tcu::IVec2 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()));
1073 return comparisonMessage(result);
1074 }
1075
1076 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,3> & val)1077 string comparisonMessageInterval<tcu::Vector<int, 3>>(const tcu::Vector<tcu::Interval, 3> &val)
1078 {
1079 tcu::IVec3 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()));
1080 return comparisonMessage(result);
1081 }
1082
1083 template <>
comparisonMessageInterval(const tcu::Vector<tcu::Interval,4> & val)1084 string comparisonMessageInterval<tcu::Vector<int, 4>>(const tcu::Vector<tcu::Interval, 4> &val)
1085 {
1086 tcu::IVec4 result(static_cast<int>(val[0].lo()), static_cast<int>(val[1].lo()), static_cast<int>(val[2].lo()),
1087 static_cast<int>(val[3].lo()));
1088 return comparisonMessage(result);
1089 }
1090
1091 /*--------------------------------------------------------------------*//*!
1092 *
1093 * \name Abstract syntax for expressions and statements.
1094 *
1095 * We represent GLSL programs as syntax objects: an Expr<T> represents an
1096 * expression whose GLSL type corresponds to the C++ type T, and a Statement
1097 * represents a statement.
1098 *
1099 * To ease memory management, we use shared pointers to refer to expressions
1100 * and statements. ExprP<T> is a shared pointer to an Expr<T>, and StatementP
1101 * is a shared pointer to a Statement.
1102 *
1103 * \{
1104 *
1105 *//*--------------------------------------------------------------------*/
1106
1107 class ExprBase;
1108 class ExpandContext;
1109 class Statement;
1110 class StatementP;
1111 class FuncBase;
1112 template <typename T>
1113 class ExprP;
1114 template <typename T>
1115 class Variable;
1116 template <typename T>
1117 class VariableP;
1118 template <typename T>
1119 class DefaultSampling;
1120
1121 typedef set<const FuncBase *> FuncSet;
1122
1123 template <typename T>
1124 VariableP<T> variable(const string &name);
1125 StatementP compoundStatement(const vector<StatementP> &statements);
1126
1127 /*--------------------------------------------------------------------*//*!
1128 * \brief A variable environment.
1129 *
1130 * An Environment object maintains the mapping between variables of the
1131 * abstract syntax tree and their values.
1132 *
1133 * \todo [2014-03-28 lauri] At least run-time type safety.
1134 *
1135 *//*--------------------------------------------------------------------*/
1136 class Environment
1137 {
1138 public:
1139 template <typename T>
bind(const Variable<T> & variable,const typename Traits<T>::IVal & value)1140 void bind(const Variable<T> &variable, const typename Traits<T>::IVal &value)
1141 {
1142 uint8_t *const data = new uint8_t[sizeof(value)];
1143
1144 deMemcpy(data, &value, sizeof(value));
1145 de::insert(m_map, variable.getName(), SharedPtr<uint8_t>(data, de::ArrayDeleter<uint8_t>()));
1146 }
1147
1148 template <typename T>
lookup(const Variable<T> & variable) const1149 typename Traits<T>::IVal &lookup(const Variable<T> &variable) const
1150 {
1151 uint8_t *const data = de::lookup(m_map, variable.getName()).get();
1152
1153 return *reinterpret_cast<typename Traits<T>::IVal *>(data);
1154 }
1155
1156 private:
1157 map<string, SharedPtr<uint8_t>> m_map;
1158 };
1159
1160 /*--------------------------------------------------------------------*//*!
1161 * \brief Evaluation context.
1162 *
1163 * The evaluation context contains everything that separates one execution of
1164 * an expression from the next. Currently this means the desired floating
1165 * point precision and the current variable environment.
1166 *
1167 *//*--------------------------------------------------------------------*/
1168 struct EvalContext
1169 {
EvalContextvkt::shaderexecutor::EvalContext1170 EvalContext(const FloatFormat &format_, Precision floatPrecision_, Environment &env_, int callDepth_)
1171 : format(format_)
1172 , floatPrecision(floatPrecision_)
1173 , env(env_)
1174 , callDepth(callDepth_)
1175 {
1176 }
1177
1178 FloatFormat format;
1179 Precision floatPrecision;
1180 Environment &env;
1181 int callDepth;
1182 };
1183
1184 /*--------------------------------------------------------------------*//*!
1185 * \brief Simple incremental counter.
1186 *
1187 * This is used to make sure that different ExpandContexts will not produce
1188 * overlapping temporary names.
1189 *
1190 *//*--------------------------------------------------------------------*/
1191 class Counter
1192 {
1193 public:
Counter(int count=0)1194 Counter(int count = 0) : m_count(count)
1195 {
1196 }
operator ()(void)1197 int operator()(void)
1198 {
1199 return m_count++;
1200 }
1201
1202 private:
1203 int m_count;
1204 };
1205
1206 class ExpandContext
1207 {
1208 public:
ExpandContext(Counter & symCounter)1209 ExpandContext(Counter &symCounter) : m_symCounter(symCounter)
1210 {
1211 }
ExpandContext(const ExpandContext & parent)1212 ExpandContext(const ExpandContext &parent) : m_symCounter(parent.m_symCounter)
1213 {
1214 }
1215
1216 template <typename T>
genSym(const string & baseName)1217 VariableP<T> genSym(const string &baseName)
1218 {
1219 return variable<T>(baseName + de::toString(m_symCounter()));
1220 }
1221
addStatement(const StatementP & stmt)1222 void addStatement(const StatementP &stmt)
1223 {
1224 m_statements.push_back(stmt);
1225 }
1226
getStatements(void) const1227 vector<StatementP> getStatements(void) const
1228 {
1229 return m_statements;
1230 }
1231
1232 private:
1233 Counter &m_symCounter;
1234 vector<StatementP> m_statements;
1235 };
1236
1237 /*--------------------------------------------------------------------*//*!
1238 * \brief A statement or declaration.
1239 *
1240 * Statements have no values. Instead, they are executed for their side
1241 * effects only: the execute() method should modify at least one variable in
1242 * the environment.
1243 *
1244 * As a bit of a kludge, a Statement object can also represent a declaration:
1245 * when it is evaluated, it can add a variable binding to the environment
1246 * instead of modifying a current one.
1247 *
1248 *//*--------------------------------------------------------------------*/
1249 class Statement
1250 {
1251 public:
~Statement(void)1252 virtual ~Statement(void)
1253 {
1254 }
1255 //! Execute the statement, modifying the environment of `ctx`
execute(EvalContext & ctx) const1256 void execute(EvalContext &ctx) const
1257 {
1258 this->doExecute(ctx);
1259 }
print(ostream & os) const1260 void print(ostream &os) const
1261 {
1262 this->doPrint(os);
1263 }
1264 //! Add the functions used in this statement to `dst`.
getUsedFuncs(FuncSet & dst) const1265 void getUsedFuncs(FuncSet &dst) const
1266 {
1267 this->doGetUsedFuncs(dst);
1268 }
failed(EvalContext & ctx) const1269 void failed(EvalContext &ctx) const
1270 {
1271 this->doFail(ctx);
1272 }
1273
1274 protected:
1275 virtual void doPrint(ostream &os) const = 0;
1276 virtual void doExecute(EvalContext &ctx) const = 0;
1277 virtual void doGetUsedFuncs(FuncSet &dst) const = 0;
doFail(EvalContext & ctx) const1278 virtual void doFail(EvalContext &ctx) const
1279 {
1280 DE_UNREF(ctx);
1281 }
1282 };
1283
operator <<(ostream & os,const Statement & stmt)1284 ostream &operator<<(ostream &os, const Statement &stmt)
1285 {
1286 stmt.print(os);
1287 return os;
1288 }
1289
1290 /*--------------------------------------------------------------------*//*!
1291 * \brief Smart pointer for statements (and declarations)
1292 *
1293 *//*--------------------------------------------------------------------*/
1294 class StatementP : public SharedPtr<const Statement>
1295 {
1296 public:
1297 typedef SharedPtr<const Statement> Super;
1298
StatementP(void)1299 StatementP(void)
1300 {
1301 }
StatementP(const Statement * ptr)1302 explicit StatementP(const Statement *ptr) : Super(ptr)
1303 {
1304 }
StatementP(const Super & ptr)1305 StatementP(const Super &ptr) : Super(ptr)
1306 {
1307 }
1308 };
1309
1310 /*--------------------------------------------------------------------*//*!
1311 * \brief
1312 *
1313 * A statement that modifies a variable or a declaration that binds a variable.
1314 *
1315 *//*--------------------------------------------------------------------*/
1316 template <typename T>
1317 class VariableStatement : public Statement
1318 {
1319 public:
VariableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1320 VariableStatement(const VariableP<T> &variable, const ExprP<T> &value, bool isDeclaration)
1321 : m_variable(variable)
1322 , m_value(value)
1323 , m_isDeclaration(isDeclaration)
1324 {
1325 }
1326
1327 protected:
doPrint(ostream & os) const1328 void doPrint(ostream &os) const
1329 {
1330 if (m_isDeclaration)
1331 os << glu::declare(getVarTypeOf<T>(), m_variable->getName());
1332 else
1333 os << m_variable->getName();
1334
1335 os << " = ";
1336 os << *m_value << ";\n";
1337 }
1338
doExecute(EvalContext & ctx) const1339 void doExecute(EvalContext &ctx) const
1340 {
1341 if (m_isDeclaration)
1342 ctx.env.bind(*m_variable, m_value->evaluate(ctx));
1343 else
1344 ctx.env.lookup(*m_variable) = m_value->evaluate(ctx);
1345 }
1346
doGetUsedFuncs(FuncSet & dst) const1347 void doGetUsedFuncs(FuncSet &dst) const
1348 {
1349 m_value->getUsedFuncs(dst);
1350 }
1351
doFail(EvalContext & ctx) const1352 virtual void doFail(EvalContext &ctx) const
1353 {
1354 if (m_isDeclaration)
1355 ctx.env.bind(*m_variable, m_value->fails(ctx));
1356 else
1357 ctx.env.lookup(*m_variable) = m_value->fails(ctx);
1358 }
1359
1360 VariableP<T> m_variable;
1361 ExprP<T> m_value;
1362 bool m_isDeclaration;
1363 };
1364
1365 template <typename T>
variableStatement(const VariableP<T> & variable,const ExprP<T> & value,bool isDeclaration)1366 StatementP variableStatement(const VariableP<T> &variable, const ExprP<T> &value, bool isDeclaration)
1367 {
1368 return StatementP(new VariableStatement<T>(variable, value, isDeclaration));
1369 }
1370
1371 template <typename T>
variableDeclaration(const VariableP<T> & variable,const ExprP<T> & definiens)1372 StatementP variableDeclaration(const VariableP<T> &variable, const ExprP<T> &definiens)
1373 {
1374 return variableStatement(variable, definiens, true);
1375 }
1376
1377 template <typename T>
variableAssignment(const VariableP<T> & variable,const ExprP<T> & value)1378 StatementP variableAssignment(const VariableP<T> &variable, const ExprP<T> &value)
1379 {
1380 return variableStatement(variable, value, false);
1381 }
1382
1383 /*--------------------------------------------------------------------*//*!
1384 * \brief A compound statement, i.e. a block.
1385 *
1386 * A compound statement is executed by executing its constituent statements in
1387 * sequence.
1388 *
1389 *//*--------------------------------------------------------------------*/
1390 class CompoundStatement : public Statement
1391 {
1392 public:
CompoundStatement(const vector<StatementP> & statements)1393 CompoundStatement(const vector<StatementP> &statements) : m_statements(statements)
1394 {
1395 }
1396
1397 protected:
doPrint(ostream & os) const1398 void doPrint(ostream &os) const
1399 {
1400 os << "{\n";
1401
1402 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1403 os << *m_statements[ndx];
1404
1405 os << "}\n";
1406 }
1407
doExecute(EvalContext & ctx) const1408 void doExecute(EvalContext &ctx) const
1409 {
1410 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1411 m_statements[ndx]->execute(ctx);
1412 }
1413
doGetUsedFuncs(FuncSet & dst) const1414 void doGetUsedFuncs(FuncSet &dst) const
1415 {
1416 for (size_t ndx = 0; ndx < m_statements.size(); ++ndx)
1417 m_statements[ndx]->getUsedFuncs(dst);
1418 }
1419
1420 vector<StatementP> m_statements;
1421 };
1422
compoundStatement(const vector<StatementP> & statements)1423 StatementP compoundStatement(const vector<StatementP> &statements)
1424 {
1425 return StatementP(new CompoundStatement(statements));
1426 }
1427
1428 //! Common base class for all expressions regardless of their type.
1429 class ExprBase
1430 {
1431 public:
~ExprBase(void)1432 virtual ~ExprBase(void)
1433 {
1434 }
printExpr(ostream & os) const1435 void printExpr(ostream &os) const
1436 {
1437 this->doPrintExpr(os);
1438 }
1439
1440 //! Output the functions that this expression refers to
getUsedFuncs(FuncSet & dst) const1441 void getUsedFuncs(FuncSet &dst) const
1442 {
1443 this->doGetUsedFuncs(dst);
1444 }
1445
1446 protected:
doPrintExpr(ostream &) const1447 virtual void doPrintExpr(ostream &) const
1448 {
1449 }
doGetUsedFuncs(FuncSet &) const1450 virtual void doGetUsedFuncs(FuncSet &) const
1451 {
1452 }
1453 };
1454
1455 //! Type-specific operations for an expression representing type T.
1456 template <typename T>
1457 class Expr : public ExprBase
1458 {
1459 public:
1460 typedef T Val;
1461 typedef typename Traits<T>::IVal IVal;
1462
1463 IVal evaluate(const EvalContext &ctx) const;
fails(const EvalContext & ctx) const1464 IVal fails(const EvalContext &ctx) const
1465 {
1466 return this->doFails(ctx);
1467 }
1468
1469 protected:
1470 virtual IVal doEvaluate(const EvalContext &ctx) const = 0;
doFails(const EvalContext & ctx) const1471 virtual IVal doFails(const EvalContext &ctx) const
1472 {
1473 return doEvaluate(ctx);
1474 }
1475 };
1476
1477 //! Evaluate an expression with the given context, optionally tracing the calls to stderr.
1478 template <typename T>
evaluate(const EvalContext & ctx) const1479 typename Traits<T>::IVal Expr<T>::evaluate(const EvalContext &ctx) const
1480 {
1481 #ifdef GLS_ENABLE_TRACE
1482 static const FloatFormat highpFmt(-126, 127, 23, true, tcu::MAYBE, tcu::YES, tcu::MAYBE);
1483 EvalContext newCtx(ctx.format, ctx.floatPrecision, ctx.env, ctx.callDepth + 1);
1484 const IVal ret = this->doEvaluate(newCtx);
1485
1486 if (isTypeValid<T>())
1487 {
1488 std::cerr << string(ctx.callDepth, ' ');
1489 this->printExpr(std::cerr);
1490 std::cerr << " -> " << intervalToString<T>(highpFmt, ret) << std::endl;
1491 }
1492 return ret;
1493 #else
1494 return this->doEvaluate(ctx);
1495 #endif
1496 }
1497
1498 template <typename T>
1499 class ExprPBase : public SharedPtr<const Expr<T>>
1500 {
1501 public:
1502 };
1503
operator <<(ostream & os,const ExprBase & expr)1504 ostream &operator<<(ostream &os, const ExprBase &expr)
1505 {
1506 expr.printExpr(os);
1507 return os;
1508 }
1509
1510 /*--------------------------------------------------------------------*//*!
1511 * \brief Shared pointer to an expression of a container type.
1512 *
1513 * Container types (i.e. vectors and matrices) support the subscription
1514 * operator. This class provides a bit of syntactic sugar to allow us to use
1515 * the C++ subscription operator to create a subscription expression.
1516 *//*--------------------------------------------------------------------*/
1517 template <typename T>
1518 class ContainerExprPBase : public ExprPBase<T>
1519 {
1520 public:
1521 ExprP<typename T::Element> operator[](int i) const;
1522 };
1523
1524 template <typename T>
1525 class ExprP : public ExprPBase<T>
1526 {
1527 };
1528
1529 // We treat Voids as containers since the unused parameters in generalized
1530 // vector functions are represented as Voids.
1531 template <>
1532 class ExprP<Void> : public ContainerExprPBase<Void>
1533 {
1534 };
1535
1536 template <typename T, int Size>
1537 class ExprP<Vector<T, Size>> : public ContainerExprPBase<Vector<T, Size>>
1538 {
1539 };
1540
1541 template <typename T, int Rows, int Cols>
1542 class ExprP<Matrix<T, Rows, Cols>> : public ContainerExprPBase<Matrix<T, Rows, Cols>>
1543 {
1544 };
1545
1546 template <typename T>
exprP(void)1547 ExprP<T> exprP(void)
1548 {
1549 return ExprP<T>();
1550 }
1551
1552 template <typename T>
exprP(const SharedPtr<const Expr<T>> & ptr)1553 ExprP<T> exprP(const SharedPtr<const Expr<T>> &ptr)
1554 {
1555 ExprP<T> ret;
1556 static_cast<SharedPtr<const Expr<T>> &>(ret) = ptr;
1557 return ret;
1558 }
1559
1560 template <typename T>
exprP(const Expr<T> * ptr)1561 ExprP<T> exprP(const Expr<T> *ptr)
1562 {
1563 return exprP(SharedPtr<const Expr<T>>(ptr));
1564 }
1565
1566 /*--------------------------------------------------------------------*//*!
1567 * \brief A shared pointer to a variable expression.
1568 *
1569 * This is just a narrowing of ExprP for the operations that require a variable
1570 * instead of an arbitrary expression.
1571 *
1572 *//*--------------------------------------------------------------------*/
1573 template <typename T>
1574 class VariableP : public SharedPtr<const Variable<T>>
1575 {
1576 public:
1577 typedef SharedPtr<const Variable<T>> Super;
VariableP(const Variable<T> * ptr)1578 explicit VariableP(const Variable<T> *ptr) : Super(ptr)
1579 {
1580 }
VariableP(void)1581 VariableP(void)
1582 {
1583 }
VariableP(const Super & ptr)1584 VariableP(const Super &ptr) : Super(ptr)
1585 {
1586 }
1587
operator ExprP<T>(void) const1588 operator ExprP<T>(void) const
1589 {
1590 SharedPtr<const Expr<T>> ptr = *this;
1591 return exprP(ptr);
1592 }
1593 };
1594
1595 /*--------------------------------------------------------------------*//*!
1596 * \name Syntactic sugar operators for expressions.
1597 *
1598 * @{
1599 *
1600 * These operators allow the use of C++ syntax to construct GLSL expressions
1601 * containing operators: e.g. "a+b" creates an addition expression with
1602 * operands a and b, and so on.
1603 *
1604 *//*--------------------------------------------------------------------*/
1605 ExprP<float> operator+(const ExprP<float> &arg0, const ExprP<float> &arg1);
1606 ExprP<deFloat16> operator+(const ExprP<deFloat16> &arg0, const ExprP<deFloat16> &arg1);
1607 ExprP<double> operator+(const ExprP<double> &arg0, const ExprP<double> &arg1);
1608 template <typename T>
1609 ExprP<T> operator-(const ExprP<T> &arg0);
1610 template <typename T>
1611 ExprP<T> operator-(const ExprP<T> &arg0, const ExprP<T> &arg1);
1612 template <int Left, int Mid, int Right, typename T>
1613 ExprP<Matrix<T, Left, Right>> operator*(const ExprP<Matrix<T, Left, Mid>> &left,
1614 const ExprP<Matrix<T, Mid, Right>> &right);
1615 ExprP<float> operator*(const ExprP<float> &arg0, const ExprP<float> &arg1);
1616 ExprP<deFloat16> operator*(const ExprP<deFloat16> &arg0, const ExprP<deFloat16> &arg1);
1617 ExprP<double> operator*(const ExprP<double> &arg0, const ExprP<double> &arg1);
1618 template <typename T>
1619 ExprP<T> operator/(const ExprP<T> &arg0, const ExprP<T> &arg1);
1620 template <typename T, int Size>
1621 ExprP<Vector<T, Size>> operator-(const ExprP<Vector<T, Size>> &arg0);
1622 template <typename T, int Size>
1623 ExprP<Vector<T, Size>> operator-(const ExprP<Vector<T, Size>> &arg0, const ExprP<Vector<T, Size>> &arg1);
1624 template <int Size, typename T>
1625 ExprP<Vector<T, Size>> operator*(const ExprP<Vector<T, Size>> &arg0, const ExprP<T> &arg1);
1626 template <typename T, int Size>
1627 ExprP<Vector<T, Size>> operator*(const ExprP<Vector<T, Size>> &arg0, const ExprP<Vector<T, Size>> &arg1);
1628 template <int Rows, int Cols, typename T>
1629 ExprP<Vector<T, Rows>> operator*(const ExprP<Vector<T, Cols>> &left, const ExprP<Matrix<T, Rows, Cols>> &right);
1630 template <int Rows, int Cols, typename T>
1631 ExprP<Vector<T, Cols>> operator*(const ExprP<Matrix<T, Rows, Cols>> &left, const ExprP<Vector<T, Rows>> &right);
1632 template <int Rows, int Cols, typename T>
1633 ExprP<Matrix<T, Rows, Cols>> operator*(const ExprP<Matrix<T, Rows, Cols>> &left, const ExprP<T> &right);
1634 template <int Rows, int Cols>
1635 ExprP<Matrix<float, Rows, Cols>> operator+(const ExprP<Matrix<float, Rows, Cols>> &left,
1636 const ExprP<Matrix<float, Rows, Cols>> &right);
1637 template <int Rows, int Cols>
1638 ExprP<Matrix<deFloat16, Rows, Cols>> operator+(const ExprP<Matrix<deFloat16, Rows, Cols>> &left,
1639 const ExprP<Matrix<deFloat16, Rows, Cols>> &right);
1640 template <int Rows, int Cols>
1641 ExprP<Matrix<double, Rows, Cols>> operator+(const ExprP<Matrix<double, Rows, Cols>> &left,
1642 const ExprP<Matrix<double, Rows, Cols>> &right);
1643 template <typename T, int Rows, int Cols>
1644 ExprP<Matrix<T, Rows, Cols>> operator-(const ExprP<Matrix<T, Rows, Cols>> &mat);
1645
1646 //! @}
1647
1648 /*--------------------------------------------------------------------*//*!
1649 * \brief Variable expression.
1650 *
1651 * A variable is evaluated by looking up its range of possible values from an
1652 * environment.
1653 *//*--------------------------------------------------------------------*/
1654 template <typename T>
1655 class Variable : public Expr<T>
1656 {
1657 public:
1658 typedef typename Expr<T>::IVal IVal;
1659
Variable(const string & name)1660 Variable(const string &name) : m_name(name)
1661 {
1662 }
getName(void) const1663 string getName(void) const
1664 {
1665 return m_name;
1666 }
1667
1668 protected:
doPrintExpr(ostream & os) const1669 void doPrintExpr(ostream &os) const
1670 {
1671 os << m_name;
1672 }
doEvaluate(const EvalContext & ctx) const1673 IVal doEvaluate(const EvalContext &ctx) const
1674 {
1675 return ctx.env.lookup<T>(*this);
1676 }
1677
1678 private:
1679 string m_name;
1680 };
1681
1682 template <typename T>
variable(const string & name)1683 VariableP<T> variable(const string &name)
1684 {
1685 return VariableP<T>(new Variable<T>(name));
1686 }
1687
1688 template <typename T>
bindExpression(const string & name,ExpandContext & ctx,const ExprP<T> & expr)1689 VariableP<T> bindExpression(const string &name, ExpandContext &ctx, const ExprP<T> &expr)
1690 {
1691 VariableP<T> var = ctx.genSym<T>(name);
1692 ctx.addStatement(variableDeclaration(var, expr));
1693 return var;
1694 }
1695
1696 /*--------------------------------------------------------------------*//*!
1697 * \brief Constant expression.
1698 *
1699 * A constant is evaluated by rounding it to a set of possible values allowed
1700 * by the current floating point precision.
1701 * TODO: For whatever reason this doesn't happen, the constant is converted to
1702 * type T and the interval contains only (T)value. See FloatConstant, below.
1703 *//*--------------------------------------------------------------------*/
1704 template <typename T>
1705 class Constant : public Expr<T>
1706 {
1707 public:
1708 typedef typename Expr<T>::IVal IVal;
1709
Constant(const T & value)1710 Constant(const T &value) : m_value(value)
1711 {
1712 }
1713
1714 protected:
doPrintExpr(ostream & os) const1715 void doPrintExpr(ostream &os) const
1716 {
1717 os << m_value;
1718 }
doEvaluate(const EvalContext &) const1719 IVal doEvaluate(const EvalContext &) const
1720 {
1721 return makeIVal(m_value);
1722 }
1723
1724 private:
1725 T m_value;
1726 };
1727
1728 template <typename T>
constant(const T & value)1729 ExprP<T> constant(const T &value)
1730 {
1731 return exprP(new Constant<T>(value));
1732 }
1733
1734 template <typename T>
1735 class FloatConstant : public Expr<T>
1736 {
1737 public:
1738 typedef typename Expr<T>::IVal IVal;
1739
FloatConstant(double value)1740 FloatConstant(double value) : m_value(value)
1741 {
1742 }
1743
1744 protected:
doPrintExpr(ostream & os) const1745 void doPrintExpr(ostream &os) const
1746 {
1747 os << m_value;
1748 }
1749 // TODO: This should probably roundOut to T, not ctx.format, but the templates don't work like that.
doEvaluate(const EvalContext & ctx) const1750 IVal doEvaluate(const EvalContext &ctx) const
1751 {
1752 return ctx.format.roundOut(makeIVal(m_value), true);
1753 }
1754
1755 private:
1756 double m_value;
1757 };
1758
f16Constant(double value)1759 ExprP<deFloat16> f16Constant(double value)
1760 {
1761 return exprP(new FloatConstant<deFloat16>(value));
1762 }
f32Constant(double value)1763 ExprP<float> f32Constant(double value)
1764 {
1765 return exprP(new FloatConstant<float>(value));
1766 }
1767
1768 //! Return a reference to a singleton void constant.
voidP(void)1769 const ExprP<Void> &voidP(void)
1770 {
1771 static const ExprP<Void> singleton = constant(Void());
1772
1773 return singleton;
1774 }
1775
1776 /*--------------------------------------------------------------------*//*!
1777 * \brief Four-element tuple.
1778 *
1779 * This is used for various things where we need one thing for each possible
1780 * function parameter. Currently the maximum supported number of parameters is
1781 * four.
1782 *//*--------------------------------------------------------------------*/
1783 template <typename T0 = Void, typename T1 = Void, typename T2 = Void, typename T3 = Void>
1784 struct Tuple4
1785 {
Tuple4vkt::shaderexecutor::Tuple41786 explicit Tuple4(const T0 e0 = T0(), const T1 e1 = T1(), const T2 e2 = T2(), const T3 e3 = T3())
1787 : a(e0)
1788 , b(e1)
1789 , c(e2)
1790 , d(e3)
1791 {
1792 }
1793
1794 T0 a;
1795 T1 b;
1796 T2 c;
1797 T3 d;
1798 };
1799
1800 /*--------------------------------------------------------------------*//*!
1801 * \brief Function signature.
1802 *
1803 * This is a purely compile-time structure used to bundle all types in a
1804 * function signature together. This makes passing the signature around in
1805 * templates easier, since we only need to take and pass a single Sig instead
1806 * of a bunch of parameter types and a return type.
1807 *
1808 *//*--------------------------------------------------------------------*/
1809 template <typename R, typename P0 = Void, typename P1 = Void, typename P2 = Void, typename P3 = Void>
1810 struct Signature
1811 {
1812 typedef R Ret;
1813 typedef P0 Arg0;
1814 typedef P1 Arg1;
1815 typedef P2 Arg2;
1816 typedef P3 Arg3;
1817 typedef typename Traits<Ret>::IVal IRet;
1818 typedef typename Traits<Arg0>::IVal IArg0;
1819 typedef typename Traits<Arg1>::IVal IArg1;
1820 typedef typename Traits<Arg2>::IVal IArg2;
1821 typedef typename Traits<Arg3>::IVal IArg3;
1822
1823 typedef Tuple4<const Arg0 &, const Arg1 &, const Arg2 &, const Arg3 &> Args;
1824 typedef Tuple4<const IArg0 &, const IArg1 &, const IArg2 &, const IArg3 &> IArgs;
1825 typedef Tuple4<ExprP<Arg0>, ExprP<Arg1>, ExprP<Arg2>, ExprP<Arg3>> ArgExprs;
1826 };
1827
1828 typedef vector<const ExprBase *> BaseArgExprs;
1829
1830 /*--------------------------------------------------------------------*//*!
1831 * \brief Type-independent operations for function objects.
1832 *
1833 *//*--------------------------------------------------------------------*/
1834 class FuncBase
1835 {
1836 public:
~FuncBase(void)1837 virtual ~FuncBase(void)
1838 {
1839 }
1840 virtual string getName(void) const = 0;
1841 //! Name of extension that this function requires, or empty.
getRequiredExtension(void) const1842 virtual string getRequiredExtension(void) const
1843 {
1844 return "";
1845 }
getInputRange(const bool is16bit) const1846 virtual Interval getInputRange(const bool is16bit) const
1847 {
1848 DE_UNREF(is16bit);
1849 return Interval(true, -TCU_INFINITY, TCU_INFINITY);
1850 }
1851 virtual void print(ostream &, const BaseArgExprs &) const = 0;
1852 //! Index of output parameter, or -1 if none of the parameters is output.
getOutParamIndex(void) const1853 virtual int getOutParamIndex(void) const
1854 {
1855 return -1;
1856 }
1857
getSpirvCase(void) const1858 virtual SpirVCaseT getSpirvCase(void) const
1859 {
1860 return SPIRV_CASETYPE_NONE;
1861 }
1862
printDefinition(ostream & os) const1863 void printDefinition(ostream &os) const
1864 {
1865 doPrintDefinition(os);
1866 }
1867
getUsedFuncs(FuncSet & dst) const1868 void getUsedFuncs(FuncSet &dst) const
1869 {
1870 this->doGetUsedFuncs(dst);
1871 }
1872
1873 protected:
1874 virtual void doPrintDefinition(ostream &os) const = 0;
1875 virtual void doGetUsedFuncs(FuncSet &dst) const = 0;
1876 };
1877
1878 typedef Tuple4<string, string, string, string> ParamNames;
1879
1880 /*--------------------------------------------------------------------*//*!
1881 * \brief Function objects.
1882 *
1883 * Each Func object represents a GLSL function. It can be applied to interval
1884 * arguments, and it returns the an interval that is a conservative
1885 * approximation of the image of the GLSL function over the argument
1886 * intervals. That is, it is given a set of possible arguments and it returns
1887 * the set of possible values.
1888 *
1889 *//*--------------------------------------------------------------------*/
1890 template <typename Sig_>
1891 class Func : public FuncBase
1892 {
1893 public:
1894 typedef Sig_ Sig;
1895 typedef typename Sig::Ret Ret;
1896 typedef typename Sig::Arg0 Arg0;
1897 typedef typename Sig::Arg1 Arg1;
1898 typedef typename Sig::Arg2 Arg2;
1899 typedef typename Sig::Arg3 Arg3;
1900 typedef typename Sig::IRet IRet;
1901 typedef typename Sig::IArg0 IArg0;
1902 typedef typename Sig::IArg1 IArg1;
1903 typedef typename Sig::IArg2 IArg2;
1904 typedef typename Sig::IArg3 IArg3;
1905 typedef typename Sig::Args Args;
1906 typedef typename Sig::IArgs IArgs;
1907 typedef typename Sig::ArgExprs ArgExprs;
1908
print(ostream & os,const BaseArgExprs & args) const1909 void print(ostream &os, const BaseArgExprs &args) const
1910 {
1911 this->doPrint(os, args);
1912 }
1913
apply(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1914 IRet apply(const EvalContext &ctx, const IArg0 &arg0 = IArg0(), const IArg1 &arg1 = IArg1(),
1915 const IArg2 &arg2 = IArg2(), const IArg3 &arg3 = IArg3()) const
1916 {
1917 return this->applyArgs(ctx, IArgs(arg0, arg1, arg2, arg3));
1918 }
1919
fail(const EvalContext & ctx,const IArg0 & arg0=IArg0 (),const IArg1 & arg1=IArg1 (),const IArg2 & arg2=IArg2 (),const IArg3 & arg3=IArg3 ()) const1920 IRet fail(const EvalContext &ctx, const IArg0 &arg0 = IArg0(), const IArg1 &arg1 = IArg1(),
1921 const IArg2 &arg2 = IArg2(), const IArg3 &arg3 = IArg3()) const
1922 {
1923 return this->doFail(ctx, IArgs(arg0, arg1, arg2, arg3));
1924 }
applyArgs(const EvalContext & ctx,const IArgs & args) const1925 IRet applyArgs(const EvalContext &ctx, const IArgs &args) const
1926 {
1927 return this->doApply(ctx, args);
1928 }
1929 ExprP<Ret> operator()(const ExprP<Arg0> &arg0 = voidP(), const ExprP<Arg1> &arg1 = voidP(),
1930 const ExprP<Arg2> &arg2 = voidP(), const ExprP<Arg3> &arg3 = voidP()) const;
1931
getParamNames(void) const1932 const ParamNames &getParamNames(void) const
1933 {
1934 return this->doGetParamNames();
1935 }
1936
1937 protected:
1938 virtual IRet doApply(const EvalContext &, const IArgs &) const = 0;
doFail(const EvalContext & ctx,const IArgs & args) const1939 virtual IRet doFail(const EvalContext &ctx, const IArgs &args) const
1940 {
1941 return this->doApply(ctx, args);
1942 }
doPrint(ostream & os,const BaseArgExprs & args) const1943 virtual void doPrint(ostream &os, const BaseArgExprs &args) const
1944 {
1945 os << getName() << "(";
1946
1947 if (isTypeValid<Arg0>())
1948 os << *args[0];
1949
1950 if (isTypeValid<Arg1>())
1951 os << ", " << *args[1];
1952
1953 if (isTypeValid<Arg2>())
1954 os << ", " << *args[2];
1955
1956 if (isTypeValid<Arg3>())
1957 os << ", " << *args[3];
1958
1959 os << ")";
1960 }
1961
doGetParamNames(void) const1962 virtual const ParamNames &doGetParamNames(void) const
1963 {
1964 static ParamNames names("a", "b", "c", "d");
1965 return names;
1966 }
1967 };
1968
1969 template <typename Sig>
1970 class Apply : public Expr<typename Sig::Ret>
1971 {
1972 public:
1973 typedef typename Sig::Ret Ret;
1974 typedef typename Sig::Arg0 Arg0;
1975 typedef typename Sig::Arg1 Arg1;
1976 typedef typename Sig::Arg2 Arg2;
1977 typedef typename Sig::Arg3 Arg3;
1978 typedef typename Expr<Ret>::Val Val;
1979 typedef typename Expr<Ret>::IVal IVal;
1980 typedef Func<Sig> ApplyFunc;
1981 typedef typename ApplyFunc::ArgExprs ArgExprs;
1982
Apply(const ApplyFunc & func,const ExprP<Arg0> & arg0=voidP (),const ExprP<Arg1> & arg1=voidP (),const ExprP<Arg2> & arg2=voidP (),const ExprP<Arg3> & arg3=voidP ())1983 Apply(const ApplyFunc &func, const ExprP<Arg0> &arg0 = voidP(), const ExprP<Arg1> &arg1 = voidP(),
1984 const ExprP<Arg2> &arg2 = voidP(), const ExprP<Arg3> &arg3 = voidP())
1985 : m_func(func)
1986 , m_args(arg0, arg1, arg2, arg3)
1987 {
1988 }
1989
Apply(const ApplyFunc & func,const ArgExprs & args)1990 Apply(const ApplyFunc &func, const ArgExprs &args) : m_func(func), m_args(args)
1991 {
1992 }
1993
1994 protected:
doPrintExpr(ostream & os) const1995 void doPrintExpr(ostream &os) const
1996 {
1997 BaseArgExprs args;
1998 args.push_back(m_args.a.get());
1999 args.push_back(m_args.b.get());
2000 args.push_back(m_args.c.get());
2001 args.push_back(m_args.d.get());
2002 m_func.print(os, args);
2003 }
2004
doEvaluate(const EvalContext & ctx) const2005 IVal doEvaluate(const EvalContext &ctx) const
2006 {
2007 return m_func.apply(ctx, m_args.a->evaluate(ctx), m_args.b->evaluate(ctx), m_args.c->evaluate(ctx),
2008 m_args.d->evaluate(ctx));
2009 }
2010
doGetUsedFuncs(FuncSet & dst) const2011 void doGetUsedFuncs(FuncSet &dst) const
2012 {
2013 m_func.getUsedFuncs(dst);
2014 m_args.a->getUsedFuncs(dst);
2015 m_args.b->getUsedFuncs(dst);
2016 m_args.c->getUsedFuncs(dst);
2017 m_args.d->getUsedFuncs(dst);
2018 }
2019
2020 const ApplyFunc &m_func;
2021 ArgExprs m_args;
2022 };
2023
2024 template <typename T>
2025 class Alternatives : public Func<Signature<T, T, T>>
2026 {
2027 public:
2028 typedef typename Alternatives::Sig Sig;
2029
2030 protected:
2031 typedef typename Alternatives::IRet IRet;
2032 typedef typename Alternatives::IArgs IArgs;
2033
getName(void) const2034 virtual string getName(void) const
2035 {
2036 return "alternatives";
2037 }
doPrintDefinition(std::ostream &) const2038 virtual void doPrintDefinition(std::ostream &) const
2039 {
2040 }
doGetUsedFuncs(FuncSet &) const2041 void doGetUsedFuncs(FuncSet &) const
2042 {
2043 }
2044
doApply(const EvalContext &,const IArgs & args) const2045 virtual IRet doApply(const EvalContext &, const IArgs &args) const
2046 {
2047 return unionIVal<T>(args.a, args.b);
2048 }
2049
doPrint(ostream & os,const BaseArgExprs & args) const2050 virtual void doPrint(ostream &os, const BaseArgExprs &args) const
2051 {
2052 os << "{" << *args[0] << " | " << *args[1] << "}";
2053 }
2054 };
2055
2056 template <typename Sig>
createApply(const Func<Sig> & func,const typename Func<Sig>::ArgExprs & args)2057 ExprP<typename Sig::Ret> createApply(const Func<Sig> &func, const typename Func<Sig>::ArgExprs &args)
2058 {
2059 return exprP(new Apply<Sig>(func, args));
2060 }
2061
2062 template <typename Sig>
createApply(const Func<Sig> & func,const ExprP<typename Sig::Arg0> & arg0=voidP (),const ExprP<typename Sig::Arg1> & arg1=voidP (),const ExprP<typename Sig::Arg2> & arg2=voidP (),const ExprP<typename Sig::Arg3> & arg3=voidP ())2063 ExprP<typename Sig::Ret> createApply(const Func<Sig> &func, const ExprP<typename Sig::Arg0> &arg0 = voidP(),
2064 const ExprP<typename Sig::Arg1> &arg1 = voidP(),
2065 const ExprP<typename Sig::Arg2> &arg2 = voidP(),
2066 const ExprP<typename Sig::Arg3> &arg3 = voidP())
2067 {
2068 return exprP(new Apply<Sig>(func, arg0, arg1, arg2, arg3));
2069 }
2070
2071 template <typename Sig>
operator ()(const ExprP<typename Sig::Arg0> & arg0,const ExprP<typename Sig::Arg1> & arg1,const ExprP<typename Sig::Arg2> & arg2,const ExprP<typename Sig::Arg3> & arg3) const2072 ExprP<typename Sig::Ret> Func<Sig>::operator()(const ExprP<typename Sig::Arg0> &arg0,
2073 const ExprP<typename Sig::Arg1> &arg1,
2074 const ExprP<typename Sig::Arg2> &arg2,
2075 const ExprP<typename Sig::Arg3> &arg3) const
2076 {
2077 return createApply(*this, arg0, arg1, arg2, arg3);
2078 }
2079
2080 template <typename F>
app(const ExprP<typename F::Arg0> & arg0=voidP (),const ExprP<typename F::Arg1> & arg1=voidP (),const ExprP<typename F::Arg2> & arg2=voidP (),const ExprP<typename F::Arg3> & arg3=voidP ())2081 ExprP<typename F::Ret> app(const ExprP<typename F::Arg0> &arg0 = voidP(), const ExprP<typename F::Arg1> &arg1 = voidP(),
2082 const ExprP<typename F::Arg2> &arg2 = voidP(), const ExprP<typename F::Arg3> &arg3 = voidP())
2083 {
2084 return createApply(instance<F>(), arg0, arg1, arg2, arg3);
2085 }
2086
2087 template <typename F>
call(const EvalContext & ctx,const typename F::IArg0 & arg0=Void (),const typename F::IArg1 & arg1=Void (),const typename F::IArg2 & arg2=Void (),const typename F::IArg3 & arg3=Void ())2088 typename F::IRet call(const EvalContext &ctx, const typename F::IArg0 &arg0 = Void(),
2089 const typename F::IArg1 &arg1 = Void(), const typename F::IArg2 &arg2 = Void(),
2090 const typename F::IArg3 &arg3 = Void())
2091 {
2092 return instance<F>().apply(ctx, arg0, arg1, arg2, arg3);
2093 }
2094
2095 template <typename T>
alternatives(const ExprP<T> & arg0,const ExprP<T> & arg1)2096 ExprP<T> alternatives(const ExprP<T> &arg0, const ExprP<T> &arg1)
2097 {
2098 return createApply<typename Alternatives<T>::Sig>(instance<Alternatives<T>>(), arg0, arg1);
2099 }
2100
2101 template <typename Sig>
2102 class ApplyVar : public Apply<Sig>
2103 {
2104 public:
2105 typedef typename Sig::Ret Ret;
2106 typedef typename Sig::Arg0 Arg0;
2107 typedef typename Sig::Arg1 Arg1;
2108 typedef typename Sig::Arg2 Arg2;
2109 typedef typename Sig::Arg3 Arg3;
2110 typedef typename Expr<Ret>::Val Val;
2111 typedef typename Expr<Ret>::IVal IVal;
2112 typedef Func<Sig> ApplyFunc;
2113 typedef typename ApplyFunc::ArgExprs ArgExprs;
2114
ApplyVar(const ApplyFunc & func,const VariableP<Arg0> & arg0,const VariableP<Arg1> & arg1,const VariableP<Arg2> & arg2,const VariableP<Arg3> & arg3)2115 ApplyVar(const ApplyFunc &func, const VariableP<Arg0> &arg0, const VariableP<Arg1> &arg1,
2116 const VariableP<Arg2> &arg2, const VariableP<Arg3> &arg3)
2117 : Apply<Sig>(func, arg0, arg1, arg2, arg3)
2118 {
2119 }
2120
2121 protected:
doEvaluate(const EvalContext & ctx) const2122 IVal doEvaluate(const EvalContext &ctx) const
2123 {
2124 const Variable<Arg0> &var0 = static_cast<const Variable<Arg0> &>(*this->m_args.a);
2125 const Variable<Arg1> &var1 = static_cast<const Variable<Arg1> &>(*this->m_args.b);
2126 const Variable<Arg2> &var2 = static_cast<const Variable<Arg2> &>(*this->m_args.c);
2127 const Variable<Arg3> &var3 = static_cast<const Variable<Arg3> &>(*this->m_args.d);
2128 return this->m_func.apply(ctx, ctx.env.lookup(var0), ctx.env.lookup(var1), ctx.env.lookup(var2),
2129 ctx.env.lookup(var3));
2130 }
2131
doFails(const EvalContext & ctx) const2132 IVal doFails(const EvalContext &ctx) const
2133 {
2134 const Variable<Arg0> &var0 = static_cast<const Variable<Arg0> &>(*this->m_args.a);
2135 const Variable<Arg1> &var1 = static_cast<const Variable<Arg1> &>(*this->m_args.b);
2136 const Variable<Arg2> &var2 = static_cast<const Variable<Arg2> &>(*this->m_args.c);
2137 const Variable<Arg3> &var3 = static_cast<const Variable<Arg3> &>(*this->m_args.d);
2138 return this->m_func.fail(ctx, ctx.env.lookup(var0), ctx.env.lookup(var1), ctx.env.lookup(var2),
2139 ctx.env.lookup(var3));
2140 }
2141 };
2142
2143 template <typename Sig>
applyVar(const Func<Sig> & func,const VariableP<typename Sig::Arg0> & arg0,const VariableP<typename Sig::Arg1> & arg1,const VariableP<typename Sig::Arg2> & arg2,const VariableP<typename Sig::Arg3> & arg3)2144 ExprP<typename Sig::Ret> applyVar(const Func<Sig> &func, const VariableP<typename Sig::Arg0> &arg0,
2145 const VariableP<typename Sig::Arg1> &arg1, const VariableP<typename Sig::Arg2> &arg2,
2146 const VariableP<typename Sig::Arg3> &arg3)
2147 {
2148 return exprP(new ApplyVar<Sig>(func, arg0, arg1, arg2, arg3));
2149 }
2150
2151 template <typename Sig_>
2152 class DerivedFunc : public Func<Sig_>
2153 {
2154 public:
2155 typedef typename DerivedFunc::ArgExprs ArgExprs;
2156 typedef typename DerivedFunc::IRet IRet;
2157 typedef typename DerivedFunc::IArgs IArgs;
2158 typedef typename DerivedFunc::Ret Ret;
2159 typedef typename DerivedFunc::Arg0 Arg0;
2160 typedef typename DerivedFunc::Arg1 Arg1;
2161 typedef typename DerivedFunc::Arg2 Arg2;
2162 typedef typename DerivedFunc::Arg3 Arg3;
2163 typedef typename DerivedFunc::IArg0 IArg0;
2164 typedef typename DerivedFunc::IArg1 IArg1;
2165 typedef typename DerivedFunc::IArg2 IArg2;
2166 typedef typename DerivedFunc::IArg3 IArg3;
2167
2168 protected:
doPrintDefinition(ostream & os) const2169 void doPrintDefinition(ostream &os) const
2170 {
2171 const ParamNames ¶mNames = this->getParamNames();
2172
2173 initialize();
2174
2175 os << dataTypeNameOf<Ret>() << " " << this->getName() << "(";
2176 if (isTypeValid<Arg0>())
2177 os << dataTypeNameOf<Arg0>() << " " << paramNames.a;
2178 if (isTypeValid<Arg1>())
2179 os << ", " << dataTypeNameOf<Arg1>() << " " << paramNames.b;
2180 if (isTypeValid<Arg2>())
2181 os << ", " << dataTypeNameOf<Arg2>() << " " << paramNames.c;
2182 if (isTypeValid<Arg3>())
2183 os << ", " << dataTypeNameOf<Arg3>() << " " << paramNames.d;
2184 os << ")\n{\n";
2185
2186 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2187 os << *m_body[ndx];
2188 os << "return " << *m_ret << ";\n";
2189 os << "}\n";
2190 }
2191
doApply(const EvalContext & ctx,const IArgs & args) const2192 IRet doApply(const EvalContext &ctx, const IArgs &args) const
2193 {
2194 Environment funEnv;
2195 IArgs &mutArgs = const_cast<IArgs &>(args);
2196 IRet ret;
2197
2198 initialize();
2199
2200 funEnv.bind(*m_var0, args.a);
2201 funEnv.bind(*m_var1, args.b);
2202 funEnv.bind(*m_var2, args.c);
2203 funEnv.bind(*m_var3, args.d);
2204
2205 {
2206 EvalContext funCtx(ctx.format, ctx.floatPrecision, funEnv, ctx.callDepth);
2207
2208 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2209 m_body[ndx]->execute(funCtx);
2210
2211 ret = m_ret->evaluate(funCtx);
2212 }
2213
2214 // \todo [lauri] Store references instead of values in environment
2215 const_cast<IArg0 &>(mutArgs.a) = funEnv.lookup(*m_var0);
2216 const_cast<IArg1 &>(mutArgs.b) = funEnv.lookup(*m_var1);
2217 const_cast<IArg2 &>(mutArgs.c) = funEnv.lookup(*m_var2);
2218 const_cast<IArg3 &>(mutArgs.d) = funEnv.lookup(*m_var3);
2219
2220 return ret;
2221 }
2222
doGetUsedFuncs(FuncSet & dst) const2223 void doGetUsedFuncs(FuncSet &dst) const
2224 {
2225 initialize();
2226 if (dst.insert(this).second)
2227 {
2228 for (size_t ndx = 0; ndx < m_body.size(); ++ndx)
2229 m_body[ndx]->getUsedFuncs(dst);
2230 m_ret->getUsedFuncs(dst);
2231 }
2232 }
2233
2234 virtual ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args_) const = 0;
2235
2236 // These are transparently initialized when first needed. They cannot be
2237 // initialized in the constructor because they depend on the doExpand
2238 // method of the subclass.
2239
2240 mutable VariableP<Arg0> m_var0;
2241 mutable VariableP<Arg1> m_var1;
2242 mutable VariableP<Arg2> m_var2;
2243 mutable VariableP<Arg3> m_var3;
2244 mutable vector<StatementP> m_body;
2245 mutable ExprP<Ret> m_ret;
2246
2247 private:
initialize(void) const2248 void initialize(void) const
2249 {
2250 if (!m_ret)
2251 {
2252 const ParamNames ¶mNames = this->getParamNames();
2253 Counter symCounter;
2254 ExpandContext ctx(symCounter);
2255 ArgExprs args;
2256
2257 args.a = m_var0 = variable<Arg0>(paramNames.a);
2258 args.b = m_var1 = variable<Arg1>(paramNames.b);
2259 args.c = m_var2 = variable<Arg2>(paramNames.c);
2260 args.d = m_var3 = variable<Arg3>(paramNames.d);
2261
2262 m_ret = this->doExpand(ctx, args);
2263 m_body = ctx.getStatements();
2264 }
2265 }
2266 };
2267
2268 template <typename Sig>
2269 class PrimitiveFunc : public Func<Sig>
2270 {
2271 public:
2272 typedef typename PrimitiveFunc::Ret Ret;
2273 typedef typename PrimitiveFunc::ArgExprs ArgExprs;
2274
2275 protected:
doPrintDefinition(ostream &) const2276 void doPrintDefinition(ostream &) const
2277 {
2278 }
doGetUsedFuncs(FuncSet &) const2279 void doGetUsedFuncs(FuncSet &) const
2280 {
2281 }
2282 };
2283
2284 template <typename T>
2285 class Cond : public PrimitiveFunc<Signature<T, bool, T, T>>
2286 {
2287 public:
2288 typedef typename Cond::IArgs IArgs;
2289 typedef typename Cond::IRet IRet;
2290
getName(void) const2291 string getName(void) const
2292 {
2293 return "_cond";
2294 }
2295
2296 protected:
doPrint(ostream & os,const BaseArgExprs & args) const2297 void doPrint(ostream &os, const BaseArgExprs &args) const
2298 {
2299 os << "(" << *args[0] << " ? " << *args[1] << " : " << *args[2] << ")";
2300 }
2301
doApply(const EvalContext &,const IArgs & iargs) const2302 IRet doApply(const EvalContext &, const IArgs &iargs) const
2303 {
2304 IRet ret;
2305
2306 if (iargs.a.contains(true))
2307 ret = unionIVal<T>(ret, iargs.b);
2308
2309 if (iargs.a.contains(false))
2310 ret = unionIVal<T>(ret, iargs.c);
2311
2312 return ret;
2313 }
2314 };
2315
2316 template <typename T>
2317 class CompareOperator : public PrimitiveFunc<Signature<bool, T, T>>
2318 {
2319 public:
2320 typedef typename CompareOperator::IArgs IArgs;
2321 typedef typename CompareOperator::IArg0 IArg0;
2322 typedef typename CompareOperator::IArg1 IArg1;
2323 typedef typename CompareOperator::IRet IRet;
2324
2325 protected:
doPrint(ostream & os,const BaseArgExprs & args) const2326 void doPrint(ostream &os, const BaseArgExprs &args) const
2327 {
2328 os << "(" << *args[0] << getSymbol() << *args[1] << ")";
2329 }
2330
doApply(const EvalContext &,const IArgs & iargs) const2331 Interval doApply(const EvalContext &, const IArgs &iargs) const
2332 {
2333 const IArg0 &arg0 = iargs.a;
2334 const IArg1 &arg1 = iargs.b;
2335 IRet ret;
2336
2337 if (canSucceed(arg0, arg1))
2338 ret |= true;
2339 if (canFail(arg0, arg1))
2340 ret |= false;
2341
2342 return ret;
2343 }
2344
2345 virtual string getSymbol(void) const = 0;
2346 virtual bool canSucceed(const IArg0 &, const IArg1 &) const = 0;
2347 virtual bool canFail(const IArg0 &, const IArg1 &) const = 0;
2348 };
2349
2350 template <typename T>
2351 class LessThan : public CompareOperator<T>
2352 {
2353 public:
getName(void) const2354 string getName(void) const
2355 {
2356 return "lessThan";
2357 }
2358
2359 protected:
getSymbol(void) const2360 string getSymbol(void) const
2361 {
2362 return "<";
2363 }
2364
canSucceed(const Interval & a,const Interval & b) const2365 bool canSucceed(const Interval &a, const Interval &b) const
2366 {
2367 return (a.lo() < b.hi());
2368 }
2369
canFail(const Interval & a,const Interval & b) const2370 bool canFail(const Interval &a, const Interval &b) const
2371 {
2372 return !(a.hi() < b.lo());
2373 }
2374 };
2375
2376 template <typename T>
operator <(const ExprP<T> & a,const ExprP<T> & b)2377 ExprP<bool> operator<(const ExprP<T> &a, const ExprP<T> &b)
2378 {
2379 return app<LessThan<T>>(a, b);
2380 }
2381
2382 template <typename T>
cond(const ExprP<bool> & test,const ExprP<T> & consequent,const ExprP<T> & alternative)2383 ExprP<T> cond(const ExprP<bool> &test, const ExprP<T> &consequent, const ExprP<T> &alternative)
2384 {
2385 return app<Cond<T>>(test, consequent, alternative);
2386 }
2387
2388 /*--------------------------------------------------------------------*//*!
2389 *
2390 * @}
2391 *
2392 *//*--------------------------------------------------------------------*/
2393 //Proper parameters for template T
2394 // Signature<float, float> 32bit tests
2395 // Signature<float, deFloat16> 16bit tests
2396 // Signature<double, double> 64bit tests
2397 template <class T>
2398 class FloatFunc1 : public PrimitiveFunc<T>
2399 {
2400 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0>::IArgs & iargs) const2401 Interval doApply(const EvalContext &ctx,
2402 const typename Signature<typename T::Ret, typename T::Arg0>::IArgs &iargs) const
2403 {
2404 return this->applyMonotone(ctx, iargs.a);
2405 }
2406
applyMonotone(const EvalContext & ctx,const Interval & iarg0) const2407 Interval applyMonotone(const EvalContext &ctx, const Interval &iarg0) const
2408 {
2409 Interval ret;
2410
2411 TCU_INTERVAL_APPLY_MONOTONE1(ret, arg0, iarg0, val,
2412 TCU_SET_INTERVAL(val, point, point = this->applyPoint(ctx, arg0)));
2413
2414 ret |= innerExtrema(ctx, iarg0);
2415 ret &= (this->getCodomain(ctx) | TCU_NAN);
2416
2417 return ctx.format.convert(ret);
2418 }
2419
innerExtrema(const EvalContext &,const Interval &) const2420 virtual Interval innerExtrema(const EvalContext &, const Interval &) const
2421 {
2422 return Interval(); // empty interval, i.e. no extrema
2423 }
2424
applyPoint(const EvalContext & ctx,double arg0) const2425 virtual Interval applyPoint(const EvalContext &ctx, double arg0) const
2426 {
2427 const double exact = this->applyExact(arg0);
2428 const double prec = this->precision(ctx, exact, arg0);
2429
2430 return exact + Interval(-prec, prec);
2431 }
2432
applyExact(double) const2433 virtual double applyExact(double) const
2434 {
2435 TCU_THROW(InternalError, "Cannot apply");
2436 }
2437
getCodomain(const EvalContext &) const2438 virtual Interval getCodomain(const EvalContext &) const
2439 {
2440 return Interval::unbounded(true);
2441 }
2442
2443 virtual double precision(const EvalContext &ctx, double, double) const = 0;
2444 };
2445
2446 /*Proper parameters for template T
2447 Signature<double, double> 64bit tests
2448 Signature<float, float> 32bit tests
2449 Signature<float, deFloat16> 16bit tests*/
2450 template <class T>
2451 class CFloatFunc1 : public FloatFunc1<T>
2452 {
2453 public:
CFloatFunc1(const string & name,tcu::DoubleFunc1 & func)2454 CFloatFunc1(const string &name, tcu::DoubleFunc1 &func) : m_name(name), m_func(func)
2455 {
2456 }
2457
getName(void) const2458 string getName(void) const
2459 {
2460 return m_name;
2461 }
2462
2463 protected:
applyExact(double x) const2464 double applyExact(double x) const
2465 {
2466 return m_func(x);
2467 }
2468
2469 const string m_name;
2470 tcu::DoubleFunc1 &m_func;
2471 };
2472
2473 //<Signature<float, deFloat16, deFloat16> >
2474 //<Signature<float, float, float> >
2475 //<Signature<double, double, double> >
2476 template <class T>
2477 class FloatFunc2 : public PrimitiveFunc<T>
2478 {
2479 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2480 Interval doApply(const EvalContext &ctx,
2481 const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs &iargs) const
2482 {
2483 return this->applyMonotone(ctx, iargs.a, iargs.b);
2484 }
2485
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi) const2486 Interval applyMonotone(const EvalContext &ctx, const Interval &xi, const Interval &yi) const
2487 {
2488 Interval reti;
2489
2490 TCU_INTERVAL_APPLY_MONOTONE2(reti, x, xi, y, yi, ret,
2491 TCU_SET_INTERVAL(ret, point, point = this->applyPoint(ctx, x, y)));
2492 reti |= innerExtrema(ctx, xi, yi);
2493 reti &= (this->getCodomain(ctx) | TCU_NAN);
2494
2495 return ctx.format.convert(reti);
2496 }
2497
innerExtrema(const EvalContext &,const Interval &,const Interval &) const2498 virtual Interval innerExtrema(const EvalContext &, const Interval &, const Interval &) const
2499 {
2500 return Interval(); // empty interval, i.e. no extrema
2501 }
2502
applyPoint(const EvalContext & ctx,double x,double y) const2503 virtual Interval applyPoint(const EvalContext &ctx, double x, double y) const
2504 {
2505 const double exact = this->applyExact(x, y);
2506 const double prec = this->precision(ctx, exact, x, y);
2507
2508 return exact + Interval(-prec, prec);
2509 }
2510
applyExact(double,double) const2511 virtual double applyExact(double, double) const
2512 {
2513 TCU_THROW(InternalError, "Cannot apply");
2514 }
2515
getCodomain(const EvalContext &) const2516 virtual Interval getCodomain(const EvalContext &) const
2517 {
2518 return Interval::unbounded(true);
2519 }
2520
2521 virtual double precision(const EvalContext &ctx, double ret, double x, double y) const = 0;
2522 };
2523
2524 template <class T>
2525 class CFloatFunc2 : public FloatFunc2<T>
2526 {
2527 public:
CFloatFunc2(const string & name,tcu::DoubleFunc2 & func)2528 CFloatFunc2(const string &name, tcu::DoubleFunc2 &func) : m_name(name), m_func(func)
2529 {
2530 }
2531
getName(void) const2532 string getName(void) const
2533 {
2534 return m_name;
2535 }
2536
2537 protected:
applyExact(double x,double y) const2538 double applyExact(double x, double y) const
2539 {
2540 return m_func(x, y);
2541 }
2542
2543 const string m_name;
2544 tcu::DoubleFunc2 &m_func;
2545 };
2546
2547 template <class T>
2548 class InfixOperator : public FloatFunc2<T>
2549 {
2550 protected:
2551 virtual string getSymbol(void) const = 0;
2552
doPrint(ostream & os,const BaseArgExprs & args) const2553 void doPrint(ostream &os, const BaseArgExprs &args) const
2554 {
2555 os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2556 }
2557
applyPoint(const EvalContext & ctx,double x,double y) const2558 Interval applyPoint(const EvalContext &ctx, double x, double y) const
2559 {
2560 const double exact = this->applyExact(x, y);
2561
2562 // Allow either representable number on both sides of the exact value,
2563 // but require exactly representable values to be preserved.
2564 return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2565 }
2566
precision(const EvalContext &,double,double,double) const2567 double precision(const EvalContext &, double, double, double) const
2568 {
2569 return 0.0;
2570 }
2571 };
2572
2573 class InfixOperator16Bit : public FloatFunc2<Signature<float, deFloat16, deFloat16>>
2574 {
2575 protected:
2576 virtual string getSymbol(void) const = 0;
2577
doPrint(ostream & os,const BaseArgExprs & args) const2578 void doPrint(ostream &os, const BaseArgExprs &args) const
2579 {
2580 os << "(" << *args[0] << " " << getSymbol() << " " << *args[1] << ")";
2581 }
2582
applyPoint(const EvalContext & ctx,double x,double y) const2583 Interval applyPoint(const EvalContext &ctx, double x, double y) const
2584 {
2585 const double exact = this->applyExact(x, y);
2586
2587 // Allow either representable number on both sides of the exact value,
2588 // but require exactly representable values to be preserved.
2589 return ctx.format.roundOut(exact, !deIsInf(x) && !deIsInf(y));
2590 }
2591
precision(const EvalContext &,double,double,double) const2592 double precision(const EvalContext &, double, double, double) const
2593 {
2594 return 0.0;
2595 }
2596 };
2597
2598 template <class T>
2599 class FloatFunc3 : public PrimitiveFunc<T>
2600 {
2601 protected:
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1,typename T::Arg2>::IArgs & iargs) const2602 Interval doApply(const EvalContext &ctx,
2603 const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1,
2604 typename T::Arg2>::IArgs &iargs) const
2605 {
2606 return this->applyMonotone(ctx, iargs.a, iargs.b, iargs.c);
2607 }
2608
applyMonotone(const EvalContext & ctx,const Interval & xi,const Interval & yi,const Interval & zi) const2609 Interval applyMonotone(const EvalContext &ctx, const Interval &xi, const Interval &yi, const Interval &zi) const
2610 {
2611 Interval reti;
2612 TCU_INTERVAL_APPLY_MONOTONE3(reti, x, xi, y, yi, z, zi, ret,
2613 TCU_SET_INTERVAL(ret, point, point = this->applyPoint(ctx, x, y, z)));
2614 return ctx.format.convert(reti);
2615 }
2616
applyPoint(const EvalContext & ctx,double x,double y,double z) const2617 virtual Interval applyPoint(const EvalContext &ctx, double x, double y, double z) const
2618 {
2619 const double exact = this->applyExact(x, y, z);
2620 const double prec = this->precision(ctx, exact, x, y, z);
2621 return exact + Interval(-prec, prec);
2622 }
2623
applyExact(double,double,double) const2624 virtual double applyExact(double, double, double) const
2625 {
2626 TCU_THROW(InternalError, "Cannot apply");
2627 }
2628
2629 virtual double precision(const EvalContext &ctx, double result, double x, double y, double z) const = 0;
2630 };
2631
2632 // We define syntactic sugar functions for expression constructors. Since
2633 // these have the same names as ordinary mathematical operations (sin, log
2634 // etc.), it's better to give them a dedicated namespace.
2635 namespace Functions
2636 {
2637
2638 using namespace tcu;
2639
2640 template <class T>
2641 class Comparison : public InfixOperator<T>
2642 {
2643 public:
getName(void) const2644 string getName(void) const
2645 {
2646 return "comparison";
2647 }
getSymbol(void) const2648 string getSymbol(void) const
2649 {
2650 return "";
2651 }
2652
getSpirvCase() const2653 SpirVCaseT getSpirvCase() const
2654 {
2655 return SPIRV_CASETYPE_COMPARE;
2656 }
2657
doApply(const EvalContext & ctx,const typename Comparison<T>::IArgs & iargs) const2658 Interval doApply(const EvalContext &ctx, const typename Comparison<T>::IArgs &iargs) const
2659 {
2660 DE_UNREF(ctx);
2661 if (iargs.a.hasNaN() || iargs.b.hasNaN())
2662 {
2663 return TCU_NAN; // one of the floats is NaN: block analysis
2664 }
2665
2666 int operationFlag = 1;
2667 int result = 0;
2668 const double a = iargs.a.midpoint();
2669 const double b = iargs.b.midpoint();
2670
2671 for (int i = 0; i < 2; ++i)
2672 {
2673 if (a == b)
2674 result += operationFlag;
2675 operationFlag = operationFlag << 1;
2676
2677 if (a > b)
2678 result += operationFlag;
2679 operationFlag = operationFlag << 1;
2680
2681 if (a < b)
2682 result += operationFlag;
2683 operationFlag = operationFlag << 1;
2684
2685 if (a >= b)
2686 result += operationFlag;
2687 operationFlag = operationFlag << 1;
2688
2689 if (a <= b)
2690 result += operationFlag;
2691 operationFlag = operationFlag << 1;
2692 }
2693 return result;
2694 }
2695 };
2696
2697 template <class T>
2698 class Add : public InfixOperator<T>
2699 {
2700 public:
getName(void) const2701 string getName(void) const
2702 {
2703 return "add";
2704 }
getSymbol(void) const2705 string getSymbol(void) const
2706 {
2707 return "+";
2708 }
2709
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2710 Interval doApply(const EvalContext &ctx,
2711 const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs &iargs) const
2712 {
2713 // Fast-path for common case
2714 if (iargs.a.isOrdinary(ctx.format.getMaxValue()) && iargs.b.isOrdinary(ctx.format.getMaxValue()))
2715 {
2716 Interval ret;
2717 TCU_SET_INTERVAL_BOUNDS(ret, sum, sum = iargs.a.lo() + iargs.b.lo(), sum = iargs.a.hi() + iargs.b.hi());
2718 return ctx.format.convert(ctx.format.roundOut(ret, true));
2719 }
2720 return this->applyMonotone(ctx, iargs.a, iargs.b);
2721 }
2722
2723 protected:
applyExact(double x,double y) const2724 double applyExact(double x, double y) const
2725 {
2726 return x + y;
2727 }
2728 };
2729
2730 template <class T>
2731 class Mul : public InfixOperator<T>
2732 {
2733 public:
getName(void) const2734 string getName(void) const
2735 {
2736 return "mul";
2737 }
getSymbol(void) const2738 string getSymbol(void) const
2739 {
2740 return "*";
2741 }
2742
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2743 Interval doApply(const EvalContext &ctx,
2744 const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs &iargs) const
2745 {
2746 Interval a = iargs.a;
2747 Interval b = iargs.b;
2748
2749 // Fast-path for common case
2750 if (a.isOrdinary(ctx.format.getMaxValue()) && b.isOrdinary(ctx.format.getMaxValue()))
2751 {
2752 Interval ret;
2753 if (a.hi() < 0)
2754 {
2755 a = -a;
2756 b = -b;
2757 }
2758 if (a.lo() >= 0 && b.lo() >= 0)
2759 {
2760 TCU_SET_INTERVAL_BOUNDS(ret, prod, prod = a.lo() * b.lo(), prod = a.hi() * b.hi());
2761 return ctx.format.convert(ctx.format.roundOut(ret, true));
2762 }
2763 if (a.lo() >= 0 && b.hi() <= 0)
2764 {
2765 TCU_SET_INTERVAL_BOUNDS(ret, prod, prod = a.hi() * b.lo(), prod = a.lo() * b.hi());
2766 return ctx.format.convert(ctx.format.roundOut(ret, true));
2767 }
2768 }
2769 return this->applyMonotone(ctx, iargs.a, iargs.b);
2770 }
2771
2772 protected:
applyExact(double x,double y) const2773 double applyExact(double x, double y) const
2774 {
2775 return x * y;
2776 }
2777
innerExtrema(const EvalContext &,const Interval & xi,const Interval & yi) const2778 Interval innerExtrema(const EvalContext &, const Interval &xi, const Interval &yi) const
2779 {
2780 if (((xi.contains(-TCU_INFINITY) || xi.contains(TCU_INFINITY)) && yi.contains(0.0)) ||
2781 ((yi.contains(-TCU_INFINITY) || yi.contains(TCU_INFINITY)) && xi.contains(0.0)))
2782 return Interval(TCU_NAN);
2783
2784 return Interval();
2785 }
2786 };
2787
2788 template <class T>
2789 class Sub : public InfixOperator<T>
2790 {
2791 public:
getName(void) const2792 string getName(void) const
2793 {
2794 return "sub";
2795 }
getSymbol(void) const2796 string getSymbol(void) const
2797 {
2798 return "-";
2799 }
2800
doApply(const EvalContext & ctx,const typename Signature<typename T::Ret,typename T::Arg0,typename T::Arg1>::IArgs & iargs) const2801 Interval doApply(const EvalContext &ctx,
2802 const typename Signature<typename T::Ret, typename T::Arg0, typename T::Arg1>::IArgs &iargs) const
2803 {
2804 // Fast-path for common case
2805 if (iargs.a.isOrdinary(ctx.format.getMaxValue()) && iargs.b.isOrdinary(ctx.format.getMaxValue()))
2806 {
2807 Interval ret;
2808
2809 TCU_SET_INTERVAL_BOUNDS(ret, diff, diff = iargs.a.lo() - iargs.b.hi(), diff = iargs.a.hi() - iargs.b.lo());
2810 return ctx.format.convert(ctx.format.roundOut(ret, true));
2811 }
2812 else
2813 {
2814 return this->applyMonotone(ctx, iargs.a, iargs.b);
2815 }
2816 }
2817
2818 protected:
applyExact(double x,double y) const2819 double applyExact(double x, double y) const
2820 {
2821 return x - y;
2822 }
2823 };
2824
2825 template <class T>
2826 class Negate : public FloatFunc1<T>
2827 {
2828 public:
getName(void) const2829 string getName(void) const
2830 {
2831 return "_negate";
2832 }
doPrint(ostream & os,const BaseArgExprs & args) const2833 void doPrint(ostream &os, const BaseArgExprs &args) const
2834 {
2835 os << "-" << *args[0];
2836 }
2837
2838 protected:
precision(const EvalContext &,double,double) const2839 double precision(const EvalContext &, double, double) const
2840 {
2841 return 0.0;
2842 }
applyExact(double x) const2843 double applyExact(double x) const
2844 {
2845 return -x;
2846 }
2847 };
2848
2849 template <class T>
2850 class Div : public InfixOperator<T>
2851 {
2852 public:
getName(void) const2853 string getName(void) const
2854 {
2855 return "div";
2856 }
2857
2858 protected:
getSymbol(void) const2859 string getSymbol(void) const
2860 {
2861 return "/";
2862 }
2863
innerExtrema(const EvalContext &,const Interval & nom,const Interval & den) const2864 Interval innerExtrema(const EvalContext &, const Interval &nom, const Interval &den) const
2865 {
2866 Interval ret;
2867
2868 if (den.contains(0.0))
2869 {
2870 if (nom.contains(0.0))
2871 ret |= TCU_NAN;
2872
2873 if (nom.lo() < 0.0 || nom.hi() > 0.0)
2874 ret |= Interval::unbounded();
2875 }
2876
2877 return ret;
2878 }
2879
applyExact(double x,double y) const2880 double applyExact(double x, double y) const
2881 {
2882 return x / y;
2883 }
2884
applyPoint(const EvalContext & ctx,double x,double y) const2885 Interval applyPoint(const EvalContext &ctx, double x, double y) const
2886 {
2887 Interval ret = FloatFunc2<T>::applyPoint(ctx, x, y);
2888
2889 if (!deIsInf(x) && !deIsInf(y) && y != 0.0)
2890 {
2891 const Interval dst = ctx.format.convert(ret);
2892 if (dst.contains(-TCU_INFINITY))
2893 ret |= -ctx.format.getMaxValue();
2894 if (dst.contains(+TCU_INFINITY))
2895 ret |= +ctx.format.getMaxValue();
2896 }
2897
2898 return ret;
2899 }
2900
precision(const EvalContext & ctx,double ret,double,double den) const2901 double precision(const EvalContext &ctx, double ret, double, double den) const
2902 {
2903 const FloatFormat &fmt = ctx.format;
2904
2905 // \todo [2014-03-05 lauri] Check that the limits in GLSL 3.10 are actually correct.
2906 // For now, we assume that division's precision is 2.5 ULP when the value is within
2907 // [2^MINEXP, 2^MAXEXP-1]
2908
2909 if (den == 0.0)
2910 return 0.0; // Result must be exactly inf
2911 else if (de::inBounds(deAbs(den), deLdExp(1.0, fmt.getMinExp()), deLdExp(1.0, fmt.getMaxExp() - 1)))
2912 return fmt.ulp(ret, 2.5);
2913 else
2914 return TCU_INFINITY; // Can be any number, but must be a number.
2915 }
2916 };
2917
2918 template <class T>
2919 class InverseSqrt : public FloatFunc1<T>
2920 {
2921 public:
getName(void) const2922 string getName(void) const
2923 {
2924 return "inversesqrt";
2925 }
2926
2927 protected:
applyExact(double x) const2928 double applyExact(double x) const
2929 {
2930 return 1.0 / deSqrt(x);
2931 }
2932
precision(const EvalContext & ctx,double ret,double x) const2933 double precision(const EvalContext &ctx, double ret, double x) const
2934 {
2935 return x <= 0 ? TCU_NAN : ctx.format.ulp(ret, 2.0);
2936 }
2937
getCodomain(const EvalContext &) const2938 Interval getCodomain(const EvalContext &) const
2939 {
2940 return Interval(0.0, TCU_INFINITY);
2941 }
2942 };
2943
2944 template <class T>
2945 class ExpFunc : public CFloatFunc1<T>
2946 {
2947 public:
ExpFunc(const string & name,DoubleFunc1 & func)2948 ExpFunc(const string &name, DoubleFunc1 &func) : CFloatFunc1<T>(name, func)
2949 {
2950 }
2951
2952 protected:
2953 double precision(const EvalContext &ctx, double ret, double x) const;
getCodomain(const EvalContext &) const2954 Interval getCodomain(const EvalContext &) const
2955 {
2956 return Interval(0.0, TCU_INFINITY);
2957 }
2958 };
2959
2960 template <>
precision(const EvalContext & ctx,double ret,double x) const2961 double ExpFunc<Signature<float, float>>::precision(const EvalContext &ctx, double ret, double x) const
2962 {
2963 switch (ctx.floatPrecision)
2964 {
2965 case glu::PRECISION_HIGHP:
2966 return ctx.format.ulp(ret, 3.0 + 2.0 * deAbs(x));
2967 case glu::PRECISION_MEDIUMP:
2968 case glu::PRECISION_LAST:
2969 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2970 default:
2971 DE_FATAL("Impossible");
2972 }
2973
2974 return 0.0;
2975 }
2976
2977 template <>
precision(const EvalContext & ctx,double ret,double x) const2978 double ExpFunc<Signature<deFloat16, deFloat16>>::precision(const EvalContext &ctx, double ret, double x) const
2979 {
2980 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2981 }
2982
2983 template <>
precision(const EvalContext & ctx,double ret,double x) const2984 double ExpFunc<Signature<double, double>>::precision(const EvalContext &ctx, double ret, double x) const
2985 {
2986 return ctx.format.ulp(ret, 1.0 + 2.0 * deAbs(x));
2987 }
2988
2989 template <class T>
2990 class Exp2 : public ExpFunc<T>
2991 {
2992 public:
Exp2(void)2993 Exp2(void) : ExpFunc<T>("exp2", deExp2)
2994 {
2995 }
2996 };
2997 template <class T>
2998 class Exp : public ExpFunc<T>
2999 {
3000 public:
Exp(void)3001 Exp(void) : ExpFunc<T>("exp", deExp)
3002 {
3003 }
3004 };
3005
3006 template <typename T>
exp2(const ExprP<T> & x)3007 ExprP<T> exp2(const ExprP<T> &x)
3008 {
3009 return app<Exp2<Signature<T, T>>>(x);
3010 }
3011 template <typename T>
exp(const ExprP<T> & x)3012 ExprP<T> exp(const ExprP<T> &x)
3013 {
3014 return app<Exp<Signature<T, T>>>(x);
3015 }
3016
3017 template <class T>
3018 class LogFunc : public CFloatFunc1<T>
3019 {
3020 public:
LogFunc(const string & name,DoubleFunc1 & func)3021 LogFunc(const string &name, DoubleFunc1 &func) : CFloatFunc1<T>(name, func)
3022 {
3023 }
3024
3025 protected:
3026 double precision(const EvalContext &ctx, double ret, double x) const;
3027 };
3028
3029 template <>
precision(const EvalContext & ctx,double ret,double x) const3030 double LogFunc<Signature<float, float>>::precision(const EvalContext &ctx, double ret, double x) const
3031 {
3032 if (x <= 0)
3033 return TCU_NAN;
3034
3035 switch (ctx.floatPrecision)
3036 {
3037 case glu::PRECISION_HIGHP:
3038 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
3039 case glu::PRECISION_MEDIUMP:
3040 case glu::PRECISION_LAST:
3041 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
3042 default:
3043 DE_FATAL("Impossible");
3044 }
3045
3046 return 0;
3047 }
3048
3049 template <>
precision(const EvalContext & ctx,double ret,double x) const3050 double LogFunc<Signature<deFloat16, deFloat16>>::precision(const EvalContext &ctx, double ret, double x) const
3051 {
3052 if (x <= 0)
3053 return TCU_NAN;
3054 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -7) : ctx.format.ulp(ret, 3.0);
3055 }
3056
3057 // Spec: "The precision of double-precision instructions is at least that of single precision."
3058 // Lets pick float high precision as a reference.
3059 template <>
precision(const EvalContext & ctx,double ret,double x) const3060 double LogFunc<Signature<double, double>>::precision(const EvalContext &ctx, double ret, double x) const
3061 {
3062 if (x <= 0)
3063 return TCU_NAN;
3064 return (0.5 <= x && x <= 2.0) ? deLdExp(1.0, -21) : ctx.format.ulp(ret, 3.0);
3065 }
3066
3067 template <class T>
3068 class Log2 : public LogFunc<T>
3069 {
3070 public:
Log2(void)3071 Log2(void) : LogFunc<T>("log2", deLog2)
3072 {
3073 }
3074 };
3075 template <class T>
3076 class Log : public LogFunc<T>
3077 {
3078 public:
Log(void)3079 Log(void) : LogFunc<T>("log", deLog)
3080 {
3081 }
3082 };
3083
log2(const ExprP<float> & x)3084 ExprP<float> log2(const ExprP<float> &x)
3085 {
3086 return app<Log2<Signature<float, float>>>(x);
3087 }
log(const ExprP<float> & x)3088 ExprP<float> log(const ExprP<float> &x)
3089 {
3090 return app<Log<Signature<float, float>>>(x);
3091 }
3092
log2(const ExprP<deFloat16> & x)3093 ExprP<deFloat16> log2(const ExprP<deFloat16> &x)
3094 {
3095 return app<Log2<Signature<deFloat16, deFloat16>>>(x);
3096 }
log(const ExprP<deFloat16> & x)3097 ExprP<deFloat16> log(const ExprP<deFloat16> &x)
3098 {
3099 return app<Log<Signature<deFloat16, deFloat16>>>(x);
3100 }
3101
log2(const ExprP<double> & x)3102 ExprP<double> log2(const ExprP<double> &x)
3103 {
3104 return app<Log2<Signature<double, double>>>(x);
3105 }
log(const ExprP<double> & x)3106 ExprP<double> log(const ExprP<double> &x)
3107 {
3108 return app<Log<Signature<double, double>>>(x);
3109 }
3110
3111 #define DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0) \
3112 ExprP<TRET> NAME(const ExprP<T0> &arg0) \
3113 { \
3114 return app<CLASS>(arg0); \
3115 }
3116
3117 #define DEFINE_DERIVED1(CLASS, TRET, NAME, T0, ARG0, EXPANSION) \
3118 class CLASS : public DerivedFunc<Signature<TRET, T0>> /* NOLINT(CLASS) */ \
3119 { \
3120 public: \
3121 string getName(void) const \
3122 { \
3123 return #NAME; \
3124 } \
3125 \
3126 protected: \
3127 ExprP<TRET> doExpand(ExpandContext &, const CLASS::ArgExprs &args_) const \
3128 { \
3129 const ExprP<T0> &ARG0 = args_.a; \
3130 return EXPANSION; \
3131 } \
3132 }; \
3133 DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
3134
3135 #define DEFINE_DERIVED_DOUBLE1(CLASS, NAME, ARG0, EXPANSION) \
3136 DEFINE_DERIVED1(CLASS, double, NAME, double, ARG0, EXPANSION)
3137
3138 #define DEFINE_DERIVED_FLOAT1(CLASS, NAME, ARG0, EXPANSION) DEFINE_DERIVED1(CLASS, float, NAME, float, ARG0, EXPANSION)
3139
3140 #define DEFINE_DERIVED1_INPUTRANGE(CLASS, TRET, NAME, T0, ARG0, EXPANSION, INTERVAL) \
3141 class CLASS : public DerivedFunc<Signature<TRET, T0>> /* NOLINT(CLASS) */ \
3142 { \
3143 public: \
3144 string getName(void) const \
3145 { \
3146 return #NAME; \
3147 } \
3148 \
3149 protected: \
3150 ExprP<TRET> doExpand(ExpandContext &, const CLASS::ArgExprs &args_) const \
3151 { \
3152 const ExprP<T0> &ARG0 = args_.a; \
3153 return EXPANSION; \
3154 } \
3155 Interval getInputRange(const bool /*is16bit*/) const \
3156 { \
3157 return INTERVAL; \
3158 } \
3159 }; \
3160 DEFINE_CONSTRUCTOR1(CLASS, TRET, NAME, T0)
3161
3162 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
3163 DEFINE_DERIVED1_INPUTRANGE(CLASS, float, NAME, float, ARG0, EXPANSION, INTERVAL)
3164
3165 #define DEFINE_DERIVED_DOUBLE1_INPUTRANGE(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
3166 DEFINE_DERIVED1_INPUTRANGE(CLASS, double, NAME, double, ARG0, EXPANSION, INTERVAL)
3167
3168 #define DEFINE_DERIVED_FLOAT1_16BIT(CLASS, NAME, ARG0, EXPANSION) \
3169 DEFINE_DERIVED1(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION)
3170
3171 #define DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(CLASS, NAME, ARG0, EXPANSION, INTERVAL) \
3172 DEFINE_DERIVED1_INPUTRANGE(CLASS, deFloat16, NAME, deFloat16, ARG0, EXPANSION, INTERVAL)
3173
3174 #define DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1) \
3175 ExprP<TRET> NAME(const ExprP<T0> &arg0, const ExprP<T1> &arg1) \
3176 { \
3177 return app<CLASS>(arg0, arg1); \
3178 }
3179
3180 #define DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRVCASE) \
3181 class CLASS : public DerivedFunc<Signature<TRET, T0, T1>> /* NOLINT(CLASS) */ \
3182 { \
3183 public: \
3184 string getName(void) const \
3185 { \
3186 return #NAME; \
3187 } \
3188 \
3189 SpirVCaseT getSpirvCase(void) const \
3190 { \
3191 return SPIRVCASE; \
3192 } \
3193 \
3194 protected: \
3195 ExprP<TRET> doExpand(ExpandContext &, const ArgExprs &args_) const \
3196 { \
3197 const ExprP<T0> &Arg0 = args_.a; \
3198 const ExprP<T1> &Arg1 = args_.b; \
3199 return EXPANSION; \
3200 } \
3201 }; \
3202 DEFINE_CONSTRUCTOR2(CLASS, TRET, NAME, T0, T1)
3203
3204 #define DEFINE_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION) \
3205 DEFINE_CASED_DERIVED2(CLASS, TRET, NAME, T0, Arg0, T1, Arg1, EXPANSION, SPIRV_CASETYPE_NONE)
3206
3207 #define DEFINE_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
3208 DEFINE_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION)
3209
3210 #define DEFINE_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION) \
3211 DEFINE_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION)
3212
3213 #define DEFINE_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION) \
3214 DEFINE_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION)
3215
3216 #define DEFINE_CASED_DERIVED_FLOAT2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
3217 DEFINE_CASED_DERIVED2(CLASS, float, NAME, float, Arg0, float, Arg1, EXPANSION, SPIRVCASE)
3218
3219 #define DEFINE_CASED_DERIVED_FLOAT2_16BIT(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
3220 DEFINE_CASED_DERIVED2(CLASS, deFloat16, NAME, deFloat16, Arg0, deFloat16, Arg1, EXPANSION, SPIRVCASE)
3221
3222 #define DEFINE_CASED_DERIVED_DOUBLE2(CLASS, NAME, Arg0, Arg1, EXPANSION, SPIRVCASE) \
3223 DEFINE_CASED_DERIVED2(CLASS, double, NAME, double, Arg0, double, Arg1, EXPANSION, SPIRVCASE)
3224
3225 #define DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2) \
3226 ExprP<TRET> NAME(const ExprP<T0> &arg0, const ExprP<T1> &arg1, const ExprP<T2> &arg2) \
3227 { \
3228 return app<CLASS>(arg0, arg1, arg2); \
3229 }
3230
3231 #define DEFINE_DERIVED3(CLASS, TRET, NAME, T0, ARG0, T1, ARG1, T2, ARG2, EXPANSION) \
3232 class CLASS : public DerivedFunc<Signature<TRET, T0, T1, T2>> /* NOLINT(CLASS) */ \
3233 { \
3234 public: \
3235 string getName(void) const \
3236 { \
3237 return #NAME; \
3238 } \
3239 \
3240 protected: \
3241 ExprP<TRET> doExpand(ExpandContext &, const ArgExprs &args_) const \
3242 { \
3243 const ExprP<T0> &ARG0 = args_.a; \
3244 const ExprP<T1> &ARG1 = args_.b; \
3245 const ExprP<T2> &ARG2 = args_.c; \
3246 return EXPANSION; \
3247 } \
3248 }; \
3249 DEFINE_CONSTRUCTOR3(CLASS, TRET, NAME, T0, T1, T2)
3250
3251 #define DEFINE_DERIVED_DOUBLE3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3252 DEFINE_DERIVED3(CLASS, double, NAME, double, ARG0, double, ARG1, double, ARG2, EXPANSION)
3253
3254 #define DEFINE_DERIVED_FLOAT3(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3255 DEFINE_DERIVED3(CLASS, float, NAME, float, ARG0, float, ARG1, float, ARG2, EXPANSION)
3256
3257 #define DEFINE_DERIVED_FLOAT3_16BIT(CLASS, NAME, ARG0, ARG1, ARG2, EXPANSION) \
3258 DEFINE_DERIVED3(CLASS, deFloat16, NAME, deFloat16, ARG0, deFloat16, ARG1, deFloat16, ARG2, EXPANSION)
3259
3260 #define DEFINE_CONSTRUCTOR4(CLASS, TRET, NAME, T0, T1, T2, T3) \
3261 ExprP<TRET> NAME(const ExprP<T0> &arg0, const ExprP<T1> &arg1, const ExprP<T2> &arg2, const ExprP<T3> &arg3) \
3262 { \
3263 return app<CLASS>(arg0, arg1, arg2, arg3); \
3264 }
3265
3266 typedef InverseSqrt<Signature<deFloat16, deFloat16>> InverseSqrt16Bit;
3267 typedef InverseSqrt<Signature<float, float>> InverseSqrt32Bit;
3268 typedef InverseSqrt<Signature<double, double>> InverseSqrt64Bit;
3269
3270 DEFINE_DERIVED_FLOAT1(Sqrt32Bit, sqrt, x, constant(1.0f) / app<InverseSqrt32Bit>(x))
3271 DEFINE_DERIVED_FLOAT1_16BIT(Sqrt16Bit, sqrt, x, constant((deFloat16)FLOAT16_1_0) / app<InverseSqrt16Bit>(x))
3272 DEFINE_DERIVED_DOUBLE1(Sqrt64Bit, sqrt, x, constant(1.0) / app<InverseSqrt64Bit>(x))
3273 DEFINE_DERIVED_FLOAT2(Pow, pow, x, y, exp2<float>(y *log2(x)))
3274 DEFINE_DERIVED_FLOAT2_16BIT(Pow16, pow, x, y, exp2<deFloat16>(y *log2(x)))
3275 DEFINE_DERIVED_FLOAT1(Radians, radians, d, f32Constant(DE_PI_DOUBLE / 180.0f) * d)
3276 DEFINE_DERIVED_FLOAT1_16BIT(Radians16, radians, d, f16Constant(DE_PI_DOUBLE / 180.0f) * d)
3277 DEFINE_DERIVED_FLOAT1(Degrees, degrees, r, f32Constant(180.0 / DE_PI_DOUBLE) * r)
3278 DEFINE_DERIVED_FLOAT1_16BIT(Degrees16, degrees, r, f16Constant(180.0 / DE_PI_DOUBLE) * r)
3279
3280 /*Proper parameters for template T
3281 Signature<float, float> 32bit tests
3282 Signature<float, deFloat16> 16bit tests*/
3283 template <class T>
3284 class TrigFunc : public CFloatFunc1<T>
3285 {
3286 public:
TrigFunc(const string & name,DoubleFunc1 & func,const Interval & loEx,const Interval & hiEx)3287 TrigFunc(const string &name, DoubleFunc1 &func, const Interval &loEx, const Interval &hiEx)
3288 : CFloatFunc1<T>(name, func)
3289 , m_loExtremum(loEx)
3290 , m_hiExtremum(hiEx)
3291 {
3292 }
3293
3294 protected:
innerExtrema(const EvalContext &,const Interval & angle) const3295 Interval innerExtrema(const EvalContext &, const Interval &angle) const
3296 {
3297 const double lo = angle.lo();
3298 const double hi = angle.hi();
3299 const int loSlope = doGetSlope(lo);
3300 const int hiSlope = doGetSlope(hi);
3301
3302 // Detect the high and low values the function can take between the
3303 // interval endpoints.
3304 if (angle.length() >= 2.0 * DE_PI_DOUBLE)
3305 {
3306 // The interval is longer than a full cycle, so it must get all possible values.
3307 return m_hiExtremum | m_loExtremum;
3308 }
3309 else if (loSlope == 1 && hiSlope == -1)
3310 {
3311 // The slope can change from positive to negative only at the maximum value.
3312 return m_hiExtremum;
3313 }
3314 else if (loSlope == -1 && hiSlope == 1)
3315 {
3316 // The slope can change from negative to positive only at the maximum value.
3317 return m_loExtremum;
3318 }
3319 else if (loSlope == hiSlope &&
3320 deIntSign(CFloatFunc1<T>::applyExact(hi) - CFloatFunc1<T>::applyExact(lo)) * loSlope == -1)
3321 {
3322 // The slope has changed twice between the endpoints, so both extrema are included.
3323 return m_hiExtremum | m_loExtremum;
3324 }
3325
3326 return Interval();
3327 }
3328
getCodomain(const EvalContext &) const3329 Interval getCodomain(const EvalContext &) const
3330 {
3331 // Ensure that result is always within [-1, 1], or NaN (for +-inf)
3332 return Interval(-1.0, 1.0) | TCU_NAN;
3333 }
3334
3335 double precision(const EvalContext &ctx, double ret, double arg) const;
3336
3337 Interval getInputRange(const bool is16bit) const;
3338 virtual int doGetSlope(double angle) const = 0;
3339
3340 Interval m_loExtremum;
3341 Interval m_hiExtremum;
3342 };
3343
3344 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3345 template <>
getInputRange(const bool is16bit) const3346 Interval TrigFunc<Signature<float, float>>::getInputRange(const bool is16bit) const
3347 {
3348 DE_UNREF(is16bit);
3349 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3350 }
3351
3352 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3353 template <>
getInputRange(const bool is16bit) const3354 Interval TrigFunc<Signature<deFloat16, deFloat16>>::getInputRange(const bool is16bit) const
3355 {
3356 DE_UNREF(is16bit);
3357 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3358 }
3359
3360 //Only -DE_PI_DOUBLE, DE_PI_DOUBLE input range
3361 template <>
getInputRange(const bool is16bit) const3362 Interval TrigFunc<Signature<double, double>>::getInputRange(const bool is16bit) const
3363 {
3364 DE_UNREF(is16bit);
3365 return Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE);
3366 }
3367
3368 template <>
precision(const EvalContext & ctx,double ret,double arg) const3369 double TrigFunc<Signature<float, float>>::precision(const EvalContext &ctx, double ret, double arg) const
3370 {
3371 DE_UNREF(ret);
3372 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3373 {
3374 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3375 return deLdExp(1.0, -11);
3376 else
3377 {
3378 // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3379 // 2^-11 at x == pi.
3380 return deLdExp(deAbs(arg), -12);
3381 }
3382 }
3383 else
3384 {
3385 DE_ASSERT(ctx.floatPrecision == glu::PRECISION_MEDIUMP || ctx.floatPrecision == glu::PRECISION_LAST);
3386
3387 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3388 return deLdExp(1.0, -7);
3389 else
3390 {
3391 // |x| * 2^-8, slightly larger than 2^-7 at x == pi
3392 return deLdExp(deAbs(arg), -8);
3393 }
3394 }
3395 }
3396 //
3397 /*
3398 * Half tests
3399 * From Spec:
3400 * Absolute error 2^{-7} inside the range [-pi, pi].
3401 */
3402 template <>
precision(const EvalContext & ctx,double ret,double arg) const3403 double TrigFunc<Signature<deFloat16, deFloat16>>::precision(const EvalContext &ctx, double ret, double arg) const
3404 {
3405 DE_UNREF(ctx);
3406 DE_UNREF(ret);
3407 DE_UNREF(arg);
3408 DE_ASSERT(-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE && ctx.floatPrecision == glu::PRECISION_LAST);
3409 return deLdExp(1.0, -7);
3410 }
3411
3412 // Spec: "The precision of double-precision instructions is at least that of single precision."
3413 // Lets pick float high precision as a reference.
3414 template <>
precision(const EvalContext & ctx,double ret,double arg) const3415 double TrigFunc<Signature<double, double>>::precision(const EvalContext &ctx, double ret, double arg) const
3416 {
3417 DE_UNREF(ctx);
3418 DE_UNREF(ret);
3419 if (-DE_PI_DOUBLE <= arg && arg <= DE_PI_DOUBLE)
3420 return deLdExp(1.0, -11);
3421 else
3422 {
3423 // "larger otherwise", let's pick |x| * 2^-12 , which is slightly over
3424 // 2^-11 at x == pi.
3425 return deLdExp(deAbs(arg), -12);
3426 }
3427 }
3428
3429 /*Proper parameters for template T
3430 Signature<float, float> 32bit tests
3431 Signature<float, deFloat16> 16bit tests*/
3432 template <class T>
3433 class Sin : public TrigFunc<T>
3434 {
3435 public:
Sin(void)3436 Sin(void) : TrigFunc<T>("sin", deSin, -1.0, 1.0)
3437 {
3438 }
3439
3440 protected:
doGetSlope(double angle) const3441 int doGetSlope(double angle) const
3442 {
3443 return deIntSign(deCos(angle));
3444 }
3445 };
3446
sin(const ExprP<float> & x)3447 ExprP<float> sin(const ExprP<float> &x)
3448 {
3449 return app<Sin<Signature<float, float>>>(x);
3450 }
sin(const ExprP<deFloat16> & x)3451 ExprP<deFloat16> sin(const ExprP<deFloat16> &x)
3452 {
3453 return app<Sin<Signature<deFloat16, deFloat16>>>(x);
3454 }
sin(const ExprP<double> & x)3455 ExprP<double> sin(const ExprP<double> &x)
3456 {
3457 return app<Sin<Signature<double, double>>>(x);
3458 }
3459
3460 template <class T>
3461 class Cos : public TrigFunc<T>
3462 {
3463 public:
Cos(void)3464 Cos(void) : TrigFunc<T>("cos", deCos, -1.0, 1.0)
3465 {
3466 }
3467
3468 protected:
doGetSlope(double angle) const3469 int doGetSlope(double angle) const
3470 {
3471 return -deIntSign(deSin(angle));
3472 }
3473 };
3474
cos(const ExprP<float> & x)3475 ExprP<float> cos(const ExprP<float> &x)
3476 {
3477 return app<Cos<Signature<float, float>>>(x);
3478 }
cos(const ExprP<deFloat16> & x)3479 ExprP<deFloat16> cos(const ExprP<deFloat16> &x)
3480 {
3481 return app<Cos<Signature<deFloat16, deFloat16>>>(x);
3482 }
cos(const ExprP<double> & x)3483 ExprP<double> cos(const ExprP<double> &x)
3484 {
3485 return app<Cos<Signature<double, double>>>(x);
3486 }
3487
3488 DEFINE_DERIVED_FLOAT1_INPUTRANGE(Tan, tan, x, sin(x) * (constant(1.0f) / cos(x)),
3489 Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3490 DEFINE_DERIVED_FLOAT1_INPUTRANGE_16BIT(Tan16Bit, tan, x, sin(x) * (constant((deFloat16)FLOAT16_1_0) / cos(x)),
3491 Interval(false, -DE_PI_DOUBLE, DE_PI_DOUBLE))
3492
3493 template <class T>
3494 class ATan : public CFloatFunc1<T>
3495 {
3496 public:
ATan(void)3497 ATan(void) : CFloatFunc1<T>("atan", deAtanOver)
3498 {
3499 }
3500
3501 protected:
precision(const EvalContext & ctx,double ret,double) const3502 double precision(const EvalContext &ctx, double ret, double) const
3503 {
3504 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3505 return ctx.format.ulp(ret, 4096.0);
3506 else
3507 return ctx.format.ulp(ret, 5.0);
3508 }
3509
getCodomain(const EvalContext & ctx) const3510 Interval getCodomain(const EvalContext &ctx) const
3511 {
3512 return ctx.format.roundOut(Interval(-0.5 * DE_PI_DOUBLE, 0.5 * DE_PI_DOUBLE), true);
3513 }
3514 };
3515
3516 template <class T>
3517 class ATan2 : public CFloatFunc2<T>
3518 {
3519 public:
ATan2(void)3520 ATan2(void) : CFloatFunc2<T>("atan", deAtan2)
3521 {
3522 }
3523
3524 protected:
innerExtrema(const EvalContext & ctx,const Interval & yi,const Interval & xi) const3525 Interval innerExtrema(const EvalContext &ctx, const Interval &yi, const Interval &xi) const
3526 {
3527 Interval ret;
3528
3529 if (yi.contains(0.0))
3530 {
3531 if (xi.contains(0.0))
3532 ret |= TCU_NAN;
3533 if (xi.intersects(Interval(-TCU_INFINITY, 0.0)))
3534 ret |= ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3535 }
3536
3537 if (!yi.isFinite(ctx.format.getMaxValue()) || !xi.isFinite(ctx.format.getMaxValue()))
3538 {
3539 // Infinities may not be supported, allow anything, including NaN
3540 ret |= TCU_NAN;
3541 }
3542
3543 return ret;
3544 }
3545
precision(const EvalContext & ctx,double ret,double,double) const3546 double precision(const EvalContext &ctx, double ret, double, double) const
3547 {
3548 if (ctx.floatPrecision == glu::PRECISION_HIGHP)
3549 return ctx.format.ulp(ret, 4096.0);
3550 else
3551 return ctx.format.ulp(ret, 5.0);
3552 }
3553
getCodomain(const EvalContext & ctx) const3554 Interval getCodomain(const EvalContext &ctx) const
3555 {
3556 return ctx.format.roundOut(Interval(-DE_PI_DOUBLE, DE_PI_DOUBLE), true);
3557 }
3558 };
3559
atan2(const ExprP<float> & x,const ExprP<float> & y)3560 ExprP<float> atan2(const ExprP<float> &x, const ExprP<float> &y)
3561 {
3562 return app<ATan2<Signature<float, float, float>>>(x, y);
3563 }
3564
atan2(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)3565 ExprP<deFloat16> atan2(const ExprP<deFloat16> &x, const ExprP<deFloat16> &y)
3566 {
3567 return app<ATan2<Signature<deFloat16, deFloat16, deFloat16>>>(x, y);
3568 }
3569
atan2(const ExprP<double> & x,const ExprP<double> & y)3570 ExprP<double> atan2(const ExprP<double> &x, const ExprP<double> &y)
3571 {
3572 return app<ATan2<Signature<double, double, double>>>(x, y);
3573 }
3574
3575 DEFINE_DERIVED_FLOAT1(Sinh, sinh, x, (exp<float>(x) - exp<float>(-x)) / constant(2.0f))
3576 DEFINE_DERIVED_FLOAT1(Cosh, cosh, x, (exp<float>(x) + exp<float>(-x)) / constant(2.0f))
3577 DEFINE_DERIVED_FLOAT1(Tanh, tanh, x, sinh(x) / cosh(x))
3578
3579 DEFINE_DERIVED_FLOAT1_16BIT(Sinh16Bit, sinh, x, (exp(x) - exp(-x)) / constant((deFloat16)FLOAT16_2_0))
3580 DEFINE_DERIVED_FLOAT1_16BIT(Cosh16Bit, cosh, x, (exp(x) + exp(-x)) / constant((deFloat16)FLOAT16_2_0))
3581 DEFINE_DERIVED_FLOAT1_16BIT(Tanh16Bit, tanh, x, sinh(x) / cosh(x))
3582
3583 DEFINE_DERIVED_FLOAT1(ASin, asin, x, atan2(x, sqrt(constant(1.0f) - x * x)))
3584 DEFINE_DERIVED_FLOAT1(ACos, acos, x, atan2(sqrt(constant(1.0f) - x * x), x))
3585 DEFINE_DERIVED_FLOAT1(ASinh, asinh, x, log(x + sqrt(x * x + constant(1.0f))))
3586 DEFINE_DERIVED_FLOAT1(ACosh, acosh, x,
3587 log(x +
3588 sqrt(alternatives((x + constant(1.0f)) * (x - constant(1.0f)), (x * x - constant(1.0f))))))
3589 DEFINE_DERIVED_FLOAT1(ATanh, atanh, x, constant(0.5f) * log((constant(1.0f) + x) / (constant(1.0f) - x)))
3590
3591 DEFINE_DERIVED_FLOAT1_16BIT(ASin16Bit, asin, x, atan2(x, sqrt(constant((deFloat16)FLOAT16_1_0) - x * x)))
3592 DEFINE_DERIVED_FLOAT1_16BIT(ACos16Bit, acos, x, atan2(sqrt(constant((deFloat16)FLOAT16_1_0) - x * x), x))
3593 DEFINE_DERIVED_FLOAT1_16BIT(ASinh16Bit, asinh, x, log(x + sqrt(x * x + constant((deFloat16)FLOAT16_1_0))))
3594 DEFINE_DERIVED_FLOAT1_16BIT(ACosh16Bit, acosh, x,
3595 log(x + sqrt(alternatives((x + constant((deFloat16)FLOAT16_1_0)) *
3596 (x - constant((deFloat16)FLOAT16_1_0)),
3597 (x * x - constant((deFloat16)FLOAT16_1_0))))))
3598 DEFINE_DERIVED_FLOAT1_16BIT(ATanh16Bit, atanh, x,
3599 constant((deFloat16)FLOAT16_0_5) *
3600 log((constant((deFloat16)FLOAT16_1_0) + x) / (constant((deFloat16)FLOAT16_1_0) - x)))
3601
3602 template <typename T>
3603 class GetComponent : public PrimitiveFunc<Signature<typename T::Element, T, int>>
3604 {
3605 public:
3606 typedef typename GetComponent::IRet IRet;
3607
getName(void) const3608 string getName(void) const
3609 {
3610 return "_getComponent";
3611 }
3612
print(ostream & os,const BaseArgExprs & args) const3613 void print(ostream &os, const BaseArgExprs &args) const
3614 {
3615 os << *args[0] << "[" << *args[1] << "]";
3616 }
3617
3618 protected:
doApply(const EvalContext &,const typename GetComponent::IArgs & iargs) const3619 IRet doApply(const EvalContext &, const typename GetComponent::IArgs &iargs) const
3620 {
3621 IRet ret;
3622
3623 for (int compNdx = 0; compNdx < T::SIZE; ++compNdx)
3624 {
3625 if (iargs.b.contains(compNdx))
3626 ret = unionIVal<typename T::Element>(ret, iargs.a[compNdx]);
3627 }
3628
3629 return ret;
3630 }
3631 };
3632
3633 template <typename T>
getComponent(const ExprP<T> & container,int ndx)3634 ExprP<typename T::Element> getComponent(const ExprP<T> &container, int ndx)
3635 {
3636 DE_ASSERT(0 <= ndx && ndx < T::SIZE);
3637 return app<GetComponent<T>>(container, constant(ndx));
3638 }
3639
3640 template <typename T>
3641 string vecNamePrefix(void);
3642 template <>
vecNamePrefix(void)3643 string vecNamePrefix<float>(void)
3644 {
3645 return "";
3646 }
3647 template <>
vecNamePrefix(void)3648 string vecNamePrefix<deFloat16>(void)
3649 {
3650 return "";
3651 }
3652 template <>
vecNamePrefix(void)3653 string vecNamePrefix<double>(void)
3654 {
3655 return "d";
3656 }
3657 template <>
vecNamePrefix(void)3658 string vecNamePrefix<int>(void)
3659 {
3660 return "i";
3661 }
3662 template <>
vecNamePrefix(void)3663 string vecNamePrefix<bool>(void)
3664 {
3665 return "b";
3666 }
3667
3668 template <typename T, int Size>
vecName(void)3669 string vecName(void)
3670 {
3671 return vecNamePrefix<T>() + "vec" + de::toString(Size);
3672 }
3673
3674 template <typename T, int Size>
3675 class GenVec;
3676
3677 template <typename T>
3678 class GenVec<T, 1> : public DerivedFunc<Signature<T, T>>
3679 {
3680 public:
3681 typedef typename GenVec<T, 1>::ArgExprs ArgExprs;
3682
getName(void) const3683 string getName(void) const
3684 {
3685 return "_" + vecName<T, 1>();
3686 }
3687
3688 protected:
doExpand(ExpandContext &,const ArgExprs & args) const3689 ExprP<T> doExpand(ExpandContext &, const ArgExprs &args) const
3690 {
3691 return args.a;
3692 }
3693 };
3694
3695 template <typename T>
3696 class GenVec<T, 2> : public PrimitiveFunc<Signature<Vector<T, 2>, T, T>>
3697 {
3698 public:
3699 typedef typename GenVec::IRet IRet;
3700 typedef typename GenVec::IArgs IArgs;
3701
getName(void) const3702 string getName(void) const
3703 {
3704 return vecName<T, 2>();
3705 }
3706
3707 protected:
doApply(const EvalContext &,const IArgs & iargs) const3708 IRet doApply(const EvalContext &, const IArgs &iargs) const
3709 {
3710 return IRet(iargs.a, iargs.b);
3711 }
3712 };
3713
3714 template <typename T>
3715 class GenVec<T, 3> : public PrimitiveFunc<Signature<Vector<T, 3>, T, T, T>>
3716 {
3717 public:
3718 typedef typename GenVec::IRet IRet;
3719 typedef typename GenVec::IArgs IArgs;
3720
getName(void) const3721 string getName(void) const
3722 {
3723 return vecName<T, 3>();
3724 }
3725
3726 protected:
doApply(const EvalContext &,const IArgs & iargs) const3727 IRet doApply(const EvalContext &, const IArgs &iargs) const
3728 {
3729 return IRet(iargs.a, iargs.b, iargs.c);
3730 }
3731 };
3732
3733 template <typename T>
3734 class GenVec<T, 4> : public PrimitiveFunc<Signature<Vector<T, 4>, T, T, T, T>>
3735 {
3736 public:
3737 typedef typename GenVec::IRet IRet;
3738 typedef typename GenVec::IArgs IArgs;
3739
getName(void) const3740 string getName(void) const
3741 {
3742 return vecName<T, 4>();
3743 }
3744
3745 protected:
doApply(const EvalContext &,const IArgs & iargs) const3746 IRet doApply(const EvalContext &, const IArgs &iargs) const
3747 {
3748 return IRet(iargs.a, iargs.b, iargs.c, iargs.d);
3749 }
3750 };
3751
3752 template <typename T, int Rows, int Columns>
3753 class GenMat;
3754
3755 template <typename T, int Rows>
3756 class GenMat<T, Rows, 2> : public PrimitiveFunc<Signature<Matrix<T, Rows, 2>, Vector<T, Rows>, Vector<T, Rows>>>
3757 {
3758 public:
3759 typedef typename GenMat::Ret Ret;
3760 typedef typename GenMat::IRet IRet;
3761 typedef typename GenMat::IArgs IArgs;
3762
getName(void) const3763 string getName(void) const
3764 {
3765 return dataTypeNameOf<Ret>();
3766 }
3767
3768 protected:
doApply(const EvalContext &,const IArgs & iargs) const3769 IRet doApply(const EvalContext &, const IArgs &iargs) const
3770 {
3771 IRet ret;
3772 ret[0] = iargs.a;
3773 ret[1] = iargs.b;
3774 return ret;
3775 }
3776 };
3777
3778 template <typename T, int Rows>
3779 class GenMat<T, Rows, 3>
3780 : public PrimitiveFunc<Signature<Matrix<T, Rows, 3>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>>>
3781 {
3782 public:
3783 typedef typename GenMat::Ret Ret;
3784 typedef typename GenMat::IRet IRet;
3785 typedef typename GenMat::IArgs IArgs;
3786
getName(void) const3787 string getName(void) const
3788 {
3789 return dataTypeNameOf<Ret>();
3790 }
3791
3792 protected:
doApply(const EvalContext &,const IArgs & iargs) const3793 IRet doApply(const EvalContext &, const IArgs &iargs) const
3794 {
3795 IRet ret;
3796 ret[0] = iargs.a;
3797 ret[1] = iargs.b;
3798 ret[2] = iargs.c;
3799 return ret;
3800 }
3801 };
3802
3803 template <typename T, int Rows>
3804 class GenMat<T, Rows, 4>
3805 : public PrimitiveFunc<
3806 Signature<Matrix<T, Rows, 4>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>, Vector<T, Rows>>>
3807 {
3808 public:
3809 typedef typename GenMat::Ret Ret;
3810 typedef typename GenMat::IRet IRet;
3811 typedef typename GenMat::IArgs IArgs;
3812
getName(void) const3813 string getName(void) const
3814 {
3815 return dataTypeNameOf<Ret>();
3816 }
3817
3818 protected:
doApply(const EvalContext &,const IArgs & iargs) const3819 IRet doApply(const EvalContext &, const IArgs &iargs) const
3820 {
3821 IRet ret;
3822 ret[0] = iargs.a;
3823 ret[1] = iargs.b;
3824 ret[2] = iargs.c;
3825 ret[3] = iargs.d;
3826 return ret;
3827 }
3828 };
3829
3830 template <typename T, int Rows>
mat2(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1)3831 ExprP<Matrix<T, Rows, 2>> mat2(const ExprP<Vector<T, Rows>> &arg0, const ExprP<Vector<T, Rows>> &arg1)
3832 {
3833 return app<GenMat<T, Rows, 2>>(arg0, arg1);
3834 }
3835
3836 template <typename T, int Rows>
mat3(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2)3837 ExprP<Matrix<T, Rows, 3>> mat3(const ExprP<Vector<T, Rows>> &arg0, const ExprP<Vector<T, Rows>> &arg1,
3838 const ExprP<Vector<T, Rows>> &arg2)
3839 {
3840 return app<GenMat<T, Rows, 3>>(arg0, arg1, arg2);
3841 }
3842
3843 template <typename T, int Rows>
mat4(const ExprP<Vector<T,Rows>> & arg0,const ExprP<Vector<T,Rows>> & arg1,const ExprP<Vector<T,Rows>> & arg2,const ExprP<Vector<T,Rows>> & arg3)3844 ExprP<Matrix<T, Rows, 4>> mat4(const ExprP<Vector<T, Rows>> &arg0, const ExprP<Vector<T, Rows>> &arg1,
3845 const ExprP<Vector<T, Rows>> &arg2, const ExprP<Vector<T, Rows>> &arg3)
3846 {
3847 return app<GenMat<T, Rows, 4>>(arg0, arg1, arg2, arg3);
3848 }
3849
3850 template <typename T, int Rows, int Cols>
3851 class MatNeg : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>, Matrix<T, Rows, Cols>>>
3852 {
3853 public:
3854 typedef typename MatNeg::IRet IRet;
3855 typedef typename MatNeg::IArgs IArgs;
3856
getName(void) const3857 string getName(void) const
3858 {
3859 return "_matNeg";
3860 }
3861
3862 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3863 void doPrint(ostream &os, const BaseArgExprs &args) const
3864 {
3865 os << "-(" << *args[0] << ")";
3866 }
3867
doApply(const EvalContext &,const IArgs & iargs) const3868 IRet doApply(const EvalContext &, const IArgs &iargs) const
3869 {
3870 IRet ret;
3871
3872 for (int col = 0; col < Cols; ++col)
3873 {
3874 for (int row = 0; row < Rows; ++row)
3875 ret[col][row] = -iargs.a[col][row];
3876 }
3877
3878 return ret;
3879 }
3880 };
3881
3882 template <typename T, typename Sig>
3883 class CompWiseFunc : public PrimitiveFunc<Sig>
3884 {
3885 public:
3886 typedef Func<Signature<T, T, T>> ScalarFunc;
3887
getName(void) const3888 string getName(void) const
3889 {
3890 return doGetScalarFunc().getName();
3891 }
3892
3893 protected:
doPrint(ostream & os,const BaseArgExprs & args) const3894 void doPrint(ostream &os, const BaseArgExprs &args) const
3895 {
3896 doGetScalarFunc().print(os, args);
3897 }
3898
3899 virtual const ScalarFunc &doGetScalarFunc(void) const = 0;
3900 };
3901
3902 template <typename T, int Rows, int Cols>
3903 class CompMatFuncBase
3904 : public CompWiseFunc<T, Signature<Matrix<T, Rows, Cols>, Matrix<T, Rows, Cols>, Matrix<T, Rows, Cols>>>
3905 {
3906 public:
3907 typedef typename CompMatFuncBase::IRet IRet;
3908 typedef typename CompMatFuncBase::IArgs IArgs;
3909
3910 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const3911 IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
3912 {
3913 IRet ret;
3914
3915 for (int col = 0; col < Cols; ++col)
3916 {
3917 for (int row = 0; row < Rows; ++row)
3918 ret[col][row] = this->doGetScalarFunc().apply(ctx, iargs.a[col][row], iargs.b[col][row]);
3919 }
3920
3921 return ret;
3922 }
3923 };
3924
3925 template <typename F, typename T, int Rows, int Cols>
3926 class CompMatFunc : public CompMatFuncBase<T, Rows, Cols>
3927 {
3928 protected:
doGetScalarFunc(void) const3929 const typename CompMatFunc::ScalarFunc &doGetScalarFunc(void) const
3930 {
3931 return instance<F>();
3932 }
3933 };
3934
3935 template <class T>
3936 class ScalarMatrixCompMult : public Mul<Signature<T, T, T>>
3937 {
3938 public:
getName(void) const3939 string getName(void) const
3940 {
3941 return "matrixCompMult";
3942 }
3943
doPrint(ostream & os,const BaseArgExprs & args) const3944 void doPrint(ostream &os, const BaseArgExprs &args) const
3945 {
3946 Func<Signature<T, T, T>>::doPrint(os, args);
3947 }
3948 };
3949
3950 template <int Rows, int Cols, class T>
3951 class MatrixCompMult : public CompMatFunc<ScalarMatrixCompMult<T>, T, Rows, Cols>
3952 {
3953 };
3954
3955 template <int Rows, int Cols>
3956 class ScalarMatFuncBase
3957 : public CompWiseFunc<float, Signature<Matrix<float, Rows, Cols>, Matrix<float, Rows, Cols>, float>>
3958 {
3959 public:
3960 typedef typename ScalarMatFuncBase::IRet IRet;
3961 typedef typename ScalarMatFuncBase::IArgs IArgs;
3962
3963 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const3964 IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
3965 {
3966 IRet ret;
3967
3968 for (int col = 0; col < Cols; ++col)
3969 {
3970 for (int row = 0; row < Rows; ++row)
3971 ret[col][row] = this->doGetScalarFunc().apply(ctx, iargs.a[col][row], iargs.b);
3972 }
3973
3974 return ret;
3975 }
3976 };
3977
3978 template <typename F, int Rows, int Cols>
3979 class ScalarMatFunc : public ScalarMatFuncBase<Rows, Cols>
3980 {
3981 protected:
doGetScalarFunc(void) const3982 const typename ScalarMatFunc::ScalarFunc &doGetScalarFunc(void) const
3983 {
3984 return instance<F>();
3985 }
3986 };
3987
3988 template <typename T, int Size>
3989 struct GenXType;
3990
3991 template <typename T>
3992 struct GenXType<T, 1>
3993 {
genXTypevkt::shaderexecutor::Functions::GenXType3994 static ExprP<T> genXType(const ExprP<T> &x)
3995 {
3996 return x;
3997 }
3998 };
3999
4000 template <typename T>
4001 struct GenXType<T, 2>
4002 {
genXTypevkt::shaderexecutor::Functions::GenXType4003 static ExprP<Vector<T, 2>> genXType(const ExprP<T> &x)
4004 {
4005 return app<GenVec<T, 2>>(x, x);
4006 }
4007 };
4008
4009 template <typename T>
4010 struct GenXType<T, 3>
4011 {
genXTypevkt::shaderexecutor::Functions::GenXType4012 static ExprP<Vector<T, 3>> genXType(const ExprP<T> &x)
4013 {
4014 return app<GenVec<T, 3>>(x, x, x);
4015 }
4016 };
4017
4018 template <typename T>
4019 struct GenXType<T, 4>
4020 {
genXTypevkt::shaderexecutor::Functions::GenXType4021 static ExprP<Vector<T, 4>> genXType(const ExprP<T> &x)
4022 {
4023 return app<GenVec<T, 4>>(x, x, x, x);
4024 }
4025 };
4026
4027 //! Returns an expression of vector of size `Size` (or scalar if Size == 1),
4028 //! with each element initialized with the expression `x`.
4029 template <typename T, int Size>
genXType(const ExprP<T> & x)4030 ExprP<typename ContainerOf<T, Size>::Container> genXType(const ExprP<T> &x)
4031 {
4032 return GenXType<T, Size>::genXType(x);
4033 }
4034
4035 typedef GenVec<float, 2> FloatVec2;
4036 DEFINE_CONSTRUCTOR2(FloatVec2, Vec2, vec2, float, float)
4037
4038 typedef GenVec<deFloat16, 2> FloatVec2_16bit;
4039 DEFINE_CONSTRUCTOR2(FloatVec2_16bit, Vec2_16Bit, vec2, deFloat16, deFloat16)
4040
4041 typedef GenVec<double, 2> DoubleVec2;
4042 DEFINE_CONSTRUCTOR2(DoubleVec2, Vec2_64Bit, vec2, double, double)
4043
4044 typedef GenVec<float, 3> FloatVec3;
4045 DEFINE_CONSTRUCTOR3(FloatVec3, Vec3, vec3, float, float, float)
4046
4047 typedef GenVec<deFloat16, 3> FloatVec3_16bit;
4048 DEFINE_CONSTRUCTOR3(FloatVec3_16bit, Vec3_16Bit, vec3, deFloat16, deFloat16, deFloat16)
4049
4050 typedef GenVec<double, 3> DoubleVec3;
4051 DEFINE_CONSTRUCTOR3(DoubleVec3, Vec3_64Bit, vec3, double, double, double)
4052
4053 typedef GenVec<float, 4> FloatVec4;
4054 DEFINE_CONSTRUCTOR4(FloatVec4, Vec4, vec4, float, float, float, float)
4055
4056 typedef GenVec<deFloat16, 4> FloatVec4_16bit;
4057 DEFINE_CONSTRUCTOR4(FloatVec4_16bit, Vec4_16Bit, vec4, deFloat16, deFloat16, deFloat16, deFloat16)
4058
4059 typedef GenVec<double, 4> DoubleVec4;
4060 DEFINE_CONSTRUCTOR4(DoubleVec4, Vec4_64Bit, vec4, double, double, double, double)
4061
4062 template <class T>
4063 const ExprP<T> getConstZero(void);
4064 template <class T>
4065 const ExprP<T> getConstOne(void);
4066 template <class T>
4067 const ExprP<T> getConstTwo(void);
4068
4069 template <>
getConstZero(void)4070 const ExprP<float> getConstZero<float>(void)
4071 {
4072 return constant(0.0f);
4073 }
4074
4075 template <>
getConstZero(void)4076 const ExprP<deFloat16> getConstZero<deFloat16>(void)
4077 {
4078 return constant((deFloat16)FLOAT16_0_0);
4079 }
4080
4081 template <>
getConstZero(void)4082 const ExprP<double> getConstZero<double>(void)
4083 {
4084 return constant(0.0);
4085 }
4086
4087 template <>
getConstOne(void)4088 const ExprP<float> getConstOne<float>(void)
4089 {
4090 return constant(1.0f);
4091 }
4092
4093 template <>
getConstOne(void)4094 const ExprP<deFloat16> getConstOne<deFloat16>(void)
4095 {
4096 return constant((deFloat16)FLOAT16_1_0);
4097 }
4098
4099 template <>
getConstOne(void)4100 const ExprP<double> getConstOne<double>(void)
4101 {
4102 return constant(1.0);
4103 }
4104
4105 template <>
getConstTwo(void)4106 const ExprP<float> getConstTwo<float>(void)
4107 {
4108 return constant(2.0f);
4109 }
4110
4111 template <>
getConstTwo(void)4112 const ExprP<deFloat16> getConstTwo<deFloat16>(void)
4113 {
4114 return constant((deFloat16)FLOAT16_2_0);
4115 }
4116
4117 template <>
getConstTwo(void)4118 const ExprP<double> getConstTwo<double>(void)
4119 {
4120 return constant(2.0);
4121 }
4122
4123 template <int Size, class T>
4124 class Dot : public DerivedFunc<Signature<T, Vector<T, Size>, Vector<T, Size>>>
4125 {
4126 public:
4127 typedef typename Dot::ArgExprs ArgExprs;
4128
getName(void) const4129 string getName(void) const
4130 {
4131 return "dot";
4132 }
4133
4134 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4135 ExprP<T> doExpand(ExpandContext &, const ArgExprs &args) const
4136 {
4137 ExprP<T> op[Size];
4138 // Precompute all products.
4139 for (int ndx = 0; ndx < Size; ++ndx)
4140 op[ndx] = args.a[ndx] * args.b[ndx];
4141
4142 int idx[Size];
4143 //Prepare an array of indices.
4144 for (int ndx = 0; ndx < Size; ++ndx)
4145 idx[ndx] = ndx;
4146
4147 ExprP<T> res = op[0];
4148 // Compute the first dot alternative: SUM(a[i]*b[i]), i = 0 .. Size-1
4149 for (int ndx = 1; ndx < Size; ++ndx)
4150 res = res + op[ndx];
4151
4152 // Generate all permutations of indices and
4153 // using a permutation compute a dot alternative.
4154 // Generates all possible variants fo summation of products in the dot product expansion expression.
4155 do
4156 {
4157 ExprP<T> alt = getConstZero<T>();
4158 for (int ndx = 0; ndx < Size; ++ndx)
4159 alt = alt + op[idx[ndx]];
4160 res = alternatives(res, alt);
4161 } while (std::next_permutation(idx, idx + Size));
4162
4163 return res;
4164 }
4165 };
4166
4167 template <class T>
4168 class Dot<1, T> : public DerivedFunc<Signature<T, T, T>>
4169 {
4170 public:
4171 typedef typename DerivedFunc<Signature<T, T, T>>::ArgExprs TArgExprs;
4172
getName(void) const4173 string getName(void) const
4174 {
4175 return "dot";
4176 }
4177
doExpand(ExpandContext &,const TArgExprs & args) const4178 ExprP<T> doExpand(ExpandContext &, const TArgExprs &args) const
4179 {
4180 return args.a * args.b;
4181 }
4182 };
4183
4184 template <int Size>
dot(const ExprP<Vector<deFloat16,Size>> & x,const ExprP<Vector<deFloat16,Size>> & y)4185 ExprP<deFloat16> dot(const ExprP<Vector<deFloat16, Size>> &x, const ExprP<Vector<deFloat16, Size>> &y)
4186 {
4187 return app<Dot<Size, deFloat16>>(x, y);
4188 }
4189
dot(const ExprP<deFloat16> & x,const ExprP<deFloat16> & y)4190 ExprP<deFloat16> dot(const ExprP<deFloat16> &x, const ExprP<deFloat16> &y)
4191 {
4192 return app<Dot<1, deFloat16>>(x, y);
4193 }
4194
4195 template <int Size>
dot(const ExprP<Vector<float,Size>> & x,const ExprP<Vector<float,Size>> & y)4196 ExprP<float> dot(const ExprP<Vector<float, Size>> &x, const ExprP<Vector<float, Size>> &y)
4197 {
4198 return app<Dot<Size, float>>(x, y);
4199 }
4200
dot(const ExprP<float> & x,const ExprP<float> & y)4201 ExprP<float> dot(const ExprP<float> &x, const ExprP<float> &y)
4202 {
4203 return app<Dot<1, float>>(x, y);
4204 }
4205
4206 template <int Size>
dot(const ExprP<Vector<double,Size>> & x,const ExprP<Vector<double,Size>> & y)4207 ExprP<double> dot(const ExprP<Vector<double, Size>> &x, const ExprP<Vector<double, Size>> &y)
4208 {
4209 return app<Dot<Size, double>>(x, y);
4210 }
4211
dot(const ExprP<double> & x,const ExprP<double> & y)4212 ExprP<double> dot(const ExprP<double> &x, const ExprP<double> &y)
4213 {
4214 return app<Dot<1, double>>(x, y);
4215 }
4216
4217 template <int Size, class T>
4218 class Length : public DerivedFunc<Signature<T, typename ContainerOf<T, Size>::Container>>
4219 {
4220 public:
4221 typedef typename Length::ArgExprs ArgExprs;
4222
getName(void) const4223 string getName(void) const
4224 {
4225 return "length";
4226 }
4227
4228 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4229 ExprP<T> doExpand(ExpandContext &, const ArgExprs &args) const
4230 {
4231 return sqrt(dot(args.a, args.a));
4232 }
4233 };
4234
4235 template <class T, class TRet>
length(const ExprP<T> & x)4236 ExprP<TRet> length(const ExprP<T> &x)
4237 {
4238 return app<Length<1, T>>(x);
4239 }
4240
4241 template <int Size, class T, class TRet>
length(const ExprP<typename ContainerOf<T,Size>::Container> & x)4242 ExprP<TRet> length(const ExprP<typename ContainerOf<T, Size>::Container> &x)
4243 {
4244 return app<Length<Size, T>>(x);
4245 }
4246
4247 template <int Size, class T>
4248 class Distance : public DerivedFunc<
4249 Signature<T, typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container>>
4250 {
4251 public:
4252 typedef typename Distance::Ret Ret;
4253 typedef typename Distance::ArgExprs ArgExprs;
4254
getName(void) const4255 string getName(void) const
4256 {
4257 return "distance";
4258 }
4259
4260 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4261 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
4262 {
4263 return length<Size, T, Ret>(args.a - args.b);
4264 }
4265 };
4266
4267 // cross
4268
4269 class Cross : public DerivedFunc<Signature<Vec3, Vec3, Vec3>>
4270 {
4271 public:
getName(void) const4272 string getName(void) const
4273 {
4274 return "cross";
4275 }
4276
4277 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4278 ExprP<Vec3> doExpand(ExpandContext &, const ArgExprs &x) const
4279 {
4280 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2], x.a[2] * x.b[0] - x.b[2] * x.a[0],
4281 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4282 }
4283 };
4284
4285 class Cross16Bit : public DerivedFunc<Signature<Vec3_16Bit, Vec3_16Bit, Vec3_16Bit>>
4286 {
4287 public:
getName(void) const4288 string getName(void) const
4289 {
4290 return "cross";
4291 }
4292
4293 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4294 ExprP<Vec3_16Bit> doExpand(ExpandContext &, const ArgExprs &x) const
4295 {
4296 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2], x.a[2] * x.b[0] - x.b[2] * x.a[0],
4297 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4298 }
4299 };
4300
4301 class Cross64Bit : public DerivedFunc<Signature<Vec3_64Bit, Vec3_64Bit, Vec3_64Bit>>
4302 {
4303 public:
getName(void) const4304 string getName(void) const
4305 {
4306 return "cross";
4307 }
4308
4309 protected:
doExpand(ExpandContext &,const ArgExprs & x) const4310 ExprP<Vec3_64Bit> doExpand(ExpandContext &, const ArgExprs &x) const
4311 {
4312 return vec3(x.a[1] * x.b[2] - x.b[1] * x.a[2], x.a[2] * x.b[0] - x.b[2] * x.a[0],
4313 x.a[0] * x.b[1] - x.b[0] * x.a[1]);
4314 }
4315 };
4316
4317 DEFINE_CONSTRUCTOR2(Cross, Vec3, cross, Vec3, Vec3)
4318 DEFINE_CONSTRUCTOR2(Cross16Bit, Vec3_16Bit, cross, Vec3_16Bit, Vec3_16Bit)
4319 DEFINE_CONSTRUCTOR2(Cross64Bit, Vec3_64Bit, cross, Vec3_64Bit, Vec3_64Bit)
4320
4321 template <int Size, class T>
4322 class Normalize
4323 : public DerivedFunc<Signature<typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container>>
4324 {
4325 public:
4326 typedef typename Normalize::Ret Ret;
4327 typedef typename Normalize::ArgExprs ArgExprs;
4328
getName(void) const4329 string getName(void) const
4330 {
4331 return "normalize";
4332 }
4333
4334 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4335 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
4336 {
4337 return args.a * app<InverseSqrt<Signature<T, T>>>(dot(args.a, args.a));
4338 }
4339 };
4340
4341 template <int Size, class T>
4342 class FaceForward
4343 : public DerivedFunc<Signature<typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container,
4344 typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container>>
4345 {
4346 public:
4347 typedef typename FaceForward::Ret Ret;
4348 typedef typename FaceForward::ArgExprs ArgExprs;
4349
getName(void) const4350 string getName(void) const
4351 {
4352 return "faceforward";
4353 }
4354
4355 protected:
doExpand(ExpandContext &,const ArgExprs & args) const4356 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
4357 {
4358 return cond(dot(args.c, args.b) < getConstZero<T>(), args.a, -args.a);
4359 }
4360 };
4361
4362 template <int Size, class T>
4363 class Reflect
4364 : public DerivedFunc<Signature<typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container,
4365 typename ContainerOf<T, Size>::Container>>
4366 {
4367 public:
4368 typedef typename Reflect::Ret Ret;
4369 typedef typename Reflect::Arg0 Arg0;
4370 typedef typename Reflect::Arg1 Arg1;
4371 typedef typename Reflect::ArgExprs ArgExprs;
4372
getName(void) const4373 string getName(void) const
4374 {
4375 return "reflect";
4376 }
4377
4378 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4379 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
4380 {
4381 const ExprP<Arg0> &i = args.a;
4382 const ExprP<Arg1> &n = args.b;
4383 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4384
4385 return i - alternatives(
4386 (n * dotNI) * getConstTwo<T>(),
4387 alternatives(n * (dotNI * getConstTwo<T>()),
4388 alternatives(n * dot(i * getConstTwo<T>(), n), n * dot(i, n * getConstTwo<T>()))));
4389 }
4390 };
4391
4392 template <class T>
4393 class Reflect<1, T> : public DerivedFunc<Signature<T, T, T>>
4394 {
4395 public:
4396 typedef typename Reflect::Ret Ret;
4397 typedef typename Reflect::Arg0 Arg0;
4398 typedef typename Reflect::Arg1 Arg1;
4399 typedef typename Reflect::ArgExprs ArgExprs;
4400
getName(void) const4401 string getName(void) const
4402 {
4403 return "reflect";
4404 }
4405
4406 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4407 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
4408 {
4409 const ExprP<Arg0> &i = args.a;
4410 const ExprP<Arg1> &n = args.b;
4411 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4412
4413 return i - alternatives((n * dotNI) * getConstTwo<T>(),
4414 alternatives(n * (dotNI * getConstTwo<T>()),
4415 alternatives(n * dot(i * getConstTwo<T>(), n),
4416 alternatives(n * dot(i, n * getConstTwo<T>()),
4417 dot(n * n, i * getConstTwo<T>())))));
4418 }
4419 };
4420
4421 template <int Size, class T>
4422 class Refract
4423 : public DerivedFunc<Signature<typename ContainerOf<T, Size>::Container, typename ContainerOf<T, Size>::Container,
4424 typename ContainerOf<T, Size>::Container, T>>
4425 {
4426 public:
4427 typedef typename Refract::Ret Ret;
4428 typedef typename Refract::Arg0 Arg0;
4429 typedef typename Refract::Arg1 Arg1;
4430 typedef typename Refract::Arg2 Arg2;
4431 typedef typename Refract::ArgExprs ArgExprs;
4432
getName(void) const4433 string getName(void) const
4434 {
4435 return "refract";
4436 }
4437
4438 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const4439 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
4440 {
4441 const ExprP<Arg0> &i = args.a;
4442 const ExprP<Arg1> &n = args.b;
4443 const ExprP<Arg2> &eta = args.c;
4444 const ExprP<T> dotNI = bindExpression("dotNI", ctx, dot(n, i));
4445 const ExprP<T> k = bindExpression("k", ctx, getConstOne<T>() - eta * eta * (getConstOne<T>() - dotNI * dotNI));
4446 return cond(k < getConstZero<T>(), genXType<T, Size>(getConstZero<T>()), i * eta - n * (eta * dotNI + sqrt(k)));
4447 }
4448 };
4449
4450 template <class T>
4451 class PreciseFunc1 : public CFloatFunc1<T>
4452 {
4453 public:
PreciseFunc1(const string & name,DoubleFunc1 & func)4454 PreciseFunc1(const string &name, DoubleFunc1 &func) : CFloatFunc1<T>(name, func)
4455 {
4456 }
4457
4458 protected:
precision(const EvalContext &,double,double) const4459 double precision(const EvalContext &, double, double) const
4460 {
4461 return 0.0;
4462 }
4463 };
4464
4465 template <class T>
4466 class Abs : public PreciseFunc1<T>
4467 {
4468 public:
Abs(void)4469 Abs(void) : PreciseFunc1<T>("abs", deAbs)
4470 {
4471 }
4472 };
4473
4474 template <class T>
4475 class Sign : public PreciseFunc1<T>
4476 {
4477 public:
Sign(void)4478 Sign(void) : PreciseFunc1<T>("sign", deSign)
4479 {
4480 }
4481 };
4482
4483 template <class T>
4484 class Floor : public PreciseFunc1<T>
4485 {
4486 public:
Floor(void)4487 Floor(void) : PreciseFunc1<T>("floor", deFloor)
4488 {
4489 }
4490 };
4491
4492 template <class T>
4493 class Trunc : public PreciseFunc1<T>
4494 {
4495 public:
Trunc(void)4496 Trunc(void) : PreciseFunc1<T>("trunc", deTrunc)
4497 {
4498 }
4499 };
4500
4501 template <class T>
4502 class Round : public FloatFunc1<T>
4503 {
4504 public:
getName(void) const4505 string getName(void) const
4506 {
4507 return "round";
4508 }
4509
4510 protected:
applyPoint(const EvalContext &,double x) const4511 Interval applyPoint(const EvalContext &, double x) const
4512 {
4513 double truncated = 0.0;
4514 const double fract = deModf(x, &truncated);
4515 Interval ret;
4516
4517 if (fabs(fract) <= 0.5)
4518 ret |= truncated;
4519 if (fabs(fract) >= 0.5)
4520 ret |= truncated + deSign(fract);
4521
4522 return ret;
4523 }
4524
precision(const EvalContext &,double,double) const4525 double precision(const EvalContext &, double, double) const
4526 {
4527 return 0.0;
4528 }
4529 };
4530
4531 template <class T>
4532 class RoundEven : public PreciseFunc1<T>
4533 {
4534 public:
RoundEven(void)4535 RoundEven(void) : PreciseFunc1<T>("roundEven", deRoundEven)
4536 {
4537 }
4538 };
4539
4540 template <class T>
4541 class Ceil : public PreciseFunc1<T>
4542 {
4543 public:
Ceil(void)4544 Ceil(void) : PreciseFunc1<T>("ceil", deCeil)
4545 {
4546 }
4547 };
4548
4549 typedef Floor<Signature<float, float>> Floor32Bit;
4550 typedef Floor<Signature<deFloat16, deFloat16>> Floor16Bit;
4551 typedef Floor<Signature<double, double>> Floor64Bit;
4552
4553 typedef Trunc<Signature<float, float>> Trunc32Bit;
4554 typedef Trunc<Signature<deFloat16, deFloat16>> Trunc16Bit;
4555 typedef Trunc<Signature<double, double>> Trunc64Bit;
4556
4557 typedef Trunc<Signature<float, float>> Trunc32Bit;
4558 typedef Trunc<Signature<deFloat16, deFloat16>> Trunc16Bit;
4559
4560 DEFINE_DERIVED_FLOAT1(Fract, fract, x, x - app<Floor32Bit>(x))
4561 DEFINE_DERIVED_FLOAT1_16BIT(Fract16Bit, fract, x, x - app<Floor16Bit>(x))
4562 DEFINE_DERIVED_DOUBLE1(Fract64Bit, fract, x, x - app<Floor64Bit>(x))
4563
4564 template <class T>
4565 class PreciseFunc2 : public CFloatFunc2<T>
4566 {
4567 public:
PreciseFunc2(const string & name,DoubleFunc2 & func)4568 PreciseFunc2(const string &name, DoubleFunc2 &func) : CFloatFunc2<T>(name, func)
4569 {
4570 }
4571
4572 protected:
precision(const EvalContext &,double,double,double) const4573 double precision(const EvalContext &, double, double, double) const
4574 {
4575 return 0.0;
4576 }
4577 };
4578
4579 DEFINE_DERIVED_FLOAT2(Mod32Bit, mod, x, y, x - y * app<Floor32Bit>(x / y))
4580 DEFINE_DERIVED_FLOAT2_16BIT(Mod16Bit, mod, x, y, x - y * app<Floor16Bit>(x / y))
4581 DEFINE_DERIVED_DOUBLE2(Mod64Bit, mod, x, y, x - y * app<Floor64Bit>(x / y))
4582
4583 DEFINE_CASED_DERIVED_FLOAT2(FRem32Bit, frem, x, y, x - y * app<Trunc32Bit>(x / y), SPIRV_CASETYPE_FREM)
4584 DEFINE_CASED_DERIVED_FLOAT2_16BIT(FRem16Bit, frem, x, y, x - y * app<Trunc16Bit>(x / y), SPIRV_CASETYPE_FREM)
4585 DEFINE_CASED_DERIVED_DOUBLE2(FRem64Bit, frem, x, y, x - y * app<Trunc64Bit>(x / y), SPIRV_CASETYPE_FREM)
4586
4587 template <class T>
4588 class Modf : public PrimitiveFunc<T>
4589 {
4590 public:
4591 typedef typename Modf<T>::IArgs TIArgs;
4592 typedef typename Modf<T>::IRet TIRet;
getName(void) const4593 string getName(void) const
4594 {
4595 return "modf";
4596 }
4597
4598 protected:
doApply(const EvalContext & ctx,const TIArgs & iargs) const4599 TIRet doApply(const EvalContext &ctx, const TIArgs &iargs) const
4600 {
4601 Interval fracIV;
4602 Interval &wholeIV = const_cast<Interval &>(iargs.b);
4603 double intPart = 0;
4604
4605 TCU_INTERVAL_APPLY_MONOTONE1(fracIV, x, iargs.a, frac, frac = deModf(x, &intPart));
4606 TCU_INTERVAL_APPLY_MONOTONE1(wholeIV, x, iargs.a, whole, deModf(x, &intPart); whole = intPart);
4607
4608 if (!iargs.a.isFinite(ctx.format.getMaxValue()))
4609 {
4610 // Behavior on modf(Inf) not well-defined, allow anything as a fractional part
4611 // See Khronos bug 13907
4612 fracIV |= TCU_NAN;
4613 }
4614
4615 return fracIV;
4616 }
4617
getOutParamIndex(void) const4618 int getOutParamIndex(void) const
4619 {
4620 return 1;
4621 }
4622 };
4623 typedef Modf<Signature<float, float, float>> Modf32Bit;
4624 typedef Modf<Signature<deFloat16, deFloat16, deFloat16>> Modf16Bit;
4625 typedef Modf<Signature<double, double, double>> Modf64Bit;
4626
4627 template <class T>
4628 class ModfStruct : public Modf<T>
4629 {
4630 public:
getName(void) const4631 virtual string getName(void) const
4632 {
4633 return "modfstruct";
4634 }
getSpirvCase(void) const4635 virtual SpirVCaseT getSpirvCase(void) const
4636 {
4637 return SPIRV_CASETYPE_MODFSTRUCT;
4638 }
4639 };
4640 typedef ModfStruct<Signature<float, float, float>> ModfStruct32Bit;
4641 typedef ModfStruct<Signature<deFloat16, deFloat16, deFloat16>> ModfStruct16Bit;
4642 typedef ModfStruct<Signature<double, double, double>> ModfStruct64Bit;
4643
4644 template <class T>
4645 class Min : public PreciseFunc2<T>
4646 {
4647 public:
Min(void)4648 Min(void) : PreciseFunc2<T>("min", deMin)
4649 {
4650 }
4651 };
4652 template <class T>
4653 class Max : public PreciseFunc2<T>
4654 {
4655 public:
Max(void)4656 Max(void) : PreciseFunc2<T>("max", deMax)
4657 {
4658 }
4659 };
4660
4661 template <class T>
4662 class Clamp : public FloatFunc3<T>
4663 {
4664 public:
getName(void) const4665 string getName(void) const
4666 {
4667 return "clamp";
4668 }
4669
applyExact(double x,double minVal,double maxVal) const4670 double applyExact(double x, double minVal, double maxVal) const
4671 {
4672 return de::min(de::max(x, minVal), maxVal);
4673 }
4674
precision(const EvalContext &,double,double,double minVal,double maxVal) const4675 double precision(const EvalContext &, double, double, double minVal, double maxVal) const
4676 {
4677 return minVal > maxVal ? TCU_NAN : 0.0;
4678 }
4679 };
4680
clamp(const ExprP<deFloat16> & x,const ExprP<deFloat16> & minVal,const ExprP<deFloat16> & maxVal)4681 ExprP<deFloat16> clamp(const ExprP<deFloat16> &x, const ExprP<deFloat16> &minVal, const ExprP<deFloat16> &maxVal)
4682 {
4683 return app<Clamp<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>>(x, minVal, maxVal);
4684 }
4685
clamp(const ExprP<float> & x,const ExprP<float> & minVal,const ExprP<float> & maxVal)4686 ExprP<float> clamp(const ExprP<float> &x, const ExprP<float> &minVal, const ExprP<float> &maxVal)
4687 {
4688 return app<Clamp<Signature<float, float, float, float>>>(x, minVal, maxVal);
4689 }
4690
clamp(const ExprP<double> & x,const ExprP<double> & minVal,const ExprP<double> & maxVal)4691 ExprP<double> clamp(const ExprP<double> &x, const ExprP<double> &minVal, const ExprP<double> &maxVal)
4692 {
4693 return app<Clamp<Signature<double, double, double, double>>>(x, minVal, maxVal);
4694 }
4695
4696 template <class T>
4697 class NanIfGreaterOrEqual : public FloatFunc2<T>
4698 {
4699 public:
getName(void) const4700 string getName(void) const
4701 {
4702 return "nanIfGreaterOrEqual";
4703 }
4704
applyExact(double edge0,double edge1) const4705 double applyExact(double edge0, double edge1) const
4706 {
4707 return (edge0 >= edge1) ? TCU_NAN : 0.0;
4708 }
4709
precision(const EvalContext &,double,double edge0,double edge1) const4710 double precision(const EvalContext &, double, double edge0, double edge1) const
4711 {
4712 return (edge0 >= edge1) ? TCU_NAN : 0.0;
4713 }
4714 };
4715
nanIfGreaterOrEqual(const ExprP<deFloat16> & edge0,const ExprP<deFloat16> & edge1)4716 ExprP<deFloat16> nanIfGreaterOrEqual(const ExprP<deFloat16> &edge0, const ExprP<deFloat16> &edge1)
4717 {
4718 return app<NanIfGreaterOrEqual<Signature<deFloat16, deFloat16, deFloat16>>>(edge0, edge1);
4719 }
4720
nanIfGreaterOrEqual(const ExprP<float> & edge0,const ExprP<float> & edge1)4721 ExprP<float> nanIfGreaterOrEqual(const ExprP<float> &edge0, const ExprP<float> &edge1)
4722 {
4723 return app<NanIfGreaterOrEqual<Signature<float, float, float>>>(edge0, edge1);
4724 }
4725
nanIfGreaterOrEqual(const ExprP<double> & edge0,const ExprP<double> & edge1)4726 ExprP<double> nanIfGreaterOrEqual(const ExprP<double> &edge0, const ExprP<double> &edge1)
4727 {
4728 return app<NanIfGreaterOrEqual<Signature<double, double, double>>>(edge0, edge1);
4729 }
4730
4731 DEFINE_DERIVED_FLOAT3(Mix, mix, x, y, a, alternatives((x * (constant(1.0f) - a)) + y * a, x + (y - x) * a))
4732
4733 DEFINE_DERIVED_FLOAT3_16BIT(Mix16Bit, mix, x, y, a,
4734 alternatives((x * (constant((deFloat16)FLOAT16_1_0) - a)) + y * a, x + (y - x) * a))
4735
4736 DEFINE_DERIVED_DOUBLE3(Mix64Bit, mix, x, y, a, alternatives((x * (constant(1.0) - a)) + y * a, x + (y - x) * a))
4737
step(double edge,double x)4738 static double step(double edge, double x)
4739 {
4740 return x < edge ? 0.0 : 1.0;
4741 }
4742
4743 template <class T>
4744 class Step : public PreciseFunc2<T>
4745 {
4746 public:
Step(void)4747 Step(void) : PreciseFunc2<T>("step", step)
4748 {
4749 }
4750 };
4751
4752 template <class T>
4753 class SmoothStep : public DerivedFunc<T>
4754 {
4755 public:
4756 typedef typename SmoothStep<T>::ArgExprs TArgExprs;
4757 typedef typename SmoothStep<T>::Ret TRet;
getName(void) const4758 string getName(void) const
4759 {
4760 return "smoothstep";
4761 }
4762
4763 protected:
4764 ExprP<TRet> doExpand(ExpandContext &ctx, const TArgExprs &args) const;
4765 };
4766
4767 template <>
4768 ExprP<SmoothStep<Signature<float, float, float, float>>::Ret> SmoothStep<Signature<float, float, float, float>>::
doExpand(ExpandContext & ctx,const SmoothStep<Signature<float,float,float,float>>::ArgExprs & args) const4769 doExpand(ExpandContext &ctx, const SmoothStep<Signature<float, float, float, float>>::ArgExprs &args) const
4770 {
4771 const ExprP<float> &edge0 = args.a;
4772 const ExprP<float> &edge1 = args.b;
4773 const ExprP<float> &x = args.c;
4774 const ExprP<float> tExpr =
4775 clamp((x - edge0) / (edge1 - edge0), constant(0.0f), constant(1.0f)) +
4776 nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4777 const ExprP<float> t = bindExpression("t", ctx, tExpr);
4778
4779 return (t * t * (constant(3.0f) - constant(2.0f) * t));
4780 }
4781
4782 template <>
4783 ExprP<SmoothStep<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>::TRet> SmoothStep<
doExpand(ExpandContext & ctx,const TArgExprs & args) const4784 Signature<deFloat16, deFloat16, deFloat16, deFloat16>>::doExpand(ExpandContext &ctx, const TArgExprs &args) const
4785 {
4786 const ExprP<deFloat16> &edge0 = args.a;
4787 const ExprP<deFloat16> &edge1 = args.b;
4788 const ExprP<deFloat16> &x = args.c;
4789 const ExprP<deFloat16> tExpr =
4790 clamp((x - edge0) / (edge1 - edge0), constant((deFloat16)FLOAT16_0_0), constant((deFloat16)FLOAT16_1_0)) +
4791 nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4792 const ExprP<deFloat16> t = bindExpression("t", ctx, tExpr);
4793
4794 return (t * t * (constant((deFloat16)FLOAT16_3_0) - constant((deFloat16)FLOAT16_2_0) * t));
4795 }
4796
4797 template <>
4798 ExprP<SmoothStep<Signature<double, double, double, double>>::Ret> SmoothStep<
4799 Signature<double, double, double, double>>::
doExpand(ExpandContext & ctx,const SmoothStep<Signature<double,double,double,double>>::ArgExprs & args) const4800 doExpand(ExpandContext &ctx, const SmoothStep<Signature<double, double, double, double>>::ArgExprs &args) const
4801 {
4802 const ExprP<double> &edge0 = args.a;
4803 const ExprP<double> &edge1 = args.b;
4804 const ExprP<double> &x = args.c;
4805 const ExprP<double> tExpr =
4806 clamp((x - edge0) / (edge1 - edge0), constant(0.0), constant(1.0)) +
4807 nanIfGreaterOrEqual(edge0, edge1); // force NaN (and non-analyzable result) for cases edge0 >= edge1
4808 const ExprP<double> t = bindExpression("t", ctx, tExpr);
4809
4810 return (t * t * (constant(3.0) - constant(2.0) * t));
4811 }
4812
4813 //Signature<float, float, int>
4814 //Signature<float, deFloat16, int>
4815 //Signature<double, double, int>
4816 template <class T>
4817 class FrExp : public PrimitiveFunc<T>
4818 {
4819 public:
getName(void) const4820 string getName(void) const
4821 {
4822 return "frexp";
4823 }
4824
4825 typedef typename FrExp::IRet IRet;
4826 typedef typename FrExp::IArgs IArgs;
4827 typedef typename FrExp::IArg0 IArg0;
4828 typedef typename FrExp::IArg1 IArg1;
4829
4830 protected:
doApply(const EvalContext &,const IArgs & iargs) const4831 IRet doApply(const EvalContext &, const IArgs &iargs) const
4832 {
4833 IRet ret;
4834 const IArg0 &x = iargs.a;
4835 IArg1 &exponent = const_cast<IArg1 &>(iargs.b);
4836
4837 if (x.hasNaN() || x.contains(TCU_INFINITY) || x.contains(-TCU_INFINITY))
4838 {
4839 // GLSL (in contrast to IEEE) says that result of applying frexp
4840 // to infinity is undefined
4841 ret = Interval::unbounded() | TCU_NAN;
4842 exponent = Interval(-deLdExp(1.0, 31), deLdExp(1.0, 31) - 1);
4843 }
4844 else if (!x.empty())
4845 {
4846 int loExp = 0;
4847 const double loFrac = deFrExp(x.lo(), &loExp);
4848 int hiExp = 0;
4849 const double hiFrac = deFrExp(x.hi(), &hiExp);
4850
4851 if (deSign(loFrac) != deSign(hiFrac))
4852 {
4853 exponent = Interval(-TCU_INFINITY, de::max(loExp, hiExp));
4854 ret = Interval();
4855 if (deSign(loFrac) < 0)
4856 ret |= Interval(-1.0 + DBL_EPSILON * 0.5, 0.0);
4857 if (deSign(hiFrac) > 0)
4858 ret |= Interval(0.0, 1.0 - DBL_EPSILON * 0.5);
4859 }
4860 else
4861 {
4862 exponent = Interval(loExp, hiExp);
4863 if (loExp == hiExp)
4864 ret = Interval(loFrac, hiFrac);
4865 else
4866 ret = deSign(loFrac) * Interval(0.5, 1.0 - DBL_EPSILON * 0.5);
4867 }
4868 }
4869
4870 return ret;
4871 }
4872
getOutParamIndex(void) const4873 int getOutParamIndex(void) const
4874 {
4875 return 1;
4876 }
4877 };
4878 typedef FrExp<Signature<float, float, int>> Frexp32Bit;
4879 typedef FrExp<Signature<deFloat16, deFloat16, int>> Frexp16Bit;
4880 typedef FrExp<Signature<double, double, int>> Frexp64Bit;
4881
4882 template <class T>
4883 class FrexpStruct : public FrExp<T>
4884 {
4885 public:
getName(void) const4886 virtual string getName(void) const
4887 {
4888 return "frexpstruct";
4889 }
getSpirvCase(void) const4890 virtual SpirVCaseT getSpirvCase(void) const
4891 {
4892 return SPIRV_CASETYPE_FREXPSTRUCT;
4893 }
4894 };
4895 typedef FrexpStruct<Signature<float, float, int>> FrexpStruct32Bit;
4896 typedef FrexpStruct<Signature<deFloat16, deFloat16, int>> FrexpStruct16Bit;
4897 typedef FrexpStruct<Signature<double, double, int>> FrexpStruct64Bit;
4898
4899 //Signature<float, float, int>
4900 //Signature<deFloat16, deFloat16, int>
4901 //Signature<double, double, int>
4902 template <class T>
4903 class LdExp : public PrimitiveFunc<T>
4904 {
4905 public:
4906 typedef typename LdExp::IRet IRet;
4907 typedef typename LdExp::IArgs IArgs;
4908
getName(void) const4909 string getName(void) const
4910 {
4911 return "ldexp";
4912 }
4913
4914 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const4915 Interval doApply(const EvalContext &ctx, const IArgs &iargs) const
4916 {
4917 const int minExp = ctx.format.getMinExp();
4918 const int maxExp = ctx.format.getMaxExp();
4919 // Restrictions from the GLSL.std.450 instruction set.
4920 // See Khronos bugzilla 11180 for rationale.
4921 bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4922 Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4923 if (iargs.b.lo() < minExp)
4924 ret |= 0.0;
4925 if (!ret.isFinite(ctx.format.getMaxValue()))
4926 ret |= TCU_NAN;
4927 return ctx.format.convert(ret);
4928 }
4929 };
4930
4931 template <>
doApply(const EvalContext & ctx,const IArgs & iargs) const4932 Interval LdExp<Signature<double, double, int>>::doApply(const EvalContext &ctx, const IArgs &iargs) const
4933 {
4934 const int minExp = ctx.format.getMinExp();
4935 const int maxExp = ctx.format.getMaxExp();
4936 // Restrictions from the GLSL.std.450 instruction set.
4937 // See Khronos bugzilla 11180 for rationale.
4938 bool any = iargs.a.hasNaN() || iargs.b.hi() > (maxExp + 1);
4939 Interval ret(any, ldexp(iargs.a.lo(), (int)iargs.b.lo()), ldexp(iargs.a.hi(), (int)iargs.b.hi()));
4940 // Add 1ULP precision tolerance to account for differing rounding modes between the GPU and deLdExp.
4941 ret += Interval(-ctx.format.ulp(ret.lo()), ctx.format.ulp(ret.hi()));
4942 if (iargs.b.lo() < minExp)
4943 ret |= 0.0;
4944 if (!ret.isFinite(ctx.format.getMaxValue()))
4945 ret |= TCU_NAN;
4946 return ctx.format.convert(ret);
4947 }
4948
4949 template <int Rows, int Columns, class T>
4950 class Transpose : public PrimitiveFunc<Signature<Matrix<T, Rows, Columns>, Matrix<T, Columns, Rows>>>
4951 {
4952 public:
4953 typedef typename Transpose::IRet IRet;
4954 typedef typename Transpose::IArgs IArgs;
4955
getName(void) const4956 string getName(void) const
4957 {
4958 return "transpose";
4959 }
4960
4961 protected:
doApply(const EvalContext &,const IArgs & iargs) const4962 IRet doApply(const EvalContext &, const IArgs &iargs) const
4963 {
4964 IRet ret;
4965
4966 for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
4967 {
4968 for (int colNdx = 0; colNdx < Columns; ++colNdx)
4969 ret(rowNdx, colNdx) = iargs.a(colNdx, rowNdx);
4970 }
4971
4972 return ret;
4973 }
4974 };
4975
4976 template <typename Ret, typename Arg0, typename Arg1>
4977 class MulFunc : public PrimitiveFunc<Signature<Ret, Arg0, Arg1>>
4978 {
4979 public:
getName(void) const4980 string getName(void) const
4981 {
4982 return "mul";
4983 }
4984
4985 protected:
doPrint(ostream & os,const BaseArgExprs & args) const4986 void doPrint(ostream &os, const BaseArgExprs &args) const
4987 {
4988 os << "(" << *args[0] << " * " << *args[1] << ")";
4989 }
4990 };
4991
4992 template <typename T, int LeftRows, int Middle, int RightCols>
4993 class MatMul : public MulFunc<Matrix<T, LeftRows, RightCols>, Matrix<T, LeftRows, Middle>, Matrix<T, Middle, RightCols>>
4994 {
4995 protected:
4996 typedef typename MatMul::IRet IRet;
4997 typedef typename MatMul::IArgs IArgs;
4998 typedef typename MatMul::IArg0 IArg0;
4999 typedef typename MatMul::IArg1 IArg1;
5000
doApply(const EvalContext & ctx,const IArgs & iargs) const5001 IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5002 {
5003 const IArg0 &left = iargs.a;
5004 const IArg1 &right = iargs.b;
5005 IRet ret;
5006
5007 for (int row = 0; row < LeftRows; ++row)
5008 {
5009 for (int col = 0; col < RightCols; ++col)
5010 {
5011 Interval element(0.0);
5012
5013 for (int ndx = 0; ndx < Middle; ++ndx)
5014 element = call<Add<Signature<T, T, T>>>(
5015 ctx, element, call<Mul<Signature<T, T, T>>>(ctx, left[ndx][row], right[col][ndx]));
5016
5017 ret[col][row] = element;
5018 }
5019 }
5020
5021 return ret;
5022 }
5023 };
5024
5025 template <typename T, int Rows, int Cols>
5026 class VecMatMul : public MulFunc<Vector<T, Cols>, Vector<T, Rows>, Matrix<T, Rows, Cols>>
5027 {
5028 public:
5029 typedef typename VecMatMul::IRet IRet;
5030 typedef typename VecMatMul::IArgs IArgs;
5031 typedef typename VecMatMul::IArg0 IArg0;
5032 typedef typename VecMatMul::IArg1 IArg1;
5033
5034 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const5035 IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5036 {
5037 const IArg0 &left = iargs.a;
5038 const IArg1 &right = iargs.b;
5039 IRet ret;
5040
5041 for (int col = 0; col < Cols; ++col)
5042 {
5043 Interval element(0.0);
5044
5045 for (int row = 0; row < Rows; ++row)
5046 element = call<Add<Signature<T, T, T>>>(ctx, element,
5047 call<Mul<Signature<T, T, T>>>(ctx, left[row], right[col][row]));
5048
5049 ret[col] = element;
5050 }
5051
5052 return ret;
5053 }
5054 };
5055
5056 template <int Rows, int Cols, class T>
5057 class MatVecMul : public MulFunc<Vector<T, Rows>, Matrix<T, Rows, Cols>, Vector<T, Cols>>
5058 {
5059 public:
5060 typedef typename MatVecMul::IRet IRet;
5061 typedef typename MatVecMul::IArgs IArgs;
5062 typedef typename MatVecMul::IArg0 IArg0;
5063 typedef typename MatVecMul::IArg1 IArg1;
5064
5065 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const5066 IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5067 {
5068 const IArg0 &left = iargs.a;
5069 const IArg1 &right = iargs.b;
5070
5071 return call<VecMatMul<T, Cols, Rows>>(ctx, right, call<Transpose<Rows, Cols, T>>(ctx, left));
5072 }
5073 };
5074
5075 template <int Rows, int Cols, class T>
5076 class OuterProduct : public PrimitiveFunc<Signature<Matrix<T, Rows, Cols>, Vector<T, Rows>, Vector<T, Cols>>>
5077 {
5078 public:
5079 typedef typename OuterProduct::IRet IRet;
5080 typedef typename OuterProduct::IArgs IArgs;
5081
getName(void) const5082 string getName(void) const
5083 {
5084 return "outerProduct";
5085 }
5086
5087 protected:
doApply(const EvalContext & ctx,const IArgs & iargs) const5088 IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5089 {
5090 IRet ret;
5091
5092 for (int row = 0; row < Rows; ++row)
5093 {
5094 for (int col = 0; col < Cols; ++col)
5095 ret[col][row] = call<Mul<Signature<T, T, T>>>(ctx, iargs.a[row], iargs.b[col]);
5096 }
5097
5098 return ret;
5099 }
5100 };
5101
5102 template <int Rows, int Cols, class T>
outerProduct(const ExprP<Vector<T,Rows>> & left,const ExprP<Vector<T,Cols>> & right)5103 ExprP<Matrix<T, Rows, Cols>> outerProduct(const ExprP<Vector<T, Rows>> &left, const ExprP<Vector<T, Cols>> &right)
5104 {
5105 return app<OuterProduct<Rows, Cols, T>>(left, right);
5106 }
5107
5108 template <class T>
5109 class DeterminantBase : public DerivedFunc<T>
5110 {
5111 public:
getName(void) const5112 string getName(void) const
5113 {
5114 return "determinant";
5115 }
5116 };
5117
5118 template <int Size>
5119 class Determinant;
5120 template <int Size>
5121 class Determinant16bit;
5122 template <int Size>
5123 class Determinant64bit;
5124
5125 template <int Size>
determinant(ExprP<Matrix<float,Size,Size>> mat)5126 ExprP<float> determinant(ExprP<Matrix<float, Size, Size>> mat)
5127 {
5128 return app<Determinant<Size>>(mat);
5129 }
5130
5131 template <int Size>
determinant(ExprP<Matrix<deFloat16,Size,Size>> mat)5132 ExprP<deFloat16> determinant(ExprP<Matrix<deFloat16, Size, Size>> mat)
5133 {
5134 return app<Determinant16bit<Size>>(mat);
5135 }
5136
5137 template <int Size>
determinant(ExprP<Matrix<double,Size,Size>> mat)5138 ExprP<double> determinant(ExprP<Matrix<double, Size, Size>> mat)
5139 {
5140 return app<Determinant64bit<Size>>(mat);
5141 }
5142
5143 template <>
5144 class Determinant<2> : public DeterminantBase<Signature<float, Matrix<float, 2, 2>>>
5145 {
5146 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5147 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5148 {
5149 ExprP<Mat2> mat = args.a;
5150
5151 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
5152 }
5153 };
5154
5155 template <>
5156 class Determinant<3> : public DeterminantBase<Signature<float, Matrix<float, 3, 3>>>
5157 {
5158 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5159 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5160 {
5161 ExprP<Mat3> mat = args.a;
5162
5163 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
5164 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
5165 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
5166 }
5167 };
5168
5169 template <>
5170 class Determinant<4> : public DeterminantBase<Signature<float, Matrix<float, 4, 4>>>
5171 {
5172 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5173 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5174 {
5175 ExprP<Mat4> mat = args.a;
5176 ExprP<Mat3> minors[4];
5177
5178 for (int ndx = 0; ndx < 4; ++ndx)
5179 {
5180 ExprP<Vec4> minorColumns[3];
5181 ExprP<Vec3> columns[3];
5182
5183 for (int col = 0; col < 3; ++col)
5184 minorColumns[col] = mat[col < ndx ? col : col + 1];
5185
5186 for (int col = 0; col < 3; ++col)
5187 columns[col] = vec3(minorColumns[0][col + 1], minorColumns[1][col + 1], minorColumns[2][col + 1]);
5188
5189 minors[ndx] = bindExpression("minor", ctx, mat3(columns[0], columns[1], columns[2]));
5190 }
5191
5192 return (mat[0][0] * determinant(minors[0]) - mat[1][0] * determinant(minors[1]) +
5193 mat[2][0] * determinant(minors[2]) - mat[3][0] * determinant(minors[3]));
5194 }
5195 };
5196
5197 template <>
5198 class Determinant16bit<2> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 2, 2>>>
5199 {
5200 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5201 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5202 {
5203 ExprP<Mat2_16b> mat = args.a;
5204
5205 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
5206 }
5207 };
5208
5209 template <>
5210 class Determinant16bit<3> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 3, 3>>>
5211 {
5212 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5213 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5214 {
5215 ExprP<Mat3_16b> mat = args.a;
5216
5217 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
5218 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
5219 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
5220 }
5221 };
5222
5223 template <>
5224 class Determinant16bit<4> : public DeterminantBase<Signature<deFloat16, Matrix<deFloat16, 4, 4>>>
5225 {
5226 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5227 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5228 {
5229 ExprP<Mat4_16b> mat = args.a;
5230 ExprP<Mat3_16b> minors[4];
5231
5232 for (int ndx = 0; ndx < 4; ++ndx)
5233 {
5234 ExprP<Vec4_16Bit> minorColumns[3];
5235 ExprP<Vec3_16Bit> columns[3];
5236
5237 for (int col = 0; col < 3; ++col)
5238 minorColumns[col] = mat[col < ndx ? col : col + 1];
5239
5240 for (int col = 0; col < 3; ++col)
5241 columns[col] = vec3(minorColumns[0][col + 1], minorColumns[1][col + 1], minorColumns[2][col + 1]);
5242
5243 minors[ndx] = bindExpression("minor", ctx, mat3(columns[0], columns[1], columns[2]));
5244 }
5245
5246 return (mat[0][0] * determinant(minors[0]) - mat[1][0] * determinant(minors[1]) +
5247 mat[2][0] * determinant(minors[2]) - mat[3][0] * determinant(minors[3]));
5248 }
5249 };
5250
5251 template <>
5252 class Determinant64bit<2> : public DeterminantBase<Signature<double, Matrix<double, 2, 2>>>
5253 {
5254 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5255 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5256 {
5257 ExprP<Matrix2d> mat = args.a;
5258
5259 return mat[0][0] * mat[1][1] - mat[1][0] * mat[0][1];
5260 }
5261 };
5262
5263 template <>
5264 class Determinant64bit<3> : public DeterminantBase<Signature<double, Matrix<double, 3, 3>>>
5265 {
5266 protected:
doExpand(ExpandContext &,const ArgExprs & args) const5267 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &args) const
5268 {
5269 ExprP<Matrix3d> mat = args.a;
5270
5271 return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) +
5272 mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * mat[2][2]) +
5273 mat[0][2] * (mat[1][0] * mat[2][1] - mat[1][1] * mat[2][0]));
5274 }
5275 };
5276
5277 template <>
5278 class Determinant64bit<4> : public DeterminantBase<Signature<double, Matrix<double, 4, 4>>>
5279 {
5280 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5281 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5282 {
5283 ExprP<Matrix4d> mat = args.a;
5284 ExprP<Matrix3d> minors[4];
5285
5286 for (int ndx = 0; ndx < 4; ++ndx)
5287 {
5288 ExprP<Vec4_64Bit> minorColumns[3];
5289 ExprP<Vec3_64Bit> columns[3];
5290
5291 for (int col = 0; col < 3; ++col)
5292 minorColumns[col] = mat[col < ndx ? col : col + 1];
5293
5294 for (int col = 0; col < 3; ++col)
5295 columns[col] = vec3(minorColumns[0][col + 1], minorColumns[1][col + 1], minorColumns[2][col + 1]);
5296
5297 minors[ndx] = bindExpression("minor", ctx, mat3(columns[0], columns[1], columns[2]));
5298 }
5299
5300 return (mat[0][0] * determinant(minors[0]) - mat[1][0] * determinant(minors[1]) +
5301 mat[2][0] * determinant(minors[2]) - mat[3][0] * determinant(minors[3]));
5302 }
5303 };
5304
5305 template <int Size>
5306 class Inverse;
5307
5308 template <int Size>
inverse(ExprP<Matrix<float,Size,Size>> mat)5309 ExprP<Matrix<float, Size, Size>> inverse(ExprP<Matrix<float, Size, Size>> mat)
5310 {
5311 return app<Inverse<Size>>(mat);
5312 }
5313
5314 template <int Size>
5315 class Inverse16bit;
5316
5317 template <int Size>
inverse(ExprP<Matrix<deFloat16,Size,Size>> mat)5318 ExprP<Matrix<deFloat16, Size, Size>> inverse(ExprP<Matrix<deFloat16, Size, Size>> mat)
5319 {
5320 return app<Inverse16bit<Size>>(mat);
5321 }
5322
5323 template <int Size>
5324 class Inverse64bit;
5325
5326 template <int Size>
inverse(ExprP<Matrix<double,Size,Size>> mat)5327 ExprP<Matrix<double, Size, Size>> inverse(ExprP<Matrix<double, Size, Size>> mat)
5328 {
5329 return app<Inverse64bit<Size>>(mat);
5330 }
5331
5332 template <>
5333 class Inverse<2> : public DerivedFunc<Signature<Mat2, Mat2>>
5334 {
5335 public:
getName(void) const5336 string getName(void) const
5337 {
5338 return "inverse";
5339 }
5340
5341 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5342 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5343 {
5344 ExprP<Mat2> mat = args.a;
5345 ExprP<float> det = bindExpression("det", ctx, determinant(mat));
5346
5347 return mat2(vec2(mat[1][1] / det, -mat[0][1] / det), vec2(-mat[1][0] / det, mat[0][0] / det));
5348 }
5349 };
5350
5351 template <>
5352 class Inverse<3> : public DerivedFunc<Signature<Mat3, Mat3>>
5353 {
5354 public:
getName(void) const5355 string getName(void) const
5356 {
5357 return "inverse";
5358 }
5359
5360 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5361 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5362 {
5363 ExprP<Mat3> mat = args.a;
5364 ExprP<Mat2> invA =
5365 bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5366
5367 ExprP<Vec2> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5368 ExprP<Vec2> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5369 ExprP<float> matD = bindExpression("matD", ctx, mat[2][2]);
5370
5371 ExprP<float> schur = bindExpression("schur", ctx, constant(1.0f) / (matD - dot(matC * invA, matB)));
5372
5373 ExprP<Vec2> t1 = invA * matB;
5374 ExprP<Vec2> t2 = t1 * schur;
5375 ExprP<Mat2> t3 = outerProduct(t2, matC);
5376 ExprP<Mat2> t4 = t3 * invA;
5377 ExprP<Mat2> t5 = invA + t4;
5378 ExprP<Mat2> blockA = bindExpression("blockA", ctx, t5);
5379 ExprP<Vec2> blockB = bindExpression("blockB", ctx, (invA * matB) * -schur);
5380 ExprP<Vec2> blockC = bindExpression("blockC", ctx, (matC * invA) * -schur);
5381
5382 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]), vec3(blockA[1][0], blockA[1][1], blockC[1]),
5383 vec3(blockB[0], blockB[1], schur));
5384 }
5385 };
5386
5387 template <>
5388 class Inverse<4> : public DerivedFunc<Signature<Mat4, Mat4>>
5389 {
5390 public:
getName(void) const5391 string getName(void) const
5392 {
5393 return "inverse";
5394 }
5395
5396 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5397 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5398 {
5399 ExprP<Mat4> mat = args.a;
5400 ExprP<Mat2> invA =
5401 bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5402 ExprP<Mat2> matB = bindExpression("matB", ctx, mat2(vec2(mat[2][0], mat[2][1]), vec2(mat[3][0], mat[3][1])));
5403 ExprP<Mat2> matC = bindExpression("matC", ctx, mat2(vec2(mat[0][2], mat[0][3]), vec2(mat[1][2], mat[1][3])));
5404 ExprP<Mat2> matD = bindExpression("matD", ctx, mat2(vec2(mat[2][2], mat[2][3]), vec2(mat[3][2], mat[3][3])));
5405 ExprP<Mat2> schur = bindExpression("schur", ctx, inverse(matD + -(matC * invA * matB)));
5406 ExprP<Mat2> blockA = bindExpression("blockA", ctx, invA + (invA * matB * schur * matC * invA));
5407 ExprP<Mat2> blockB = bindExpression("blockB", ctx, (-invA) * matB * schur);
5408 ExprP<Mat2> blockC = bindExpression("blockC", ctx, (-schur) * matC * invA);
5409
5410 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5411 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5412 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5413 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5414 }
5415 };
5416
5417 template <>
5418 class Inverse16bit<2> : public DerivedFunc<Signature<Mat2_16b, Mat2_16b>>
5419 {
5420 public:
getName(void) const5421 string getName(void) const
5422 {
5423 return "inverse";
5424 }
5425
5426 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5427 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5428 {
5429 ExprP<Mat2_16b> mat = args.a;
5430 ExprP<deFloat16> det = bindExpression("det", ctx, determinant(mat));
5431
5432 return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)), vec2((-mat[1][0] / det), (mat[0][0] / det)));
5433 }
5434 };
5435
5436 template <>
5437 class Inverse16bit<3> : public DerivedFunc<Signature<Mat3_16b, Mat3_16b>>
5438 {
5439 public:
getName(void) const5440 string getName(void) const
5441 {
5442 return "inverse";
5443 }
5444
5445 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5446 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5447 {
5448 ExprP<Mat3_16b> mat = args.a;
5449 ExprP<Mat2_16b> invA =
5450 bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5451
5452 ExprP<Vec2_16Bit> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5453 ExprP<Vec2_16Bit> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5454 ExprP<Mat3_16b::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5455
5456 ExprP<Mat3_16b::Scalar> schur =
5457 bindExpression("schur", ctx, constant((deFloat16)FLOAT16_1_0) / (matD - dot(matC * invA, matB)));
5458
5459 ExprP<Vec2_16Bit> t1 = invA * matB;
5460 ExprP<Vec2_16Bit> t2 = t1 * schur;
5461 ExprP<Mat2_16b> t3 = outerProduct(t2, matC);
5462 ExprP<Mat2_16b> t4 = t3 * invA;
5463 ExprP<Mat2_16b> t5 = invA + t4;
5464 ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx, t5);
5465 ExprP<Vec2_16Bit> blockB = bindExpression("blockB", ctx, (invA * matB) * -schur);
5466 ExprP<Vec2_16Bit> blockC = bindExpression("blockC", ctx, (matC * invA) * -schur);
5467
5468 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]), vec3(blockA[1][0], blockA[1][1], blockC[1]),
5469 vec3(blockB[0], blockB[1], schur));
5470 }
5471 };
5472
5473 template <>
5474 class Inverse16bit<4> : public DerivedFunc<Signature<Mat4_16b, Mat4_16b>>
5475 {
5476 public:
getName(void) const5477 string getName(void) const
5478 {
5479 return "inverse";
5480 }
5481
5482 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5483 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5484 {
5485 ExprP<Mat4_16b> mat = args.a;
5486 ExprP<Mat2_16b> invA =
5487 bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5488 ExprP<Mat2_16b> matB =
5489 bindExpression("matB", ctx, mat2(vec2(mat[2][0], mat[2][1]), vec2(mat[3][0], mat[3][1])));
5490 ExprP<Mat2_16b> matC =
5491 bindExpression("matC", ctx, mat2(vec2(mat[0][2], mat[0][3]), vec2(mat[1][2], mat[1][3])));
5492 ExprP<Mat2_16b> matD =
5493 bindExpression("matD", ctx, mat2(vec2(mat[2][2], mat[2][3]), vec2(mat[3][2], mat[3][3])));
5494 ExprP<Mat2_16b> schur = bindExpression("schur", ctx, inverse(matD + -(matC * invA * matB)));
5495 ExprP<Mat2_16b> blockA = bindExpression("blockA", ctx, invA + (invA * matB * schur * matC * invA));
5496 ExprP<Mat2_16b> blockB = bindExpression("blockB", ctx, (-invA) * matB * schur);
5497 ExprP<Mat2_16b> blockC = bindExpression("blockC", ctx, (-schur) * matC * invA);
5498
5499 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5500 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5501 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5502 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5503 }
5504 };
5505
5506 template <>
5507 class Inverse64bit<2> : public DerivedFunc<Signature<Matrix2d, Matrix2d>>
5508 {
5509 public:
getName(void) const5510 string getName(void) const
5511 {
5512 return "inverse";
5513 }
5514
5515 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5516 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5517 {
5518 ExprP<Matrix2d> mat = args.a;
5519 ExprP<double> det = bindExpression("det", ctx, determinant(mat));
5520
5521 return mat2(vec2((mat[1][1] / det), (-mat[0][1] / det)), vec2((-mat[1][0] / det), (mat[0][0] / det)));
5522 }
5523 };
5524
5525 template <>
5526 class Inverse64bit<3> : public DerivedFunc<Signature<Matrix3d, Matrix3d>>
5527 {
5528 public:
getName(void) const5529 string getName(void) const
5530 {
5531 return "inverse";
5532 }
5533
5534 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5535 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5536 {
5537 ExprP<Matrix3d> mat = args.a;
5538 ExprP<Matrix2d> invA =
5539 bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5540
5541 ExprP<Vec2_64Bit> matB = bindExpression("matB", ctx, vec2(mat[2][0], mat[2][1]));
5542 ExprP<Vec2_64Bit> matC = bindExpression("matC", ctx, vec2(mat[0][2], mat[1][2]));
5543 ExprP<Matrix3d::Scalar> matD = bindExpression("matD", ctx, mat[2][2]);
5544
5545 ExprP<Matrix3d::Scalar> schur = bindExpression("schur", ctx, constant(1.0) / (matD - dot(matC * invA, matB)));
5546
5547 ExprP<Vec2_64Bit> t1 = invA * matB;
5548 ExprP<Vec2_64Bit> t2 = t1 * schur;
5549 ExprP<Matrix2d> t3 = outerProduct(t2, matC);
5550 ExprP<Matrix2d> t4 = t3 * invA;
5551 ExprP<Matrix2d> t5 = invA + t4;
5552 ExprP<Matrix2d> blockA = bindExpression("blockA", ctx, t5);
5553 ExprP<Vec2_64Bit> blockB = bindExpression("blockB", ctx, (invA * matB) * -schur);
5554 ExprP<Vec2_64Bit> blockC = bindExpression("blockC", ctx, (matC * invA) * -schur);
5555
5556 return mat3(vec3(blockA[0][0], blockA[0][1], blockC[0]), vec3(blockA[1][0], blockA[1][1], blockC[1]),
5557 vec3(blockB[0], blockB[1], schur));
5558 }
5559 };
5560
5561 template <>
5562 class Inverse64bit<4> : public DerivedFunc<Signature<Matrix4d, Matrix4d>>
5563 {
5564 public:
getName(void) const5565 string getName(void) const
5566 {
5567 return "inverse";
5568 }
5569
5570 protected:
doExpand(ExpandContext & ctx,const ArgExprs & args) const5571 ExprP<Ret> doExpand(ExpandContext &ctx, const ArgExprs &args) const
5572 {
5573 ExprP<Matrix4d> mat = args.a;
5574 ExprP<Matrix2d> invA =
5575 bindExpression("invA", ctx, inverse(mat2(vec2(mat[0][0], mat[0][1]), vec2(mat[1][0], mat[1][1]))));
5576 ExprP<Matrix2d> matB =
5577 bindExpression("matB", ctx, mat2(vec2(mat[2][0], mat[2][1]), vec2(mat[3][0], mat[3][1])));
5578 ExprP<Matrix2d> matC =
5579 bindExpression("matC", ctx, mat2(vec2(mat[0][2], mat[0][3]), vec2(mat[1][2], mat[1][3])));
5580 ExprP<Matrix2d> matD =
5581 bindExpression("matD", ctx, mat2(vec2(mat[2][2], mat[2][3]), vec2(mat[3][2], mat[3][3])));
5582 ExprP<Matrix2d> schur = bindExpression("schur", ctx, inverse(matD + -(matC * invA * matB)));
5583 ExprP<Matrix2d> blockA = bindExpression("blockA", ctx, invA + (invA * matB * schur * matC * invA));
5584 ExprP<Matrix2d> blockB = bindExpression("blockB", ctx, (-invA) * matB * schur);
5585 ExprP<Matrix2d> blockC = bindExpression("blockC", ctx, (-schur) * matC * invA);
5586
5587 return mat4(vec4(blockA[0][0], blockA[0][1], blockC[0][0], blockC[0][1]),
5588 vec4(blockA[1][0], blockA[1][1], blockC[1][0], blockC[1][1]),
5589 vec4(blockB[0][0], blockB[0][1], schur[0][0], schur[0][1]),
5590 vec4(blockB[1][0], blockB[1][1], schur[1][0], schur[1][1]));
5591 }
5592 };
5593
5594 //Signature<float, float, float, float>
5595 //Signature<deFloat16, deFloat16, deFloat16, deFloat16>
5596 //Signature<double, double, double, double>
5597 template <class T>
5598 class Fma : public DerivedFunc<T>
5599 {
5600 public:
5601 typedef typename Fma::ArgExprs ArgExprs;
5602 typedef typename Fma::Ret Ret;
5603
getName(void) const5604 string getName(void) const
5605 {
5606 return "fma";
5607 }
5608
5609 protected:
doExpand(ExpandContext &,const ArgExprs & x) const5610 ExprP<Ret> doExpand(ExpandContext &, const ArgExprs &x) const
5611 {
5612 return x.a * x.b + x.c;
5613 }
5614 };
5615
5616 } // namespace Functions
5617
5618 using namespace Functions;
5619
5620 template <typename T>
operator [](int i) const5621 ExprP<typename T::Element> ContainerExprPBase<T>::operator[](int i) const
5622 {
5623 return Functions::getComponent(exprP<T>(*this), i);
5624 }
5625
operator +(const ExprP<float> & arg0,const ExprP<float> & arg1)5626 ExprP<float> operator+(const ExprP<float> &arg0, const ExprP<float> &arg1)
5627 {
5628 return app<Add<Signature<float, float, float>>>(arg0, arg1);
5629 }
5630
operator +(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5631 ExprP<deFloat16> operator+(const ExprP<deFloat16> &arg0, const ExprP<deFloat16> &arg1)
5632 {
5633 return app<Add<Signature<deFloat16, deFloat16, deFloat16>>>(arg0, arg1);
5634 }
5635
operator +(const ExprP<double> & arg0,const ExprP<double> & arg1)5636 ExprP<double> operator+(const ExprP<double> &arg0, const ExprP<double> &arg1)
5637 {
5638 return app<Add<Signature<double, double, double>>>(arg0, arg1);
5639 }
5640
5641 template <typename T>
operator -(const ExprP<T> & arg0,const ExprP<T> & arg1)5642 ExprP<T> operator-(const ExprP<T> &arg0, const ExprP<T> &arg1)
5643 {
5644 return app<Sub<Signature<T, T, T>>>(arg0, arg1);
5645 }
5646
5647 template <typename T>
operator -(const ExprP<T> & arg0)5648 ExprP<T> operator-(const ExprP<T> &arg0)
5649 {
5650 return app<Negate<Signature<T, T>>>(arg0);
5651 }
5652
operator *(const ExprP<float> & arg0,const ExprP<float> & arg1)5653 ExprP<float> operator*(const ExprP<float> &arg0, const ExprP<float> &arg1)
5654 {
5655 return app<Mul<Signature<float, float, float>>>(arg0, arg1);
5656 }
5657
operator *(const ExprP<deFloat16> & arg0,const ExprP<deFloat16> & arg1)5658 ExprP<deFloat16> operator*(const ExprP<deFloat16> &arg0, const ExprP<deFloat16> &arg1)
5659 {
5660 return app<Mul<Signature<deFloat16, deFloat16, deFloat16>>>(arg0, arg1);
5661 }
5662
operator *(const ExprP<double> & arg0,const ExprP<double> & arg1)5663 ExprP<double> operator*(const ExprP<double> &arg0, const ExprP<double> &arg1)
5664 {
5665 return app<Mul<Signature<double, double, double>>>(arg0, arg1);
5666 }
5667
5668 template <typename T>
operator /(const ExprP<T> & arg0,const ExprP<T> & arg1)5669 ExprP<T> operator/(const ExprP<T> &arg0, const ExprP<T> &arg1)
5670 {
5671 return app<Div<Signature<T, T, T>>>(arg0, arg1);
5672 }
5673
5674 template <typename Sig_, int Size>
5675 class GenFunc : public PrimitiveFunc<Signature<typename ContainerOf<typename Sig_::Ret, Size>::Container,
5676 typename ContainerOf<typename Sig_::Arg0, Size>::Container,
5677 typename ContainerOf<typename Sig_::Arg1, Size>::Container,
5678 typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5679 typename ContainerOf<typename Sig_::Arg3, Size>::Container>>
5680 {
5681 public:
5682 typedef typename GenFunc::IArgs IArgs;
5683 typedef typename GenFunc::IRet IRet;
5684
GenFunc(const Func<Sig_> & scalarFunc)5685 GenFunc(const Func<Sig_> &scalarFunc) : m_func(scalarFunc)
5686 {
5687 }
5688
getSpirvCase(void) const5689 SpirVCaseT getSpirvCase(void) const
5690 {
5691 return m_func.getSpirvCase();
5692 }
5693
getName(void) const5694 string getName(void) const
5695 {
5696 return m_func.getName();
5697 }
5698
getOutParamIndex(void) const5699 int getOutParamIndex(void) const
5700 {
5701 return m_func.getOutParamIndex();
5702 }
5703
getRequiredExtension(void) const5704 string getRequiredExtension(void) const
5705 {
5706 return m_func.getRequiredExtension();
5707 }
5708
getInputRange(const bool is16bit) const5709 Interval getInputRange(const bool is16bit) const
5710 {
5711 return m_func.getInputRange(is16bit);
5712 }
5713
5714 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5715 void doPrint(ostream &os, const BaseArgExprs &args) const
5716 {
5717 m_func.print(os, args);
5718 }
5719
doApply(const EvalContext & ctx,const IArgs & iargs) const5720 IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5721 {
5722 IRet ret;
5723
5724 for (int ndx = 0; ndx < Size; ++ndx)
5725 {
5726 ret[ndx] = m_func.apply(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5727 }
5728
5729 return ret;
5730 }
5731
doFail(const EvalContext & ctx,const IArgs & iargs) const5732 IRet doFail(const EvalContext &ctx, const IArgs &iargs) const
5733 {
5734 IRet ret;
5735
5736 for (int ndx = 0; ndx < Size; ++ndx)
5737 {
5738 ret[ndx] = m_func.fail(ctx, iargs.a[ndx], iargs.b[ndx], iargs.c[ndx], iargs.d[ndx]);
5739 }
5740
5741 return ret;
5742 }
5743
doGetUsedFuncs(FuncSet & dst) const5744 void doGetUsedFuncs(FuncSet &dst) const
5745 {
5746 m_func.getUsedFuncs(dst);
5747 }
5748
5749 const Func<Sig_> &m_func;
5750 };
5751
5752 template <typename F, int Size>
5753 class VectorizedFunc : public GenFunc<typename F::Sig, Size>
5754 {
5755 public:
VectorizedFunc(void)5756 VectorizedFunc(void) : GenFunc<typename F::Sig, Size>(instance<F>())
5757 {
5758 }
5759 };
5760
5761 template <typename Sig_, int Size>
5762 class FixedGenFunc
5763 : public PrimitiveFunc<Signature<typename ContainerOf<typename Sig_::Ret, Size>::Container,
5764 typename ContainerOf<typename Sig_::Arg0, Size>::Container, typename Sig_::Arg1,
5765 typename ContainerOf<typename Sig_::Arg2, Size>::Container,
5766 typename ContainerOf<typename Sig_::Arg3, Size>::Container>>
5767 {
5768 public:
5769 typedef typename FixedGenFunc::IArgs IArgs;
5770 typedef typename FixedGenFunc::IRet IRet;
5771
getName(void) const5772 string getName(void) const
5773 {
5774 return this->doGetScalarFunc().getName();
5775 }
5776
getSpirvCase(void) const5777 SpirVCaseT getSpirvCase(void) const
5778 {
5779 return this->doGetScalarFunc().getSpirvCase();
5780 }
5781
5782 protected:
doPrint(ostream & os,const BaseArgExprs & args) const5783 void doPrint(ostream &os, const BaseArgExprs &args) const
5784 {
5785 this->doGetScalarFunc().print(os, args);
5786 }
5787
doApply(const EvalContext & ctx,const IArgs & iargs) const5788 IRet doApply(const EvalContext &ctx, const IArgs &iargs) const
5789 {
5790 IRet ret;
5791 const Func<Sig_> &func = this->doGetScalarFunc();
5792
5793 for (int ndx = 0; ndx < Size; ++ndx)
5794 ret[ndx] = func.apply(ctx, iargs.a[ndx], iargs.b, iargs.c[ndx], iargs.d[ndx]);
5795
5796 return ret;
5797 }
5798
5799 virtual const Func<Sig_> &doGetScalarFunc(void) const = 0;
5800 };
5801
5802 template <typename F, int Size>
5803 class FixedVecFunc : public FixedGenFunc<typename F::Sig, Size>
5804 {
5805 protected:
doGetScalarFunc(void) const5806 const Func<typename F::Sig> &doGetScalarFunc(void) const
5807 {
5808 return instance<F>();
5809 }
5810 };
5811
5812 template <typename Sig>
5813 struct GenFuncs
5814 {
GenFuncsvkt::shaderexecutor::GenFuncs5815 GenFuncs(const Func<Sig> &func_, const GenFunc<Sig, 2> &func2_, const GenFunc<Sig, 3> &func3_,
5816 const GenFunc<Sig, 4> &func4_)
5817 : func(func_)
5818 , func2(func2_)
5819 , func3(func3_)
5820 , func4(func4_)
5821 {
5822 }
5823
5824 const Func<Sig> &func;
5825 const GenFunc<Sig, 2> &func2;
5826 const GenFunc<Sig, 3> &func3;
5827 const GenFunc<Sig, 4> &func4;
5828 };
5829
5830 template <typename F>
makeVectorizedFuncs(void)5831 GenFuncs<typename F::Sig> makeVectorizedFuncs(void)
5832 {
5833 return GenFuncs<typename F::Sig>(instance<F>(), instance<VectorizedFunc<F, 2>>(), instance<VectorizedFunc<F, 3>>(),
5834 instance<VectorizedFunc<F, 4>>());
5835 }
5836
5837 template <typename T, int Size>
operator /(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5838 ExprP<Vector<T, Size>> operator/(const ExprP<Vector<T, Size>> &arg0, const ExprP<T> &arg1)
5839 {
5840 return app<FixedVecFunc<Div<Signature<T, T, T>>, Size>>(arg0, arg1);
5841 }
5842
5843 template <typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0)5844 ExprP<Vector<T, Size>> operator-(const ExprP<Vector<T, Size>> &arg0)
5845 {
5846 return app<VectorizedFunc<Negate<Signature<T, T>>, Size>>(arg0);
5847 }
5848
5849 template <typename T, int Size>
operator -(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5850 ExprP<Vector<T, Size>> operator-(const ExprP<Vector<T, Size>> &arg0, const ExprP<Vector<T, Size>> &arg1)
5851 {
5852 return app<VectorizedFunc<Sub<Signature<T, T, T>>, Size>>(arg0, arg1);
5853 }
5854
5855 template <int Size, typename T>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<T> & arg1)5856 ExprP<Vector<T, Size>> operator*(const ExprP<Vector<T, Size>> &arg0, const ExprP<T> &arg1)
5857 {
5858 return app<FixedVecFunc<Mul<Signature<T, T, T>>, Size>>(arg0, arg1);
5859 }
5860
5861 template <typename T, int Size>
operator *(const ExprP<Vector<T,Size>> & arg0,const ExprP<Vector<T,Size>> & arg1)5862 ExprP<Vector<T, Size>> operator*(const ExprP<Vector<T, Size>> &arg0, const ExprP<Vector<T, Size>> &arg1)
5863 {
5864 return app<VectorizedFunc<Mul<Signature<T, T, T>>, Size>>(arg0, arg1);
5865 }
5866
5867 template <int LeftRows, int Middle, int RightCols, typename T>
operator *(const ExprP<Matrix<T,LeftRows,Middle>> & left,const ExprP<Matrix<T,Middle,RightCols>> & right)5868 ExprP<Matrix<T, LeftRows, RightCols>> operator*(const ExprP<Matrix<T, LeftRows, Middle>> &left,
5869 const ExprP<Matrix<T, Middle, RightCols>> &right)
5870 {
5871 return app<MatMul<T, LeftRows, Middle, RightCols>>(left, right);
5872 }
5873
5874 template <int Rows, int Cols, typename T>
operator *(const ExprP<Vector<T,Cols>> & left,const ExprP<Matrix<T,Rows,Cols>> & right)5875 ExprP<Vector<T, Rows>> operator*(const ExprP<Vector<T, Cols>> &left, const ExprP<Matrix<T, Rows, Cols>> &right)
5876 {
5877 return app<VecMatMul<T, Rows, Cols>>(left, right);
5878 }
5879
5880 template <int Rows, int Cols, class T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<Vector<T,Rows>> & right)5881 ExprP<Vector<T, Cols>> operator*(const ExprP<Matrix<T, Rows, Cols>> &left, const ExprP<Vector<T, Rows>> &right)
5882 {
5883 return app<MatVecMul<Rows, Cols, T>>(left, right);
5884 }
5885
5886 template <int Rows, int Cols, typename T>
operator *(const ExprP<Matrix<T,Rows,Cols>> & left,const ExprP<T> & right)5887 ExprP<Matrix<T, Rows, Cols>> operator*(const ExprP<Matrix<T, Rows, Cols>> &left, const ExprP<T> &right)
5888 {
5889 return app<ScalarMatFunc<Mul<Signature<T, T, T>>, Rows, Cols>>(left, right);
5890 }
5891
5892 template <int Rows, int Cols>
operator +(const ExprP<Matrix<float,Rows,Cols>> & left,const ExprP<Matrix<float,Rows,Cols>> & right)5893 ExprP<Matrix<float, Rows, Cols>> operator+(const ExprP<Matrix<float, Rows, Cols>> &left,
5894 const ExprP<Matrix<float, Rows, Cols>> &right)
5895 {
5896 return app<CompMatFunc<Add<Signature<float, float, float>>, float, Rows, Cols>>(left, right);
5897 }
5898
5899 template <int Rows, int Cols>
operator +(const ExprP<Matrix<deFloat16,Rows,Cols>> & left,const ExprP<Matrix<deFloat16,Rows,Cols>> & right)5900 ExprP<Matrix<deFloat16, Rows, Cols>> operator+(const ExprP<Matrix<deFloat16, Rows, Cols>> &left,
5901 const ExprP<Matrix<deFloat16, Rows, Cols>> &right)
5902 {
5903 return app<CompMatFunc<Add<Signature<deFloat16, deFloat16, deFloat16>>, deFloat16, Rows, Cols>>(left, right);
5904 }
5905
5906 template <int Rows, int Cols>
operator +(const ExprP<Matrix<double,Rows,Cols>> & left,const ExprP<Matrix<double,Rows,Cols>> & right)5907 ExprP<Matrix<double, Rows, Cols>> operator+(const ExprP<Matrix<double, Rows, Cols>> &left,
5908 const ExprP<Matrix<double, Rows, Cols>> &right)
5909 {
5910 return app<CompMatFunc<Add<Signature<double, double, double>>, double, Rows, Cols>>(left, right);
5911 }
5912
5913 template <typename T, int Rows, int Cols>
operator -(const ExprP<Matrix<T,Rows,Cols>> & mat)5914 ExprP<Matrix<T, Rows, Cols>> operator-(const ExprP<Matrix<T, Rows, Cols>> &mat)
5915 {
5916 return app<MatNeg<T, Rows, Cols>>(mat);
5917 }
5918
5919 template <typename T>
5920 class Sampling
5921 {
5922 public:
~Sampling()5923 virtual ~Sampling()
5924 {
5925 }
5926
genFixeds(const FloatFormat &,const Precision,vector<T> &,const Interval &) const5927 virtual void genFixeds(const FloatFormat &, const Precision, vector<T> &, const Interval &) const
5928 {
5929 }
genRandom(const FloatFormat &,const Precision,Random &,const Interval &) const5930 virtual T genRandom(const FloatFormat &, const Precision, Random &, const Interval &) const
5931 {
5932 return T();
5933 }
removeNotInRange(vector<T> &,const Interval &,const Precision) const5934 virtual void removeNotInRange(vector<T> &, const Interval &, const Precision) const
5935 {
5936 }
5937 };
5938
5939 template <>
5940 class DefaultSampling<Void> : public Sampling<Void>
5941 {
5942 public:
genFixeds(const FloatFormat &,const Precision,vector<Void> & dst,const Interval &) const5943 void genFixeds(const FloatFormat &, const Precision, vector<Void> &dst, const Interval &) const
5944 {
5945 dst.push_back(Void());
5946 }
5947 };
5948
5949 template <>
5950 class DefaultSampling<bool> : public Sampling<bool>
5951 {
5952 public:
genFixeds(const FloatFormat &,const Precision,vector<bool> & dst,const Interval &) const5953 void genFixeds(const FloatFormat &, const Precision, vector<bool> &dst, const Interval &) const
5954 {
5955 dst.push_back(true);
5956 dst.push_back(false);
5957 }
5958 };
5959
5960 template <>
5961 class DefaultSampling<int> : public Sampling<int>
5962 {
5963 public:
genRandom(const FloatFormat &,const Precision prec,Random & rnd,const Interval &) const5964 int genRandom(const FloatFormat &, const Precision prec, Random &rnd, const Interval &) const
5965 {
5966 const int exp = rnd.getInt(0, getNumBits(prec) - 2);
5967 const int sign = rnd.getBool() ? -1 : 1;
5968
5969 return sign * rnd.getInt(0, (int32_t)1 << exp);
5970 }
5971
genFixeds(const FloatFormat &,const Precision,vector<int> & dst,const Interval &) const5972 void genFixeds(const FloatFormat &, const Precision, vector<int> &dst, const Interval &) const
5973 {
5974 dst.push_back(0);
5975 dst.push_back(-1);
5976 dst.push_back(1);
5977 }
5978
5979 private:
getNumBits(Precision prec)5980 static inline int getNumBits(Precision prec)
5981 {
5982 switch (prec)
5983 {
5984 case glu::PRECISION_LAST:
5985 case glu::PRECISION_MEDIUMP:
5986 return 16;
5987 case glu::PRECISION_HIGHP:
5988 return 32;
5989 default:
5990 DE_ASSERT(false);
5991 return 0;
5992 }
5993 }
5994 };
5995
5996 template <>
5997 class DefaultSampling<float> : public Sampling<float>
5998 {
5999 public:
6000 float genRandom(const FloatFormat &format, const Precision prec, Random &rnd, const Interval &inputRange) const;
6001 void genFixeds(const FloatFormat &format, const Precision prec, vector<float> &dst,
6002 const Interval &inputRange) const;
6003 void removeNotInRange(vector<float> &dst, const Interval &inputRange, const Precision prec) const;
6004 };
6005
6006 template <>
6007 class DefaultSampling<double> : public Sampling<double>
6008 {
6009 public:
6010 double genRandom(const FloatFormat &format, const Precision prec, Random &rnd, const Interval &inputRange) const;
6011 void genFixeds(const FloatFormat &format, const Precision prec, vector<double> &dst,
6012 const Interval &inputRange) const;
6013 void removeNotInRange(vector<double> &dst, const Interval &inputRange, const Precision prec) const;
6014 };
6015
isDenorm16(deFloat16 v)6016 static bool isDenorm16(deFloat16 v)
6017 {
6018 const uint16_t mantissa = 0x03FF;
6019 const uint16_t exponent = 0x7C00;
6020 return ((exponent & v) == 0 && (mantissa & v) != 0);
6021 }
6022
6023 //! Generate a random double from a reasonable general-purpose distribution.
randomDouble(const FloatFormat & format,Random & rnd,const Interval & inputRange)6024 double randomDouble(const FloatFormat &format, Random &rnd, const Interval &inputRange)
6025 {
6026 // No testing of subnormals. TODO: Could integrate float controls for some operations.
6027 const int minExp = format.getMinExp();
6028 const int maxExp = format.getMaxExp();
6029 const bool haveSubnormal = false;
6030 const double midpoint = inputRange.midpoint();
6031
6032 // Choose exponent so that the cumulative distribution is cubic.
6033 // This makes the probability distribution quadratic, with the peak centered on zero.
6034 const double minRoot = deCbrt(minExp - 0.5 - (haveSubnormal ? 1.0 : 0.0));
6035 const double maxRoot = deCbrt(maxExp + 0.5);
6036 const int fractionBits = format.getFractionBits();
6037 const int exp = int(deRoundEven(dePow(rnd.getDouble(minRoot, maxRoot), 3.0)));
6038
6039 // Generate some occasional special numbers
6040 switch (rnd.getInt(0, 64))
6041 {
6042 case 0:
6043 return inputRange.contains(0) ? 0 : midpoint;
6044 case 1:
6045 return inputRange.contains(TCU_INFINITY) ? TCU_INFINITY : midpoint;
6046 case 2:
6047 return inputRange.contains(-TCU_INFINITY) ? -TCU_INFINITY : midpoint;
6048 case 3:
6049 return inputRange.contains(TCU_NAN) ? TCU_NAN : midpoint;
6050 default:
6051 break;
6052 }
6053
6054 DE_ASSERT(fractionBits < std::numeric_limits<double>::digits);
6055
6056 // Normal number
6057 double base = deLdExp(1.0, exp);
6058 double quantum = deLdExp(1.0, exp - fractionBits); // smallest representable difference in the binade
6059 double significand = 0.0;
6060 switch (rnd.getInt(0, 16))
6061 {
6062 case 0: // The highest number in this binade, significand is all bits one.
6063 significand = base - quantum;
6064 break;
6065 case 1: // Significand is one.
6066 significand = quantum;
6067 break;
6068 case 2: // Significand is zero.
6069 significand = 0.0;
6070 break;
6071 default: // Random (evenly distributed) significand.
6072 {
6073 uint64_t intFraction = rnd.getUint64() & ((1 << fractionBits) - 1);
6074 significand = double(intFraction) * quantum;
6075 }
6076 }
6077
6078 // Produce positive numbers more often than negative.
6079 double value = (rnd.getInt(0, 3) == 0 ? -1.0 : 1.0) * (base + significand);
6080 return inputRange.contains(value) ? value : midpoint;
6081 }
6082
6083 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const6084 float DefaultSampling<float>::genRandom(const FloatFormat &format, Precision prec, Random &rnd,
6085 const Interval &inputRange) const
6086 {
6087 DE_UNREF(prec);
6088 return (float)randomDouble(format, rnd, inputRange);
6089 }
6090
6091 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<float> & dst,const Interval & inputRange) const6092 void DefaultSampling<float>::genFixeds(const FloatFormat &format, const Precision prec, vector<float> &dst,
6093 const Interval &inputRange) const
6094 {
6095 const int minExp = format.getMinExp();
6096 const int maxExp = format.getMaxExp();
6097 const int fractionBits = format.getFractionBits();
6098 const float minQuantum = deFloatLdExp(1.0f, minExp - fractionBits);
6099 const float minNormalized = deFloatLdExp(1.0f, minExp);
6100 const float maxQuantum = deFloatLdExp(1.0f, maxExp - fractionBits);
6101
6102 // NaN
6103 dst.push_back(TCU_NAN);
6104 // Zero
6105 dst.push_back(0.0f);
6106
6107 for (int sign = -1; sign <= 1; sign += 2)
6108 {
6109 // Smallest normalized
6110 dst.push_back((float)sign * minNormalized);
6111
6112 // Next smallest normalized
6113 dst.push_back((float)sign * (minNormalized + minQuantum));
6114
6115 dst.push_back((float)sign * 0.5f);
6116 dst.push_back((float)sign * 1.0f);
6117 dst.push_back((float)sign * 2.0f);
6118
6119 // Largest number
6120 dst.push_back((float)sign * (deFloatLdExp(1.0f, maxExp) + (deFloatLdExp(1.0f, maxExp) - maxQuantum)));
6121
6122 dst.push_back((float)sign * TCU_INFINITY);
6123 }
6124 removeNotInRange(dst, inputRange, prec);
6125 }
6126
removeNotInRange(vector<float> & dst,const Interval & inputRange,const Precision prec) const6127 void DefaultSampling<float>::removeNotInRange(vector<float> &dst, const Interval &inputRange,
6128 const Precision prec) const
6129 {
6130 for (vector<float>::iterator it = dst.begin(); it < dst.end();)
6131 {
6132 // Remove out of range values. PRECISION_LAST means this is an FP16 test so remove any values that
6133 // will be denorms when converted to FP16. (This is used in the precision_fp16_storage32b test group).
6134 if (!inputRange.contains(static_cast<double>(*it)) ||
6135 (prec == glu::PRECISION_LAST && isDenorm16(deFloat32To16Round(*it, DE_ROUNDINGMODE_TO_ZERO))))
6136 it = dst.erase(it);
6137 else
6138 ++it;
6139 }
6140 }
6141
6142 //! Generate a random double from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,Precision prec,Random & rnd,const Interval & inputRange) const6143 double DefaultSampling<double>::genRandom(const FloatFormat &format, Precision prec, Random &rnd,
6144 const Interval &inputRange) const
6145 {
6146 DE_UNREF(prec);
6147 return randomDouble(format, rnd, inputRange);
6148 }
6149
6150 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<double> & dst,const Interval & inputRange) const6151 void DefaultSampling<double>::genFixeds(const FloatFormat &format, const Precision prec, vector<double> &dst,
6152 const Interval &inputRange) const
6153 {
6154 const int minExp = format.getMinExp();
6155 const int maxExp = format.getMaxExp();
6156 const int fractionBits = format.getFractionBits();
6157 const double minQuantum = deLdExp(1.0, minExp - fractionBits);
6158 const double minNormalized = deLdExp(1.0, minExp);
6159 const double maxQuantum = deLdExp(1.0, maxExp - fractionBits);
6160
6161 // NaN
6162 dst.push_back(TCU_NAN);
6163 // Zero
6164 dst.push_back(0.0);
6165
6166 for (int sign = -1; sign <= 1; sign += 2)
6167 {
6168 // Smallest normalized
6169 dst.push_back((double)sign * minNormalized);
6170
6171 // Next smallest normalized
6172 dst.push_back((double)sign * (minNormalized + minQuantum));
6173
6174 dst.push_back((double)sign * 0.5);
6175 dst.push_back((double)sign * 1.0);
6176 dst.push_back((double)sign * 2.0);
6177
6178 // Largest number
6179 dst.push_back((double)sign * (deLdExp(1.0, maxExp) + (deLdExp(1.0, maxExp) - maxQuantum)));
6180
6181 dst.push_back((double)sign * TCU_INFINITY);
6182 }
6183 removeNotInRange(dst, inputRange, prec);
6184 }
6185
removeNotInRange(vector<double> & dst,const Interval & inputRange,const Precision) const6186 void DefaultSampling<double>::removeNotInRange(vector<double> &dst, const Interval &inputRange, const Precision) const
6187 {
6188 for (vector<double>::iterator it = dst.begin(); it < dst.end();)
6189 {
6190 if (!inputRange.contains(*it))
6191 it = dst.erase(it);
6192 else
6193 ++it;
6194 }
6195 }
6196
6197 template <>
6198 class DefaultSampling<deFloat16> : public Sampling<deFloat16>
6199 {
6200 public:
6201 deFloat16 genRandom(const FloatFormat &format, const Precision prec, Random &rnd, const Interval &inputRange) const;
6202 void genFixeds(const FloatFormat &format, const Precision prec, vector<deFloat16> &dst,
6203 const Interval &inputRange) const;
6204
6205 private:
6206 void removeNotInRange(vector<deFloat16> &dst, const Interval &inputRange, const Precision prec) const;
6207 };
6208
6209 //! Generate a random float from a reasonable general-purpose distribution.
genRandom(const FloatFormat & format,const Precision prec,Random & rnd,const Interval & inputRange) const6210 deFloat16 DefaultSampling<deFloat16>::genRandom(const FloatFormat &format, const Precision prec, Random &rnd,
6211 const Interval &inputRange) const
6212 {
6213 DE_UNREF(prec);
6214 return deFloat64To16Round(randomDouble(format, rnd, inputRange), DE_ROUNDINGMODE_TO_NEAREST_EVEN);
6215 }
6216
6217 //! Generate a standard set of floats that should always be tested.
genFixeds(const FloatFormat & format,const Precision prec,vector<deFloat16> & dst,const Interval & inputRange) const6218 void DefaultSampling<deFloat16>::genFixeds(const FloatFormat &format, const Precision prec, vector<deFloat16> &dst,
6219 const Interval &inputRange) const
6220 {
6221 dst.push_back(uint16_t(0x3E00)); //1.5
6222 dst.push_back(uint16_t(0x3D00)); //1.25
6223 dst.push_back(uint16_t(0x3F00)); //1.75
6224 // Zero
6225 dst.push_back(uint16_t(0x0000));
6226 dst.push_back(uint16_t(0x8000));
6227 // Infinity
6228 dst.push_back(uint16_t(0x7c00));
6229 dst.push_back(uint16_t(0xfc00));
6230 // SNaN
6231 dst.push_back(uint16_t(0x7c0f));
6232 dst.push_back(uint16_t(0xfc0f));
6233 // QNaN
6234 dst.push_back(uint16_t(0x7cf0));
6235 dst.push_back(uint16_t(0xfcf0));
6236 // Normalized
6237 dst.push_back(uint16_t(0x0401));
6238 dst.push_back(uint16_t(0x8401));
6239 // Some normal number
6240 dst.push_back(uint16_t(0x14cb));
6241 dst.push_back(uint16_t(0x94cb));
6242
6243 const int minExp = format.getMinExp();
6244 const int maxExp = format.getMaxExp();
6245 const int fractionBits = format.getFractionBits();
6246 const float minQuantum = deFloatLdExp(1.0f, minExp - fractionBits);
6247 const float minNormalized = deFloatLdExp(1.0f, minExp);
6248 const float maxQuantum = deFloatLdExp(1.0f, maxExp - fractionBits);
6249
6250 for (float sign = -1.0; sign <= 1.0f; sign += 2.0f)
6251 {
6252 // Smallest normalized
6253 dst.push_back(deFloat32To16Round(sign * minNormalized, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6254
6255 // Next smallest normalized
6256 dst.push_back(deFloat32To16Round(sign * (minNormalized + minQuantum), DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6257
6258 dst.push_back(deFloat32To16Round(sign * 0.5f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6259 dst.push_back(deFloat32To16Round(sign * 1.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6260 dst.push_back(deFloat32To16Round(sign * 2.0f, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6261
6262 // Largest number
6263 dst.push_back(
6264 deFloat32To16Round(sign * (deFloatLdExp(1.0f, maxExp) + (deFloatLdExp(1.0f, maxExp) - maxQuantum)),
6265 DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6266
6267 dst.push_back(deFloat32To16Round(sign * TCU_INFINITY, DE_ROUNDINGMODE_TO_NEAREST_EVEN));
6268 }
6269 removeNotInRange(dst, inputRange, prec);
6270 }
6271
removeNotInRange(vector<deFloat16> & dst,const Interval & inputRange,const Precision) const6272 void DefaultSampling<deFloat16>::removeNotInRange(vector<deFloat16> &dst, const Interval &inputRange,
6273 const Precision) const
6274 {
6275 for (vector<deFloat16>::iterator it = dst.begin(); it < dst.end();)
6276 {
6277 if (inputRange.contains(static_cast<double>(*it)))
6278 ++it;
6279 else
6280 it = dst.erase(it);
6281 }
6282 }
6283
6284 template <typename T, int Size>
6285 class DefaultSampling<Vector<T, Size>> : public Sampling<Vector<T, Size>>
6286 {
6287 public:
6288 typedef Vector<T, Size> Value;
6289
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const6290 Value genRandom(const FloatFormat &fmt, const Precision prec, Random &rnd, const Interval &inputRange) const
6291 {
6292 Value ret;
6293
6294 for (int ndx = 0; ndx < Size; ++ndx)
6295 ret[ndx] = instance<DefaultSampling<T>>().genRandom(fmt, prec, rnd, inputRange);
6296
6297 return ret;
6298 }
6299
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const6300 void genFixeds(const FloatFormat &fmt, const Precision prec, vector<Value> &dst, const Interval &inputRange) const
6301 {
6302 vector<T> scalars;
6303
6304 instance<DefaultSampling<T>>().genFixeds(fmt, prec, scalars, inputRange);
6305
6306 for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
6307 dst.push_back(Value(scalars[scalarNdx]));
6308 }
6309 };
6310
6311 template <typename T, int Rows, int Columns>
6312 class DefaultSampling<Matrix<T, Rows, Columns>> : public Sampling<Matrix<T, Rows, Columns>>
6313 {
6314 public:
6315 typedef Matrix<T, Rows, Columns> Value;
6316
genRandom(const FloatFormat & fmt,const Precision prec,Random & rnd,const Interval & inputRange) const6317 Value genRandom(const FloatFormat &fmt, const Precision prec, Random &rnd, const Interval &inputRange) const
6318 {
6319 Value ret;
6320
6321 for (int rowNdx = 0; rowNdx < Rows; ++rowNdx)
6322 for (int colNdx = 0; colNdx < Columns; ++colNdx)
6323 ret(rowNdx, colNdx) = instance<DefaultSampling<T>>().genRandom(fmt, prec, rnd, inputRange);
6324
6325 return ret;
6326 }
6327
genFixeds(const FloatFormat & fmt,const Precision prec,vector<Value> & dst,const Interval & inputRange) const6328 void genFixeds(const FloatFormat &fmt, const Precision prec, vector<Value> &dst, const Interval &inputRange) const
6329 {
6330 vector<T> scalars;
6331
6332 instance<DefaultSampling<T>>().genFixeds(fmt, prec, scalars, inputRange);
6333
6334 for (size_t scalarNdx = 0; scalarNdx < scalars.size(); ++scalarNdx)
6335 dst.push_back(Value(scalars[scalarNdx]));
6336
6337 if (Columns == Rows)
6338 {
6339 Value mat(T(0.0));
6340 T x = T(1.0f);
6341 mat[0][0] = x;
6342 for (int ndx = 0; ndx < Columns; ++ndx)
6343 {
6344 mat[Columns - 1 - ndx][ndx] = x;
6345 x = static_cast<T>(x * static_cast<T>(2.0f));
6346 }
6347 dst.push_back(mat);
6348 }
6349 }
6350 };
6351
6352 struct CaseContext
6353 {
CaseContextvkt::shaderexecutor::CaseContext6354 CaseContext(const string &name_, TestContext &testContext_, const FloatFormat &floatFormat_,
6355 const FloatFormat &highpFormat_, const Precision precision_, const ShaderType shaderType_,
6356 const size_t numRandoms_,
6357 const PrecisionTestFeatures precisionTestFeatures_ = PRECISION_TEST_FEATURES_NONE,
6358 const bool isPackFloat16b_ = false, const bool isFloat64b_ = false)
6359 : name(name_)
6360 , testContext(testContext_)
6361 , floatFormat(floatFormat_)
6362 , highpFormat(highpFormat_)
6363 , precision(precision_)
6364 , shaderType(shaderType_)
6365 , numRandoms(numRandoms_)
6366 , inputRange(-TCU_INFINITY, TCU_INFINITY)
6367 , precisionTestFeatures(precisionTestFeatures_)
6368 , isPackFloat16b(isPackFloat16b_)
6369 , isFloat64b(isFloat64b_)
6370 {
6371 }
6372
6373 string name;
6374 TestContext &testContext;
6375 FloatFormat floatFormat;
6376 FloatFormat highpFormat;
6377 Precision precision;
6378 ShaderType shaderType;
6379 size_t numRandoms;
6380 Interval inputRange;
6381 PrecisionTestFeatures precisionTestFeatures;
6382 bool isPackFloat16b;
6383 bool isFloat64b;
6384 };
6385
6386 template <typename In0_ = Void, typename In1_ = Void, typename In2_ = Void, typename In3_ = Void>
6387 struct InTypes
6388 {
6389 typedef In0_ In0;
6390 typedef In1_ In1;
6391 typedef In2_ In2;
6392 typedef In3_ In3;
6393 };
6394
6395 template <typename In>
numInputs(void)6396 int numInputs(void)
6397 {
6398 return (!isTypeValid<typename In::In0>() ? 0 :
6399 !isTypeValid<typename In::In1>() ? 1 :
6400 !isTypeValid<typename In::In2>() ? 2 :
6401 !isTypeValid<typename In::In3>() ? 3 :
6402 4);
6403 }
6404
6405 template <typename Out0_, typename Out1_ = Void>
6406 struct OutTypes
6407 {
6408 typedef Out0_ Out0;
6409 typedef Out1_ Out1;
6410 };
6411
6412 template <typename Out>
numOutputs(void)6413 int numOutputs(void)
6414 {
6415 return (!isTypeValid<typename Out::Out0>() ? 0 : !isTypeValid<typename Out::Out1>() ? 1 : 2);
6416 }
6417
6418 template <typename In>
6419 struct Inputs
6420 {
6421 vector<typename In::In0> in0;
6422 vector<typename In::In1> in1;
6423 vector<typename In::In2> in2;
6424 vector<typename In::In3> in3;
6425 };
6426
6427 template <typename Out>
6428 struct Outputs
6429 {
Outputsvkt::shaderexecutor::Outputs6430 Outputs(size_t size) : out0(size), out1(size)
6431 {
6432 }
6433
6434 vector<typename Out::Out0> out0;
6435 vector<typename Out::Out1> out1;
6436 };
6437
6438 template <typename In, typename Out>
6439 struct Variables
6440 {
6441 VariableP<typename In::In0> in0;
6442 VariableP<typename In::In1> in1;
6443 VariableP<typename In::In2> in2;
6444 VariableP<typename In::In3> in3;
6445 VariableP<typename Out::Out0> out0;
6446 VariableP<typename Out::Out1> out1;
6447 };
6448
6449 template <typename In>
6450 struct Samplings
6451 {
Samplingsvkt::shaderexecutor::Samplings6452 Samplings(const Sampling<typename In::In0> &in0_, const Sampling<typename In::In1> &in1_,
6453 const Sampling<typename In::In2> &in2_, const Sampling<typename In::In3> &in3_)
6454 : in0(in0_)
6455 , in1(in1_)
6456 , in2(in2_)
6457 , in3(in3_)
6458 {
6459 }
6460
6461 const Sampling<typename In::In0> &in0;
6462 const Sampling<typename In::In1> &in1;
6463 const Sampling<typename In::In2> &in2;
6464 const Sampling<typename In::In3> &in3;
6465 };
6466
6467 template <typename In>
6468 struct DefaultSamplings : Samplings<In>
6469 {
DefaultSamplingsvkt::shaderexecutor::DefaultSamplings6470 DefaultSamplings(void)
6471 : Samplings<In>(instance<DefaultSampling<typename In::In0>>(), instance<DefaultSampling<typename In::In1>>(),
6472 instance<DefaultSampling<typename In::In2>>(), instance<DefaultSampling<typename In::In3>>())
6473 {
6474 }
6475 };
6476
6477 template <typename In, typename Out>
6478 class BuiltinPrecisionCaseTestInstance : public TestInstance
6479 {
6480 public:
BuiltinPrecisionCaseTestInstance(Context & context,const CaseContext caseCtx,const ShaderSpec & shaderSpec,const Variables<In,Out> variables,const Samplings<In> & samplings,const StatementP stmt,bool modularOp=false)6481 BuiltinPrecisionCaseTestInstance(Context &context, const CaseContext caseCtx, const ShaderSpec &shaderSpec,
6482 const Variables<In, Out> variables, const Samplings<In> &samplings,
6483 const StatementP stmt, bool modularOp = false)
6484 : TestInstance(context)
6485 , m_caseCtx(caseCtx)
6486 , m_variables(variables)
6487 , m_samplings(samplings)
6488 , m_stmt(stmt)
6489 , m_executor(createExecutor(context, caseCtx.shaderType, shaderSpec))
6490 , m_modularOp(modularOp)
6491 {
6492 }
6493 virtual tcu::TestStatus iterate(void);
6494
6495 protected:
6496 CaseContext m_caseCtx;
6497 Variables<In, Out> m_variables;
6498 const Samplings<In> &m_samplings;
6499 StatementP m_stmt;
6500 de::UniquePtr<ShaderExecutor> m_executor;
6501 bool m_modularOp;
6502 };
6503
6504 template <class In, class Out>
iterate(void)6505 tcu::TestStatus BuiltinPrecisionCaseTestInstance<In, Out>::iterate(void)
6506 {
6507 typedef typename In::In0 In0;
6508 typedef typename In::In1 In1;
6509 typedef typename In::In2 In2;
6510 typedef typename In::In3 In3;
6511 typedef typename Out::Out0 Out0;
6512 typedef typename Out::Out1 Out1;
6513
6514 areFeaturesSupported(m_context, m_caseCtx.precisionTestFeatures);
6515 Inputs<In> inputs =
6516 generateInputs(m_samplings, m_caseCtx.floatFormat, m_caseCtx.precision, m_caseCtx.numRandoms,
6517 0xdeadbeefu + m_caseCtx.testContext.getCommandLine().getBaseSeed(), m_caseCtx.inputRange);
6518 const FloatFormat &fmt = m_caseCtx.floatFormat;
6519 const int inCount = numInputs<In>();
6520 const int outCount = numOutputs<Out>();
6521 const size_t numValues = (inCount > 0) ? inputs.in0.size() : 1;
6522 Outputs<Out> outputs(numValues);
6523 const FloatFormat highpFmt = m_caseCtx.highpFormat;
6524 const int maxMsgs = 100;
6525 int numErrors = 0;
6526 Environment env; // Hoisted out of the inner loop for optimization.
6527 ResultCollector status;
6528 TestLog &testLog = m_context.getTestContext().getLog();
6529
6530 // Module operations need exactly two inputs and have exactly one output.
6531 if (m_modularOp)
6532 {
6533 DE_ASSERT(inCount == 2);
6534 DE_ASSERT(outCount == 1);
6535 }
6536
6537 const void *inputArr[] = {
6538 inputs.in0.data(),
6539 inputs.in1.data(),
6540 inputs.in2.data(),
6541 inputs.in3.data(),
6542 };
6543 void *outputArr[] = {
6544 outputs.out0.data(),
6545 outputs.out1.data(),
6546 };
6547
6548 // Print out the statement and its definitions
6549 testLog << TestLog::Message << "Statement: " << m_stmt << TestLog::EndMessage;
6550 {
6551 ostringstream oss;
6552 FuncSet funcs;
6553
6554 m_stmt->getUsedFuncs(funcs);
6555 for (FuncSet::const_iterator it = funcs.begin(); it != funcs.end(); ++it)
6556 {
6557 (*it)->printDefinition(oss);
6558 }
6559 if (!funcs.empty())
6560 testLog << TestLog::Message << "Reference definitions:\n" << oss.str() << TestLog::EndMessage;
6561 }
6562 switch (inCount)
6563 {
6564 case 4:
6565 DE_ASSERT(inputs.in3.size() == numValues);
6566 // Fallthrough
6567 case 3:
6568 DE_ASSERT(inputs.in2.size() == numValues);
6569 // Fallthrough
6570 case 2:
6571 DE_ASSERT(inputs.in1.size() == numValues);
6572 // Fallthrough
6573 case 1:
6574 DE_ASSERT(inputs.in0.size() == numValues);
6575 // Fallthrough
6576 default:
6577 break;
6578 }
6579
6580 m_executor->execute(int(numValues), inputArr, outputArr);
6581
6582 // Initialize environment with unused values so we don't need to bind in inner loop.
6583 {
6584 const typename Traits<In0>::IVal in0;
6585 const typename Traits<In1>::IVal in1;
6586 const typename Traits<In2>::IVal in2;
6587 const typename Traits<In3>::IVal in3;
6588 const typename Traits<Out0>::IVal reference0;
6589 const typename Traits<Out1>::IVal reference1;
6590
6591 env.bind(*m_variables.in0, in0);
6592 env.bind(*m_variables.in1, in1);
6593 env.bind(*m_variables.in2, in2);
6594 env.bind(*m_variables.in3, in3);
6595 env.bind(*m_variables.out0, reference0);
6596 env.bind(*m_variables.out1, reference1);
6597 }
6598
6599 // For each input tuple, compute output reference interval and compare
6600 // shader output to the reference.
6601 for (size_t valueNdx = 0; valueNdx < numValues; valueNdx++)
6602 {
6603 bool result = true;
6604 const bool isInput16Bit = m_executor->areInputs16Bit();
6605 const bool isInput64Bit = m_executor->areInputs64Bit();
6606
6607 DE_ASSERT(!(isInput16Bit && isInput64Bit));
6608
6609 typename Traits<Out0>::IVal reference0;
6610 typename Traits<Out1>::IVal reference1;
6611
6612 if (valueNdx % (size_t)TOUCH_WATCHDOG_VALUE_FREQUENCY == 0)
6613 m_context.getTestContext().touchWatchdog();
6614
6615 env.lookup(*m_variables.in0) = convert<In0>(fmt, round(fmt, inputs.in0[valueNdx]));
6616 env.lookup(*m_variables.in1) = convert<In1>(fmt, round(fmt, inputs.in1[valueNdx]));
6617 env.lookup(*m_variables.in2) = convert<In2>(fmt, round(fmt, inputs.in2[valueNdx]));
6618 env.lookup(*m_variables.in3) = convert<In3>(fmt, round(fmt, inputs.in3[valueNdx]));
6619
6620 {
6621 EvalContext ctx(fmt, m_caseCtx.precision, env, 0);
6622 m_stmt->execute(ctx);
6623
6624 switch (outCount)
6625 {
6626 case 2:
6627 reference1 = convert<Out1>(highpFmt, env.lookup(*m_variables.out1));
6628 if (!status.check(contains(reference1, outputs.out1[valueNdx], m_caseCtx.isPackFloat16b),
6629 "Shader output 1 is outside acceptable range"))
6630 result = false;
6631 // Fallthrough
6632 case 1:
6633 {
6634 // Pass b from mod(a, b) if we are in the modulo operation.
6635 const tcu::Maybe<In1> modularDivisor = (m_modularOp ? tcu::just(inputs.in1[valueNdx]) : tcu::Nothing);
6636
6637 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6638 if (!status.check(
6639 contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor),
6640 "Shader output 0 is outside acceptable range"))
6641 {
6642 m_stmt->failed(ctx);
6643 reference0 = convert<Out0>(highpFmt, env.lookup(*m_variables.out0));
6644 if (!status.check(
6645 contains(reference0, outputs.out0[valueNdx], m_caseCtx.isPackFloat16b, modularDivisor),
6646 "Shader output 0 is outside acceptable range"))
6647 result = false;
6648 }
6649 }
6650 // Fallthrough
6651 default:
6652 break;
6653 }
6654 }
6655 if (!result)
6656 ++numErrors;
6657
6658 if ((!result && numErrors <= maxMsgs) || GLS_LOG_ALL_RESULTS)
6659 {
6660 MessageBuilder builder = testLog.message();
6661
6662 builder << (result ? "Passed" : "Failed") << " sample:\n";
6663
6664 if (inCount > 0)
6665 {
6666 builder << "\t" << m_variables.in0->getName() << " = "
6667 << (isInput64Bit ? value64ToString(highpFmt, inputs.in0[valueNdx]) :
6668 (isInput16Bit ? value16ToString(highpFmt, inputs.in0[valueNdx]) :
6669 value32ToString(highpFmt, inputs.in0[valueNdx])))
6670 << "\n";
6671 }
6672
6673 if (inCount > 1)
6674 {
6675 builder << "\t" << m_variables.in1->getName() << " = "
6676 << (isInput64Bit ? value64ToString(highpFmt, inputs.in1[valueNdx]) :
6677 (isInput16Bit ? value16ToString(highpFmt, inputs.in1[valueNdx]) :
6678 value32ToString(highpFmt, inputs.in1[valueNdx])))
6679 << "\n";
6680 }
6681
6682 if (inCount > 2)
6683 {
6684 builder << "\t" << m_variables.in2->getName() << " = "
6685 << (isInput64Bit ? value64ToString(highpFmt, inputs.in2[valueNdx]) :
6686 (isInput16Bit ? value16ToString(highpFmt, inputs.in2[valueNdx]) :
6687 value32ToString(highpFmt, inputs.in2[valueNdx])))
6688 << "\n";
6689 }
6690
6691 if (inCount > 3)
6692 {
6693 builder << "\t" << m_variables.in3->getName() << " = "
6694 << (isInput64Bit ? value64ToString(highpFmt, inputs.in3[valueNdx]) :
6695 (isInput16Bit ? value16ToString(highpFmt, inputs.in3[valueNdx]) :
6696 value32ToString(highpFmt, inputs.in3[valueNdx])))
6697 << "\n";
6698 }
6699
6700 if (outCount > 0)
6701 {
6702 if (m_executor->spirvCase() == SPIRV_CASETYPE_COMPARE)
6703 {
6704 builder << "Output:\n"
6705 << comparisonMessage(outputs.out0[valueNdx]) << "Expected result:\n"
6706 << comparisonMessageInterval<typename Out::Out0>(reference0) << "\n";
6707 }
6708 else
6709 {
6710 builder << "\t" << m_variables.out0->getName() << " = "
6711 << (m_executor->isOutput64Bit(0u) ?
6712 value64ToString(highpFmt, outputs.out0[valueNdx]) :
6713 (m_executor->isOutput16Bit(0u) || m_caseCtx.isPackFloat16b ?
6714 value16ToString(highpFmt, outputs.out0[valueNdx]) :
6715 value32ToString(highpFmt, outputs.out0[valueNdx])))
6716 << "\n"
6717 << "\tExpected range: " << intervalToString<typename Out::Out0>(highpFmt, reference0)
6718 << "\n";
6719 }
6720 }
6721
6722 if (outCount > 1)
6723 {
6724 builder << "\t" << m_variables.out1->getName() << " = "
6725 << (m_executor->isOutput64Bit(1u) ? value64ToString(highpFmt, outputs.out1[valueNdx]) :
6726 (m_executor->isOutput16Bit(1u) || m_caseCtx.isPackFloat16b ?
6727 value16ToString(highpFmt, outputs.out1[valueNdx]) :
6728 value32ToString(highpFmt, outputs.out1[valueNdx])))
6729 << "\n"
6730 << "\tExpected range: " << intervalToString<typename Out::Out1>(highpFmt, reference1) << "\n";
6731 }
6732
6733 builder << TestLog::EndMessage;
6734 }
6735 }
6736
6737 if (numErrors > maxMsgs)
6738 {
6739 testLog << TestLog::Message << "(Skipped " << (numErrors - maxMsgs) << " messages.)" << TestLog::EndMessage;
6740 }
6741
6742 if (numErrors == 0)
6743 {
6744 testLog << TestLog::Message << "All " << numValues << " inputs passed." << TestLog::EndMessage;
6745 }
6746 else
6747 {
6748 testLog << TestLog::Message << numErrors << "/" << numValues << " inputs failed." << TestLog::EndMessage;
6749 }
6750
6751 if (numErrors)
6752 return tcu::TestStatus::fail(de::toString(numErrors) + string(" test failed. Check log for the details"));
6753 else
6754 return tcu::TestStatus::pass("Pass");
6755 }
6756
6757 class PrecisionCase : public TestCase
6758 {
6759 protected:
PrecisionCase(const CaseContext & context,const string & name,const Interval & inputRange,const string & extension="")6760 PrecisionCase(const CaseContext &context, const string &name, const Interval &inputRange,
6761 const string &extension = "")
6762 : TestCase(context.testContext, name.c_str())
6763 , m_ctx(context)
6764 , m_extension(extension)
6765 {
6766 m_ctx.inputRange = inputRange;
6767 m_spec.packFloat16Bit = context.isPackFloat16b;
6768 }
6769
initPrograms(vk::SourceCollections & programCollection) const6770 virtual void initPrograms(vk::SourceCollections &programCollection) const
6771 {
6772 generateSources(m_ctx.shaderType, m_spec, programCollection);
6773 }
6774
getFormat(void) const6775 const FloatFormat &getFormat(void) const
6776 {
6777 return m_ctx.floatFormat;
6778 }
6779
6780 template <typename In, typename Out>
6781 void testStatement(const Variables<In, Out> &variables, const Statement &stmt, SpirVCaseT spirvCase);
6782
6783 template <typename T>
makeSymbol(const Variable<T> & variable)6784 Symbol makeSymbol(const Variable<T> &variable)
6785 {
6786 return Symbol(variable.getName(), getVarTypeOf<T>(m_ctx.precision));
6787 }
6788
6789 CaseContext m_ctx;
6790 const string m_extension;
6791 ShaderSpec m_spec;
6792 };
6793
6794 template <typename In, typename Out>
testStatement(const Variables<In,Out> & variables,const Statement & stmt,SpirVCaseT spirvCase)6795 void PrecisionCase::testStatement(const Variables<In, Out> &variables, const Statement &stmt, SpirVCaseT spirvCase)
6796 {
6797 const int inCount = numInputs<In>();
6798 const int outCount = numOutputs<Out>();
6799 Environment env; // Hoisted out of the inner loop for optimization.
6800
6801 // Initialize ShaderSpec from precision, variables and statement.
6802 if (m_ctx.precision != glu::PRECISION_LAST)
6803 {
6804 ostringstream os;
6805 os << "precision " << glu::getPrecisionName(m_ctx.precision) << " float;\n";
6806 m_spec.globalDeclarations = os.str();
6807 }
6808
6809 if (!m_extension.empty())
6810 m_spec.globalDeclarations = "#extension " + m_extension + " : require\n";
6811
6812 m_spec.inputs.resize(inCount);
6813
6814 switch (inCount)
6815 {
6816 case 4:
6817 m_spec.inputs[3] = makeSymbol(*variables.in3);
6818 // Fallthrough
6819 case 3:
6820 m_spec.inputs[2] = makeSymbol(*variables.in2);
6821 // Fallthrough
6822 case 2:
6823 m_spec.inputs[1] = makeSymbol(*variables.in1);
6824 // Fallthrough
6825 case 1:
6826 m_spec.inputs[0] = makeSymbol(*variables.in0);
6827 // Fallthrough
6828 default:
6829 break;
6830 }
6831
6832 bool inputs16Bit = false;
6833 for (vector<Symbol>::const_iterator symIter = m_spec.inputs.begin(); symIter != m_spec.inputs.end(); ++symIter)
6834 inputs16Bit = inputs16Bit || glu::isDataTypeFloat16OrVec(symIter->varType.getBasicType());
6835
6836 if (inputs16Bit || m_spec.packFloat16Bit)
6837 m_spec.globalDeclarations += "#extension GL_EXT_shader_explicit_arithmetic_types: require\n";
6838
6839 m_spec.outputs.resize(outCount);
6840
6841 switch (outCount)
6842 {
6843 case 2:
6844 m_spec.outputs[1] = makeSymbol(*variables.out1);
6845 // Fallthrough
6846 case 1:
6847 m_spec.outputs[0] = makeSymbol(*variables.out0);
6848 // Fallthrough
6849 default:
6850 break;
6851 }
6852
6853 m_spec.source = de::toString(stmt);
6854 m_spec.spirvCase = spirvCase;
6855 }
6856
6857 template <typename T>
6858 struct InputLess
6859 {
operator ()vkt::shaderexecutor::InputLess6860 bool operator()(const T &val1, const T &val2) const
6861 {
6862 return val1 < val2;
6863 }
6864 };
6865
6866 template <typename T>
inputLess(const T & val1,const T & val2)6867 bool inputLess(const T &val1, const T &val2)
6868 {
6869 return InputLess<T>()(val1, val2);
6870 }
6871
6872 template <>
6873 struct InputLess<float>
6874 {
operator ()vkt::shaderexecutor::InputLess6875 bool operator()(const float &val1, const float &val2) const
6876 {
6877 if (deIsNaN(val1))
6878 return false;
6879 if (deIsNaN(val2))
6880 return true;
6881 return val1 < val2;
6882 }
6883 };
6884
6885 template <typename T, int Size>
6886 struct InputLess<Vector<T, Size>>
6887 {
operator ()vkt::shaderexecutor::InputLess6888 bool operator()(const Vector<T, Size> &vec1, const Vector<T, Size> &vec2) const
6889 {
6890 for (int ndx = 0; ndx < Size; ++ndx)
6891 {
6892 if (inputLess(vec1[ndx], vec2[ndx]))
6893 return true;
6894 if (inputLess(vec2[ndx], vec1[ndx]))
6895 return false;
6896 }
6897
6898 return false;
6899 }
6900 };
6901
6902 template <typename T, int Rows, int Cols>
6903 struct InputLess<Matrix<T, Rows, Cols>>
6904 {
operator ()vkt::shaderexecutor::InputLess6905 bool operator()(const Matrix<T, Rows, Cols> &mat1, const Matrix<T, Rows, Cols> &mat2) const
6906 {
6907 for (int col = 0; col < Cols; ++col)
6908 {
6909 if (inputLess(mat1[col], mat2[col]))
6910 return true;
6911 if (inputLess(mat2[col], mat1[col]))
6912 return false;
6913 }
6914
6915 return false;
6916 }
6917 };
6918
6919 template <typename In>
6920 struct InTuple : public Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>
6921 {
InTuplevkt::shaderexecutor::InTuple6922 InTuple(const typename In::In0 &in0, const typename In::In1 &in1, const typename In::In2 &in2,
6923 const typename In::In3 &in3)
6924 : Tuple4<typename In::In0, typename In::In1, typename In::In2, typename In::In3>(in0, in1, in2, in3)
6925 {
6926 }
6927 };
6928
6929 template <typename In>
6930 struct InputLess<InTuple<In>>
6931 {
operator ()vkt::shaderexecutor::InputLess6932 bool operator()(const InTuple<In> &in1, const InTuple<In> &in2) const
6933 {
6934 if (inputLess(in1.a, in2.a))
6935 return true;
6936 if (inputLess(in2.a, in1.a))
6937 return false;
6938 if (inputLess(in1.b, in2.b))
6939 return true;
6940 if (inputLess(in2.b, in1.b))
6941 return false;
6942 if (inputLess(in1.c, in2.c))
6943 return true;
6944 if (inputLess(in2.c, in1.c))
6945 return false;
6946 if (inputLess(in1.d, in2.d))
6947 return true;
6948 return false;
6949 }
6950 };
6951
6952 template <typename In>
generateInputs(const Samplings<In> & samplings,const FloatFormat & floatFormat,Precision intPrecision,size_t numSamples,uint32_t seed,const Interval & inputRange)6953 Inputs<In> generateInputs(const Samplings<In> &samplings, const FloatFormat &floatFormat, Precision intPrecision,
6954 size_t numSamples, uint32_t seed, const Interval &inputRange)
6955 {
6956 Random rnd(seed);
6957 Inputs<In> ret;
6958 Inputs<In> fixedInputs;
6959 set<InTuple<In>, InputLess<InTuple<In>>> seenInputs;
6960
6961 samplings.in0.genFixeds(floatFormat, intPrecision, fixedInputs.in0, inputRange);
6962 samplings.in1.genFixeds(floatFormat, intPrecision, fixedInputs.in1, inputRange);
6963 samplings.in2.genFixeds(floatFormat, intPrecision, fixedInputs.in2, inputRange);
6964 samplings.in3.genFixeds(floatFormat, intPrecision, fixedInputs.in3, inputRange);
6965
6966 for (size_t ndx0 = 0; ndx0 < fixedInputs.in0.size(); ++ndx0)
6967 {
6968 for (size_t ndx1 = 0; ndx1 < fixedInputs.in1.size(); ++ndx1)
6969 {
6970 for (size_t ndx2 = 0; ndx2 < fixedInputs.in2.size(); ++ndx2)
6971 {
6972 for (size_t ndx3 = 0; ndx3 < fixedInputs.in3.size(); ++ndx3)
6973 {
6974 const InTuple<In> tuple(fixedInputs.in0[ndx0], fixedInputs.in1[ndx1], fixedInputs.in2[ndx2],
6975 fixedInputs.in3[ndx3]);
6976
6977 seenInputs.insert(tuple);
6978 ret.in0.push_back(tuple.a);
6979 ret.in1.push_back(tuple.b);
6980 ret.in2.push_back(tuple.c);
6981 ret.in3.push_back(tuple.d);
6982 }
6983 }
6984 }
6985 }
6986
6987 for (size_t ndx = 0; ndx < numSamples; ++ndx)
6988 {
6989 const typename In::In0 in0 = samplings.in0.genRandom(floatFormat, intPrecision, rnd, inputRange);
6990 const typename In::In1 in1 = samplings.in1.genRandom(floatFormat, intPrecision, rnd, inputRange);
6991 const typename In::In2 in2 = samplings.in2.genRandom(floatFormat, intPrecision, rnd, inputRange);
6992 const typename In::In3 in3 = samplings.in3.genRandom(floatFormat, intPrecision, rnd, inputRange);
6993 const InTuple<In> tuple(in0, in1, in2, in3);
6994
6995 if (de::contains(seenInputs, tuple))
6996 continue;
6997
6998 seenInputs.insert(tuple);
6999 ret.in0.push_back(in0);
7000 ret.in1.push_back(in1);
7001 ret.in2.push_back(in2);
7002 ret.in3.push_back(in3);
7003 }
7004
7005 return ret;
7006 }
7007
7008 class FuncCaseBase : public PrecisionCase
7009 {
7010 protected:
FuncCaseBase(const CaseContext & context,const string & name,const FuncBase & func)7011 FuncCaseBase(const CaseContext &context, const string &name, const FuncBase &func)
7012 : PrecisionCase(context, name,
7013 func.getInputRange(!context.isFloat64b &&
7014 (context.precision == glu::PRECISION_LAST || context.isPackFloat16b)),
7015 func.getRequiredExtension())
7016 {
7017 }
7018
7019 StatementP m_stmt;
7020 };
7021
7022 template <typename Sig>
7023 class FuncCase : public FuncCaseBase
7024 {
7025 public:
7026 typedef Func<Sig> CaseFunc;
7027 typedef typename Sig::Ret Ret;
7028 typedef typename Sig::Arg0 Arg0;
7029 typedef typename Sig::Arg1 Arg1;
7030 typedef typename Sig::Arg2 Arg2;
7031 typedef typename Sig::Arg3 Arg3;
7032 typedef InTypes<Arg0, Arg1, Arg2, Arg3> In;
7033 typedef OutTypes<Ret> Out;
7034
FuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)7035 FuncCase(const CaseContext &context, const string &name, const CaseFunc &func, bool modularOp = false)
7036 : FuncCaseBase(context, name, func)
7037 , m_func(func)
7038 , m_modularOp(modularOp)
7039 {
7040 buildTest();
7041 }
7042
createInstance(Context & context) const7043 virtual TestInstance *createInstance(Context &context) const
7044 {
7045 return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(),
7046 m_stmt, m_modularOp);
7047 }
7048
7049 protected:
7050 void buildTest(void);
getSamplings(void) const7051 virtual const Samplings<In> &getSamplings(void) const
7052 {
7053 return instance<DefaultSamplings<In>>();
7054 }
7055
7056 private:
7057 const CaseFunc &m_func;
7058 Variables<In, Out> m_variables;
7059 bool m_modularOp;
7060 };
7061
7062 template <typename Sig>
buildTest(void)7063 void FuncCase<Sig>::buildTest(void)
7064 {
7065 m_variables.out0 = variable<Ret>("out0");
7066 m_variables.out1 = variable<Void>("out1");
7067 m_variables.in0 = variable<Arg0>("in0");
7068 m_variables.in1 = variable<Arg1>("in1");
7069 m_variables.in2 = variable<Arg2>("in2");
7070 m_variables.in3 = variable<Arg3>("in3");
7071
7072 {
7073 ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.in1, m_variables.in2, m_variables.in3);
7074 m_stmt = variableAssignment(m_variables.out0, expr);
7075
7076 this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
7077 }
7078 }
7079
7080 template <typename Sig>
7081 class InOutFuncCase : public FuncCaseBase
7082 {
7083 public:
7084 typedef Func<Sig> CaseFunc;
7085 typedef typename Sig::Ret Ret;
7086 typedef typename Sig::Arg0 Arg0;
7087 typedef typename Sig::Arg1 Arg1;
7088 typedef typename Sig::Arg2 Arg2;
7089 typedef typename Sig::Arg3 Arg3;
7090 typedef InTypes<Arg0, Arg2, Arg3> In;
7091 typedef OutTypes<Ret, Arg1> Out;
7092
InOutFuncCase(const CaseContext & context,const string & name,const CaseFunc & func,bool modularOp=false)7093 InOutFuncCase(const CaseContext &context, const string &name, const CaseFunc &func, bool modularOp = false)
7094 : FuncCaseBase(context, name, func)
7095 , m_func(func)
7096 , m_modularOp(modularOp)
7097 {
7098 buildTest();
7099 }
createInstance(Context & context) const7100 virtual TestInstance *createInstance(Context &context) const
7101 {
7102 return new BuiltinPrecisionCaseTestInstance<In, Out>(context, m_ctx, m_spec, m_variables, getSamplings(),
7103 m_stmt, m_modularOp);
7104 }
7105
7106 protected:
7107 void buildTest(void);
getSamplings(void) const7108 virtual const Samplings<In> &getSamplings(void) const
7109 {
7110 return instance<DefaultSamplings<In>>();
7111 }
7112
7113 private:
7114 const CaseFunc &m_func;
7115 Variables<In, Out> m_variables;
7116 bool m_modularOp;
7117 };
7118
7119 template <typename Sig>
buildTest(void)7120 void InOutFuncCase<Sig>::buildTest(void)
7121 {
7122 m_variables.out0 = variable<Ret>("out0");
7123 m_variables.out1 = variable<Arg1>("out1");
7124 m_variables.in0 = variable<Arg0>("in0");
7125 m_variables.in1 = variable<Arg2>("in1");
7126 m_variables.in2 = variable<Arg3>("in2");
7127 m_variables.in3 = variable<Void>("in3");
7128
7129 {
7130 ExprP<Ret> expr = applyVar(m_func, m_variables.in0, m_variables.out1, m_variables.in1, m_variables.in2);
7131 m_stmt = variableAssignment(m_variables.out0, expr);
7132
7133 this->testStatement(m_variables, *m_stmt, m_func.getSpirvCase());
7134 }
7135 }
7136
7137 template <typename Sig>
createFuncCase(const CaseContext & context,const string & name,const Func<Sig> & func,bool modularOp=false)7138 PrecisionCase *createFuncCase(const CaseContext &context, const string &name, const Func<Sig> &func,
7139 bool modularOp = false)
7140 {
7141 switch (func.getOutParamIndex())
7142 {
7143 case -1:
7144 return new FuncCase<Sig>(context, name, func, modularOp);
7145 case 1:
7146 return new InOutFuncCase<Sig>(context, name, func, modularOp);
7147 default:
7148 DE_FATAL("Impossible");
7149 }
7150 return DE_NULL;
7151 }
7152
7153 class CaseFactory
7154 {
7155 public:
~CaseFactory(void)7156 virtual ~CaseFactory(void)
7157 {
7158 }
7159 virtual MovePtr<TestNode> createCase(const CaseContext &ctx) const = 0;
7160 virtual string getName(void) const = 0;
7161 virtual string getDesc(void) const = 0;
7162 };
7163
7164 class FuncCaseFactory : public CaseFactory
7165 {
7166 public:
7167 virtual const FuncBase &getFunc(void) const = 0;
getName(void) const7168 string getName(void) const
7169 {
7170 return de::toLower(getFunc().getName());
7171 }
getDesc(void) const7172 string getDesc(void) const
7173 {
7174 return "Function '" + getFunc().getName() + "'";
7175 }
7176 };
7177
7178 template <typename Sig>
7179 class GenFuncCaseFactory : public CaseFactory
7180 {
7181 public:
GenFuncCaseFactory(const GenFuncs<Sig> & funcs,const string & name,bool modularOp=false)7182 GenFuncCaseFactory(const GenFuncs<Sig> &funcs, const string &name, bool modularOp = false)
7183 : m_funcs(funcs)
7184 , m_name(de::toLower(name))
7185 , m_modularOp(modularOp)
7186 {
7187 }
7188
createCase(const CaseContext & ctx) const7189 MovePtr<TestNode> createCase(const CaseContext &ctx) const
7190 {
7191 TestCaseGroup *group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
7192
7193 group->addChild(createFuncCase(ctx, "scalar", m_funcs.func, m_modularOp));
7194 group->addChild(createFuncCase(ctx, "vec2", m_funcs.func2, m_modularOp));
7195 group->addChild(createFuncCase(ctx, "vec3", m_funcs.func3, m_modularOp));
7196 group->addChild(createFuncCase(ctx, "vec4", m_funcs.func4, m_modularOp));
7197 return MovePtr<TestNode>(group);
7198 }
7199
getName(void) const7200 string getName(void) const
7201 {
7202 return m_name;
7203 }
getDesc(void) const7204 string getDesc(void) const
7205 {
7206 return "Function '" + m_funcs.func.getName() + "'";
7207 }
7208
7209 private:
7210 const GenFuncs<Sig> m_funcs;
7211 string m_name;
7212 bool m_modularOp;
7213 };
7214
7215 template <template <int, class> class GenF, typename T>
7216 class TemplateFuncCaseFactory : public FuncCaseFactory
7217 {
7218 public:
createCase(const CaseContext & ctx) const7219 MovePtr<TestNode> createCase(const CaseContext &ctx) const
7220 {
7221 TestCaseGroup *group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
7222
7223 group->addChild(createFuncCase(ctx, "scalar", instance<GenF<1, T>>()));
7224 group->addChild(createFuncCase(ctx, "vec2", instance<GenF<2, T>>()));
7225 group->addChild(createFuncCase(ctx, "vec3", instance<GenF<3, T>>()));
7226 group->addChild(createFuncCase(ctx, "vec4", instance<GenF<4, T>>()));
7227
7228 return MovePtr<TestNode>(group);
7229 }
7230
getFunc(void) const7231 const FuncBase &getFunc(void) const
7232 {
7233 return instance<GenF<1, T>>();
7234 }
7235 };
7236
7237 #ifndef CTS_USES_VULKANSC
7238 template <template <int> class GenF>
7239 class SquareMatrixFuncCaseFactory : public FuncCaseFactory
7240 {
7241 public:
createCase(const CaseContext & ctx) const7242 MovePtr<TestNode> createCase(const CaseContext &ctx) const
7243 {
7244 TestCaseGroup *group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
7245
7246 group->addChild(createFuncCase(ctx, "mat2", instance<GenF<2>>()));
7247
7248 // There is no defined precision for mediump/RelaxedPrecision in Vulkan
7249 if (ctx.name != "mediump")
7250 {
7251 static const char dataDir[] = "builtin/precision/square_matrix";
7252 std::string fileName = getFunc().getName() + "_" + ctx.name;
7253 std::vector<std::string> requirements;
7254
7255 if (ctx.name == "compute")
7256 {
7257 if (ctx.isFloat64b)
7258 {
7259 requirements.push_back("Features.shaderFloat64");
7260 fileName += "_fp64";
7261 }
7262 else
7263 {
7264 requirements.push_back("Float16Int8Features.shaderFloat16");
7265 requirements.push_back("VK_KHR_16bit_storage");
7266 requirements.push_back("VK_KHR_storage_buffer_storage_class");
7267 fileName += "_fp16";
7268
7269 if (ctx.isPackFloat16b == true)
7270 {
7271 fileName += "_32bit";
7272 }
7273 else
7274 {
7275 requirements.push_back("Storage16BitFeatures.storageBuffer16BitAccess");
7276 }
7277 }
7278 }
7279
7280 group->addChild(cts_amber::createAmberTestCase(ctx.testContext, "mat3", "Square matrix 3x3 precision tests",
7281 dataDir, fileName + "_mat_3x3.amber", requirements));
7282 group->addChild(cts_amber::createAmberTestCase(ctx.testContext, "mat4", "Square matrix 4x4 precision tests",
7283 dataDir, fileName + "_mat_4x4.amber", requirements));
7284 }
7285
7286 return MovePtr<TestNode>(group);
7287 }
7288
getFunc(void) const7289 const FuncBase &getFunc(void) const
7290 {
7291 return instance<GenF<2>>();
7292 }
7293 };
7294 #endif // CTS_USES_VULKANSC
7295
7296 template <template <int, int, class> class GenF, typename T>
7297 class MatrixFuncCaseFactory : public FuncCaseFactory
7298 {
7299 public:
createCase(const CaseContext & ctx) const7300 MovePtr<TestNode> createCase(const CaseContext &ctx) const
7301 {
7302 TestCaseGroup *const group = new TestCaseGroup(ctx.testContext, ctx.name.c_str());
7303
7304 this->addCase<2, 2>(ctx, group);
7305 this->addCase<3, 2>(ctx, group);
7306 this->addCase<4, 2>(ctx, group);
7307 this->addCase<2, 3>(ctx, group);
7308 this->addCase<3, 3>(ctx, group);
7309 this->addCase<4, 3>(ctx, group);
7310 this->addCase<2, 4>(ctx, group);
7311 this->addCase<3, 4>(ctx, group);
7312 this->addCase<4, 4>(ctx, group);
7313
7314 return MovePtr<TestNode>(group);
7315 }
7316
getFunc(void) const7317 const FuncBase &getFunc(void) const
7318 {
7319 return instance<GenF<2, 2, T>>();
7320 }
7321
7322 private:
7323 template <int Rows, int Cols>
addCase(const CaseContext & ctx,TestCaseGroup * group) const7324 void addCase(const CaseContext &ctx, TestCaseGroup *group) const
7325 {
7326 const char *const name = dataTypeNameOf<Matrix<float, Rows, Cols>>();
7327 group->addChild(createFuncCase(ctx, name, instance<GenF<Rows, Cols, T>>()));
7328 }
7329 };
7330
7331 template <typename Sig>
7332 class SimpleFuncCaseFactory : public CaseFactory
7333 {
7334 public:
SimpleFuncCaseFactory(const Func<Sig> & func)7335 SimpleFuncCaseFactory(const Func<Sig> &func) : m_func(func)
7336 {
7337 }
7338
createCase(const CaseContext & ctx) const7339 MovePtr<TestNode> createCase(const CaseContext &ctx) const
7340 {
7341 return MovePtr<TestNode>(createFuncCase(ctx, ctx.name.c_str(), m_func));
7342 }
getName(void) const7343 string getName(void) const
7344 {
7345 return de::toLower(m_func.getName());
7346 }
getDesc(void) const7347 string getDesc(void) const
7348 {
7349 return "Function '" + getName() + "'";
7350 }
7351
7352 private:
7353 const Func<Sig> &m_func;
7354 };
7355
7356 template <typename F>
createSimpleFuncCaseFactory(void)7357 SharedPtr<SimpleFuncCaseFactory<typename F::Sig>> createSimpleFuncCaseFactory(void)
7358 {
7359 return SharedPtr<SimpleFuncCaseFactory<typename F::Sig>>(new SimpleFuncCaseFactory<typename F::Sig>(instance<F>()));
7360 }
7361
7362 class CaseFactories
7363 {
7364 public:
~CaseFactories(void)7365 virtual ~CaseFactories(void)
7366 {
7367 }
7368 virtual const std::vector<const CaseFactory *> getFactories(void) const = 0;
7369 };
7370
7371 class BuiltinFuncs : public CaseFactories
7372 {
7373 public:
getFactories(void) const7374 const vector<const CaseFactory *> getFactories(void) const
7375 {
7376 vector<const CaseFactory *> ret;
7377
7378 for (size_t ndx = 0; ndx < m_factories.size(); ++ndx)
7379 ret.push_back(m_factories[ndx].get());
7380
7381 return ret;
7382 }
7383
addFactory(SharedPtr<const CaseFactory> fact)7384 void addFactory(SharedPtr<const CaseFactory> fact)
7385 {
7386 m_factories.push_back(fact);
7387 }
7388
7389 private:
7390 vector<SharedPtr<const CaseFactory>> m_factories;
7391 };
7392
7393 template <typename F>
addScalarFactory(BuiltinFuncs & funcs,string name="",bool modularOp=false)7394 void addScalarFactory(BuiltinFuncs &funcs, string name = "", bool modularOp = false)
7395 {
7396 if (name.empty())
7397 name = instance<F>().getName();
7398
7399 funcs.addFactory(SharedPtr<const CaseFactory>(
7400 new GenFuncCaseFactory<typename F::Sig>(makeVectorizedFuncs<F>(), name, modularOp)));
7401 }
7402
createBuiltinCases()7403 MovePtr<const CaseFactories> createBuiltinCases()
7404 {
7405 MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7406
7407 // Tests for ES3 builtins
7408 addScalarFactory<Comparison<Signature<int, float, float>>>(*funcs);
7409 addScalarFactory<Add<Signature<float, float, float>>>(*funcs);
7410 addScalarFactory<Sub<Signature<float, float, float>>>(*funcs);
7411 addScalarFactory<Mul<Signature<float, float, float>>>(*funcs);
7412 addScalarFactory<Div<Signature<float, float, float>>>(*funcs);
7413
7414 addScalarFactory<Radians>(*funcs);
7415 addScalarFactory<Degrees>(*funcs);
7416 addScalarFactory<Sin<Signature<float, float>>>(*funcs);
7417 addScalarFactory<Cos<Signature<float, float>>>(*funcs);
7418 addScalarFactory<Tan>(*funcs);
7419
7420 addScalarFactory<ASin>(*funcs);
7421 addScalarFactory<ACos>(*funcs);
7422 addScalarFactory<ATan2<Signature<float, float, float>>>(*funcs, "atan2");
7423 addScalarFactory<ATan<Signature<float, float>>>(*funcs);
7424 addScalarFactory<Sinh>(*funcs);
7425 addScalarFactory<Cosh>(*funcs);
7426 addScalarFactory<Tanh>(*funcs);
7427 addScalarFactory<ASinh>(*funcs);
7428 addScalarFactory<ACosh>(*funcs);
7429 addScalarFactory<ATanh>(*funcs);
7430
7431 addScalarFactory<Pow>(*funcs);
7432 addScalarFactory<Exp<Signature<float, float>>>(*funcs);
7433 addScalarFactory<Log<Signature<float, float>>>(*funcs);
7434 addScalarFactory<Exp2<Signature<float, float>>>(*funcs);
7435 addScalarFactory<Log2<Signature<float, float>>>(*funcs);
7436 addScalarFactory<Sqrt32Bit>(*funcs);
7437 addScalarFactory<InverseSqrt<Signature<float, float>>>(*funcs);
7438
7439 addScalarFactory<Abs<Signature<float, float>>>(*funcs);
7440 addScalarFactory<Sign<Signature<float, float>>>(*funcs);
7441 addScalarFactory<Floor32Bit>(*funcs);
7442 addScalarFactory<Trunc32Bit>(*funcs);
7443 addScalarFactory<Round<Signature<float, float>>>(*funcs);
7444 addScalarFactory<RoundEven<Signature<float, float>>>(*funcs);
7445 addScalarFactory<Ceil<Signature<float, float>>>(*funcs);
7446 addScalarFactory<Fract>(*funcs);
7447
7448 addScalarFactory<Mod32Bit>(*funcs, "mod", true);
7449 addScalarFactory<FRem32Bit>(*funcs);
7450
7451 addScalarFactory<Modf32Bit>(*funcs);
7452 addScalarFactory<ModfStruct32Bit>(*funcs);
7453 addScalarFactory<Min<Signature<float, float, float>>>(*funcs);
7454 addScalarFactory<Max<Signature<float, float, float>>>(*funcs);
7455 addScalarFactory<Clamp<Signature<float, float, float, float>>>(*funcs);
7456 addScalarFactory<Mix>(*funcs);
7457 addScalarFactory<Step<Signature<float, float, float>>>(*funcs);
7458 addScalarFactory<SmoothStep<Signature<float, float, float, float>>>(*funcs);
7459
7460 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, float>()));
7461 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, float>()));
7462 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, float>()));
7463 funcs->addFactory(createSimpleFuncCaseFactory<Cross>());
7464 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, float>()));
7465 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, float>()));
7466 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, float>()));
7467 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, float>()));
7468
7469 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, float>()));
7470 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, float>()));
7471 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, float>()));
7472 #ifndef CTS_USES_VULKANSC
7473 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant>()));
7474 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse>()));
7475 #endif // CTS_USES_VULKANSC
7476
7477 addScalarFactory<Frexp32Bit>(*funcs);
7478 addScalarFactory<FrexpStruct32Bit>(*funcs);
7479 addScalarFactory<LdExp<Signature<float, float, int>>>(*funcs);
7480 addScalarFactory<Fma<Signature<float, float, float, float>>>(*funcs);
7481
7482 return MovePtr<const CaseFactories>(funcs.release());
7483 }
7484
createBuiltinDoubleCases()7485 MovePtr<const CaseFactories> createBuiltinDoubleCases()
7486 {
7487 MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7488
7489 // Tests for ES3 builtins
7490 addScalarFactory<Comparison<Signature<int, double, double>>>(*funcs);
7491 addScalarFactory<Add<Signature<double, double, double>>>(*funcs);
7492 addScalarFactory<Sub<Signature<double, double, double>>>(*funcs);
7493 addScalarFactory<Mul<Signature<double, double, double>>>(*funcs);
7494 addScalarFactory<Div<Signature<double, double, double>>>(*funcs);
7495
7496 // Radians, degrees, sin, cos, tan, asin, acos, atan, sinh, cosh, tanh, asinh, acosh, atanh, atan2, pow, exp, log, exp2 and log2
7497 // only work with 16-bit and 32-bit floating point types according to the spec.
7498
7499 addScalarFactory<Sqrt64Bit>(*funcs);
7500 addScalarFactory<InverseSqrt<Signature<double, double>>>(*funcs);
7501
7502 addScalarFactory<Abs<Signature<double, double>>>(*funcs);
7503 addScalarFactory<Sign<Signature<double, double>>>(*funcs);
7504 addScalarFactory<Floor64Bit>(*funcs);
7505 addScalarFactory<Trunc64Bit>(*funcs);
7506 addScalarFactory<Round<Signature<double, double>>>(*funcs);
7507 addScalarFactory<RoundEven<Signature<double, double>>>(*funcs);
7508 addScalarFactory<Ceil<Signature<double, double>>>(*funcs);
7509 addScalarFactory<Fract64Bit>(*funcs);
7510
7511 addScalarFactory<Mod64Bit>(*funcs, "mod", true);
7512 addScalarFactory<FRem64Bit>(*funcs);
7513
7514 addScalarFactory<Modf64Bit>(*funcs);
7515 addScalarFactory<ModfStruct64Bit>(*funcs);
7516 addScalarFactory<Min<Signature<double, double, double>>>(*funcs);
7517 addScalarFactory<Max<Signature<double, double, double>>>(*funcs);
7518 addScalarFactory<Clamp<Signature<double, double, double, double>>>(*funcs);
7519 addScalarFactory<Mix64Bit>(*funcs);
7520 addScalarFactory<Step<Signature<double, double, double>>>(*funcs);
7521 addScalarFactory<SmoothStep<Signature<double, double, double, double>>>(*funcs);
7522
7523 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, double>()));
7524 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, double>()));
7525 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, double>()));
7526 funcs->addFactory(createSimpleFuncCaseFactory<Cross64Bit>());
7527 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, double>()));
7528 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, double>()));
7529 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, double>()));
7530 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, double>()));
7531
7532 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<MatrixCompMult, double>()));
7533 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, double>()));
7534 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, double>()));
7535 #ifndef CTS_USES_VULKANSC
7536 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant64bit>()));
7537 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse64bit>()));
7538 #endif // CTS_USES_VULKANSC
7539
7540 addScalarFactory<Frexp64Bit>(*funcs);
7541 addScalarFactory<FrexpStruct64Bit>(*funcs);
7542 addScalarFactory<LdExp<Signature<double, double, int>>>(*funcs);
7543 addScalarFactory<Fma<Signature<double, double, double, double>>>(*funcs);
7544
7545 return MovePtr<const CaseFactories>(funcs.release());
7546 }
7547
createBuiltinCases16Bit(void)7548 MovePtr<const CaseFactories> createBuiltinCases16Bit(void)
7549 {
7550 MovePtr<BuiltinFuncs> funcs(new BuiltinFuncs());
7551
7552 addScalarFactory<Comparison<Signature<int, deFloat16, deFloat16>>>(*funcs);
7553 addScalarFactory<Add<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7554 addScalarFactory<Sub<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7555 addScalarFactory<Mul<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7556 addScalarFactory<Div<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7557
7558 addScalarFactory<Radians16>(*funcs);
7559 addScalarFactory<Degrees16>(*funcs);
7560
7561 addScalarFactory<Sin<Signature<deFloat16, deFloat16>>>(*funcs);
7562 addScalarFactory<Cos<Signature<deFloat16, deFloat16>>>(*funcs);
7563 addScalarFactory<Tan16Bit>(*funcs);
7564 addScalarFactory<ASin16Bit>(*funcs);
7565 addScalarFactory<ACos16Bit>(*funcs);
7566 addScalarFactory<ATan2<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs, "atan2");
7567 addScalarFactory<ATan<Signature<deFloat16, deFloat16>>>(*funcs);
7568
7569 addScalarFactory<Sinh16Bit>(*funcs);
7570 addScalarFactory<Cosh16Bit>(*funcs);
7571 addScalarFactory<Tanh16Bit>(*funcs);
7572 addScalarFactory<ASinh16Bit>(*funcs);
7573 addScalarFactory<ACosh16Bit>(*funcs);
7574 addScalarFactory<ATanh16Bit>(*funcs);
7575
7576 addScalarFactory<Pow16>(*funcs);
7577 addScalarFactory<Exp<Signature<deFloat16, deFloat16>>>(*funcs);
7578 addScalarFactory<Log<Signature<deFloat16, deFloat16>>>(*funcs);
7579 addScalarFactory<Exp2<Signature<deFloat16, deFloat16>>>(*funcs);
7580 addScalarFactory<Log2<Signature<deFloat16, deFloat16>>>(*funcs);
7581 addScalarFactory<Sqrt16Bit>(*funcs);
7582 addScalarFactory<InverseSqrt16Bit>(*funcs);
7583
7584 addScalarFactory<Abs<Signature<deFloat16, deFloat16>>>(*funcs);
7585 addScalarFactory<Sign<Signature<deFloat16, deFloat16>>>(*funcs);
7586 addScalarFactory<Floor16Bit>(*funcs);
7587 addScalarFactory<Trunc16Bit>(*funcs);
7588 addScalarFactory<Round<Signature<deFloat16, deFloat16>>>(*funcs);
7589 addScalarFactory<RoundEven<Signature<deFloat16, deFloat16>>>(*funcs);
7590 addScalarFactory<Ceil<Signature<deFloat16, deFloat16>>>(*funcs);
7591 addScalarFactory<Fract16Bit>(*funcs);
7592
7593 addScalarFactory<Mod16Bit>(*funcs, "mod", true);
7594 addScalarFactory<FRem16Bit>(*funcs);
7595
7596 addScalarFactory<Modf16Bit>(*funcs);
7597 addScalarFactory<ModfStruct16Bit>(*funcs);
7598 addScalarFactory<Min<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7599 addScalarFactory<Max<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7600 addScalarFactory<Clamp<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>>(*funcs);
7601 addScalarFactory<Mix16Bit>(*funcs);
7602 addScalarFactory<Step<Signature<deFloat16, deFloat16, deFloat16>>>(*funcs);
7603 addScalarFactory<SmoothStep<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>>(*funcs);
7604
7605 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Length, deFloat16>()));
7606 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Distance, deFloat16>()));
7607 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Dot, deFloat16>()));
7608 funcs->addFactory(createSimpleFuncCaseFactory<Cross16Bit>());
7609 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Normalize, deFloat16>()));
7610 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<FaceForward, deFloat16>()));
7611 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Reflect, deFloat16>()));
7612 funcs->addFactory(SharedPtr<const CaseFactory>(new TemplateFuncCaseFactory<Refract, deFloat16>()));
7613
7614 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<OuterProduct, deFloat16>()));
7615 funcs->addFactory(SharedPtr<const CaseFactory>(new MatrixFuncCaseFactory<Transpose, deFloat16>()));
7616 #ifndef CTS_USES_VULKANSC
7617 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Determinant16bit>()));
7618 funcs->addFactory(SharedPtr<const CaseFactory>(new SquareMatrixFuncCaseFactory<Inverse16bit>()));
7619 #endif // CTS_USES_VULKANSC
7620
7621 addScalarFactory<Frexp16Bit>(*funcs);
7622 addScalarFactory<FrexpStruct16Bit>(*funcs);
7623 addScalarFactory<LdExp<Signature<deFloat16, deFloat16, int>>>(*funcs);
7624 addScalarFactory<Fma<Signature<deFloat16, deFloat16, deFloat16, deFloat16>>>(*funcs);
7625
7626 return MovePtr<const CaseFactories>(funcs.release());
7627 }
7628
createFuncGroup(TestContext & ctx,const CaseFactory & factory,int numRandoms)7629 TestCaseGroup *createFuncGroup(TestContext &ctx, const CaseFactory &factory, int numRandoms)
7630 {
7631 TestCaseGroup *const group = new TestCaseGroup(ctx, factory.getName().c_str());
7632 const FloatFormat highp(-126, 127, 23, true,
7633 tcu::MAYBE, // subnormals
7634 tcu::YES, // infinities
7635 tcu::MAYBE); // NaN
7636 const FloatFormat mediump(-14, 13, 10, false, tcu::MAYBE);
7637
7638 for (int precNdx = glu::PRECISION_MEDIUMP; precNdx < glu::PRECISION_LAST; ++precNdx)
7639 {
7640 const Precision precision = Precision(precNdx);
7641 const string precName(glu::getPrecisionName(precision));
7642 const FloatFormat &fmt = precNdx == glu::PRECISION_MEDIUMP ? mediump : highp;
7643
7644 const CaseContext caseCtx(precName, ctx, fmt, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms);
7645
7646 group->addChild(factory.createCase(caseCtx).release());
7647 }
7648
7649 return group;
7650 }
7651
createFuncGroupDouble(TestContext & ctx,const CaseFactory & factory,int numRandoms)7652 TestCaseGroup *createFuncGroupDouble(TestContext &ctx, const CaseFactory &factory, int numRandoms)
7653 {
7654 TestCaseGroup *const group = new TestCaseGroup(ctx, factory.getName().c_str());
7655 const Precision precision = Precision(glu::PRECISION_LAST);
7656 const FloatFormat highp(-1022, 1023, 52, true,
7657 tcu::MAYBE, // subnormals
7658 tcu::YES, // infinities
7659 tcu::MAYBE); // NaN
7660
7661 PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_64BIT_SHADER_FLOAT;
7662
7663 const CaseContext caseCtx("compute", ctx, highp, highp, precision, glu::SHADERTYPE_COMPUTE, numRandoms,
7664 precisionTestFeatures, false, true);
7665 group->addChild(factory.createCase(caseCtx).release());
7666
7667 return group;
7668 }
7669
createFuncGroup16Bit(TestContext & ctx,const CaseFactory & factory,int numRandoms,bool storage32)7670 TestCaseGroup *createFuncGroup16Bit(TestContext &ctx, const CaseFactory &factory, int numRandoms, bool storage32)
7671 {
7672 TestCaseGroup *const group = new TestCaseGroup(ctx, factory.getName().c_str());
7673 const Precision precision = Precision(glu::PRECISION_LAST);
7674 const FloatFormat float16(-14, 15, 10, true, tcu::MAYBE);
7675
7676 PrecisionTestFeatures precisionTestFeatures = PRECISION_TEST_FEATURES_16BIT_SHADER_FLOAT;
7677 if (!storage32)
7678 precisionTestFeatures |= PRECISION_TEST_FEATURES_16BIT_UNIFORM_AND_STORAGE_BUFFER_ACCESS;
7679
7680 const CaseContext caseCtx("compute", ctx, float16, float16, precision, glu::SHADERTYPE_COMPUTE, numRandoms,
7681 precisionTestFeatures, storage32);
7682 group->addChild(factory.createCase(caseCtx).release());
7683
7684 return group;
7685 }
7686
7687 const int defRandoms = 16384;
7688
addBuiltinPrecisionTests(TestContext & ctx,TestCaseGroup & dstGroup,const bool test16Bit=false,const bool storage32Bit=false)7689 void addBuiltinPrecisionTests(TestContext &ctx, TestCaseGroup &dstGroup, const bool test16Bit = false,
7690 const bool storage32Bit = false)
7691 {
7692 const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7693 const int numRandoms = userRandoms > 0 ? userRandoms : defRandoms;
7694
7695 MovePtr<const CaseFactories> cases =
7696 (test16Bit && !storage32Bit) ? createBuiltinCases16Bit() : createBuiltinCases();
7697 for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7698 {
7699 if (!test16Bit)
7700 dstGroup.addChild(createFuncGroup(ctx, *cases->getFactories()[ndx], numRandoms));
7701 else
7702 dstGroup.addChild(createFuncGroup16Bit(ctx, *cases->getFactories()[ndx], numRandoms, storage32Bit));
7703 }
7704 }
7705
addBuiltinPrecisionDoubleTests(TestContext & ctx,TestCaseGroup & dstGroup)7706 void addBuiltinPrecisionDoubleTests(TestContext &ctx, TestCaseGroup &dstGroup)
7707 {
7708 const int userRandoms = ctx.getCommandLine().getTestIterationCount();
7709 const int numRandoms = userRandoms > 0 ? userRandoms : defRandoms;
7710
7711 MovePtr<const CaseFactories> cases = createBuiltinDoubleCases();
7712 for (size_t ndx = 0; ndx < cases->getFactories().size(); ++ndx)
7713 {
7714 dstGroup.addChild(createFuncGroupDouble(ctx, *cases->getFactories()[ndx], numRandoms));
7715 }
7716 }
7717
BuiltinPrecisionTests(tcu::TestContext & testCtx)7718 BuiltinPrecisionTests::BuiltinPrecisionTests(tcu::TestContext &testCtx) : tcu::TestCaseGroup(testCtx, "precision")
7719 {
7720 }
7721
~BuiltinPrecisionTests(void)7722 BuiltinPrecisionTests::~BuiltinPrecisionTests(void)
7723 {
7724 }
7725
init(void)7726 void BuiltinPrecisionTests::init(void)
7727 {
7728 addBuiltinPrecisionTests(m_testCtx, *this);
7729 }
7730
BuiltinPrecisionDoubleTests(tcu::TestContext & testCtx)7731 BuiltinPrecisionDoubleTests::BuiltinPrecisionDoubleTests(tcu::TestContext &testCtx)
7732 : tcu::TestCaseGroup(testCtx, "precision_double")
7733 {
7734 }
7735
~BuiltinPrecisionDoubleTests(void)7736 BuiltinPrecisionDoubleTests::~BuiltinPrecisionDoubleTests(void)
7737 {
7738 }
7739
init(void)7740 void BuiltinPrecisionDoubleTests::init(void)
7741 {
7742 addBuiltinPrecisionDoubleTests(m_testCtx, *this);
7743 }
7744
BuiltinPrecision16BitTests(tcu::TestContext & testCtx)7745 BuiltinPrecision16BitTests::BuiltinPrecision16BitTests(tcu::TestContext &testCtx)
7746 : tcu::TestCaseGroup(testCtx, "precision_fp16_storage16b")
7747 {
7748 }
7749
~BuiltinPrecision16BitTests(void)7750 BuiltinPrecision16BitTests::~BuiltinPrecision16BitTests(void)
7751 {
7752 }
7753
init(void)7754 void BuiltinPrecision16BitTests::init(void)
7755 {
7756 addBuiltinPrecisionTests(m_testCtx, *this, true);
7757 }
7758
BuiltinPrecision16Storage32BitTests(tcu::TestContext & testCtx)7759 BuiltinPrecision16Storage32BitTests::BuiltinPrecision16Storage32BitTests(tcu::TestContext &testCtx)
7760 : tcu::TestCaseGroup(testCtx, "precision_fp16_storage32b")
7761 {
7762 }
7763
~BuiltinPrecision16Storage32BitTests(void)7764 BuiltinPrecision16Storage32BitTests::~BuiltinPrecision16Storage32BitTests(void)
7765 {
7766 }
7767
init(void)7768 void BuiltinPrecision16Storage32BitTests::init(void)
7769 {
7770 addBuiltinPrecisionTests(m_testCtx, *this, true, true);
7771 }
7772
7773 } // namespace shaderexecutor
7774 } // namespace vkt
7775