xref: /aosp_15_r20/external/OpenCL-CTS/test_conformance/subgroups/subhelpers.h (revision 6467f958c7de8070b317fc65bcb0f6472e388d82)
1 //
2 // Copyright (c) 2017 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBHELPERS_H
17 #define SUBHELPERS_H
18 
19 #include "testHarness.h"
20 #include "kernelHelpers.h"
21 #include "typeWrappers.h"
22 #include "imageHelpers.h"
23 
24 #include <limits>
25 #include <vector>
26 #include <type_traits>
27 #include <bitset>
28 #include <regex>
29 #include <map>
30 
31 #define NR_OF_ACTIVE_WORK_ITEMS 4
32 
33 extern MTdata gMTdata;
34 typedef std::bitset<128> bs128;
35 extern cl_half_rounding_mode g_rounding_mode;
36 
37 bs128 cl_uint4_to_bs128(cl_uint4 v);
38 cl_uint4 bs128_to_cl_uint4(bs128 v);
39 cl_uint4 generate_bit_mask(cl_uint subgroup_local_id,
40                            const std::string &mask_type,
41                            cl_uint max_sub_group_size);
42 
43 // limit possible input values to avoid arithmetic rounding/overflow issues.
44 // for each subgroup values defined different values
45 // for rest of workitems set 1 shuffle values
46 void fill_and_shuffle_safe_values(std::vector<cl_ulong> &safe_values,
47                                   size_t sb_size);
48 
49 struct WorkGroupParams
50 {
51 
52     WorkGroupParams(size_t gws, size_t lws, int dm_arg = -1, int cs_arg = -1)
global_workgroup_sizeWorkGroupParams53         : global_workgroup_size(gws), local_workgroup_size(lws),
54           divergence_mask_arg(dm_arg), cluster_size_arg(cs_arg)
55     {
56         subgroup_size = 0;
57         cluster_size = 0;
58         work_items_mask = 0;
59         use_core_subgroups = true;
60         dynsc = 0;
61         load_masks();
62     }
63     size_t global_workgroup_size;
64     size_t local_workgroup_size;
65     size_t subgroup_size;
66     cl_uint cluster_size;
67     bs128 work_items_mask;
68     size_t dynsc;
69     bool use_core_subgroups;
70     std::vector<bs128> all_work_item_masks;
71     int divergence_mask_arg;
72     int cluster_size_arg;
73     void save_kernel_source(const std::string &source, std::string name = "")
74     {
75         if (name == "")
76         {
77             name = "default";
78         }
79         if (kernel_function_name.find(name) != kernel_function_name.end())
80         {
81             log_info("Kernel definition duplication. Source will be "
82                      "overwritten for function name %s\n",
83                      name.c_str());
84         }
85         kernel_function_name[name] = source;
86     };
87     // return specific defined kernel or default.
get_kernel_sourceWorkGroupParams88     std::string get_kernel_source(std::string name)
89     {
90         if (kernel_function_name.find(name) == kernel_function_name.end())
91         {
92             return kernel_function_name["default"];
93         }
94         return kernel_function_name[name];
95     }
96 
97 
98 private:
99     std::map<std::string, std::string> kernel_function_name;
load_masksWorkGroupParams100     void load_masks()
101     {
102         if (divergence_mask_arg != -1)
103         {
104             // 1 in string will be set 1, 0 will be set 0
105             bs128 mask_0xf0f0f0f0("11110000111100001111000011110000"
106                                   "11110000111100001111000011110000"
107                                   "11110000111100001111000011110000"
108                                   "11110000111100001111000011110000",
109                                   128, '0', '1');
110             all_work_item_masks.push_back(mask_0xf0f0f0f0);
111             // 1 in string will be set 0, 0 will be set 1
112             bs128 mask_0x0f0f0f0f("11110000111100001111000011110000"
113                                   "11110000111100001111000011110000"
114                                   "11110000111100001111000011110000"
115                                   "11110000111100001111000011110000",
116                                   128, '1', '0');
117             all_work_item_masks.push_back(mask_0x0f0f0f0f);
118             bs128 mask_0x5555aaaa("10101010101010101010101010101010"
119                                   "10101010101010101010101010101010"
120                                   "10101010101010101010101010101010"
121                                   "10101010101010101010101010101010",
122                                   128, '0', '1');
123             all_work_item_masks.push_back(mask_0x5555aaaa);
124             bs128 mask_0xaaaa5555("10101010101010101010101010101010"
125                                   "10101010101010101010101010101010"
126                                   "10101010101010101010101010101010"
127                                   "10101010101010101010101010101010",
128                                   128, '1', '0');
129             all_work_item_masks.push_back(mask_0xaaaa5555);
130             // 0x0f0ff0f0
131             bs128 mask_0x0f0ff0f0("00001111000011111111000011110000"
132                                   "00001111000011111111000011110000"
133                                   "00001111000011111111000011110000"
134                                   "00001111000011111111000011110000",
135                                   128, '0', '1');
136             all_work_item_masks.push_back(mask_0x0f0ff0f0);
137             // 0xff0000ff
138             bs128 mask_0xff0000ff("11111111000000000000000011111111"
139                                   "11111111000000000000000011111111"
140                                   "11111111000000000000000011111111"
141                                   "11111111000000000000000011111111",
142                                   128, '0', '1');
143             all_work_item_masks.push_back(mask_0xff0000ff);
144             // 0xff00ff00
145             bs128 mask_0xff00ff00("11111111000000001111111100000000"
146                                   "11111111000000001111111100000000"
147                                   "11111111000000001111111100000000"
148                                   "11111111000000001111111100000000",
149                                   128, '0', '1');
150             all_work_item_masks.push_back(mask_0xff00ff00);
151             // 0x00ffff00
152             bs128 mask_0x00ffff00("00000000111111111111111100000000"
153                                   "00000000111111111111111100000000"
154                                   "00000000111111111111111100000000"
155                                   "00000000111111111111111100000000",
156                                   128, '0', '1');
157             all_work_item_masks.push_back(mask_0x00ffff00);
158             // 0x80 1 workitem highest id for 8 subgroup size
159             bs128 mask_0x80808080("10000000100000001000000010000000"
160                                   "10000000100000001000000010000000"
161                                   "10000000100000001000000010000000"
162                                   "10000000100000001000000010000000",
163                                   128, '0', '1');
164 
165             all_work_item_masks.push_back(mask_0x80808080);
166             // 0x8000 1 workitem highest id for 16 subgroup size
167             bs128 mask_0x80008000("10000000000000001000000000000000"
168                                   "10000000000000001000000000000000"
169                                   "10000000000000001000000000000000"
170                                   "10000000000000001000000000000000",
171                                   128, '0', '1');
172             all_work_item_masks.push_back(mask_0x80008000);
173             // 0x80000000 1 workitem highest id for 32 subgroup size
174             bs128 mask_0x80000000("10000000000000000000000000000000"
175                                   "10000000000000000000000000000000"
176                                   "10000000000000000000000000000000"
177                                   "10000000000000000000000000000000",
178                                   128, '0', '1');
179             all_work_item_masks.push_back(mask_0x80000000);
180             // 0x80000000 00000000 1 workitem highest id for 64 subgroup size
181             // 0x80000000 1 workitem highest id for 32 subgroup size
182             bs128 mask_0x8000000000000000("10000000000000000000000000000000"
183                                           "00000000000000000000000000000000"
184                                           "10000000000000000000000000000000"
185                                           "00000000000000000000000000000000",
186                                           128, '0', '1');
187 
188             all_work_item_masks.push_back(mask_0x8000000000000000);
189             // 0x80000000 00000000 00000000 00000000 1 workitem highest id for
190             // 128 subgroup size
191             bs128 mask_0x80000000000000000000000000000000(
192                 "10000000000000000000000000000000"
193                 "00000000000000000000000000000000"
194                 "00000000000000000000000000000000"
195                 "00000000000000000000000000000000",
196                 128, '0', '1');
197             all_work_item_masks.push_back(
198                 mask_0x80000000000000000000000000000000);
199 
200             bs128 mask_0xffffffff("11111111111111111111111111111111"
201                                   "11111111111111111111111111111111"
202                                   "11111111111111111111111111111111"
203                                   "11111111111111111111111111111111",
204                                   128, '0', '1');
205             all_work_item_masks.push_back(mask_0xffffffff);
206         }
207     }
208 };
209 
210 enum class SubgroupsBroadcastOp
211 {
212     broadcast,
213     broadcast_first,
214     non_uniform_broadcast
215 };
216 
217 enum class NonUniformVoteOp
218 {
219     elect,
220     all,
221     any,
222     all_equal
223 };
224 
225 enum class BallotOp
226 {
227     ballot,
228     inverse_ballot,
229     ballot_bit_extract,
230     ballot_bit_count,
231     ballot_inclusive_scan,
232     ballot_exclusive_scan,
233     ballot_find_lsb,
234     ballot_find_msb,
235     eq_mask,
236     ge_mask,
237     gt_mask,
238     le_mask,
239     lt_mask,
240 };
241 
242 enum class ShuffleOp
243 {
244     shuffle,
245     shuffle_up,
246     shuffle_down,
247     shuffle_xor,
248     rotate,
249     clustered_rotate,
250 };
251 
252 enum class ArithmeticOp
253 {
254     add_,
255     max_,
256     min_,
257     mul_,
258     and_,
259     or_,
260     xor_,
261     logical_and,
262     logical_or,
263     logical_xor
264 };
265 
266 const char *const operation_names(ArithmeticOp operation);
267 const char *const operation_names(BallotOp operation);
268 const char *const operation_names(ShuffleOp operation);
269 const char *const operation_names(NonUniformVoteOp operation);
270 const char *const operation_names(SubgroupsBroadcastOp operation);
271 
272 class subgroupsAPI {
273 public:
subgroupsAPI(cl_platform_id platform,bool use_core_subgroups)274     subgroupsAPI(cl_platform_id platform, bool use_core_subgroups)
275     {
276         static_assert(CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE
277                           == CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE_KHR,
278                       "Enums have to be the same");
279         static_assert(CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE
280                           == CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE_KHR,
281                       "Enums have to be the same");
282         if (use_core_subgroups)
283         {
284             _clGetKernelSubGroupInfo_ptr = &clGetKernelSubGroupInfo;
285             clGetKernelSubGroupInfo_name = "clGetKernelSubGroupInfo";
286         }
287         else
288         {
289             _clGetKernelSubGroupInfo_ptr = (clGetKernelSubGroupInfoKHR_fn)
290                 clGetExtensionFunctionAddressForPlatform(
291                     platform, "clGetKernelSubGroupInfoKHR");
292             clGetKernelSubGroupInfo_name = "clGetKernelSubGroupInfoKHR";
293         }
294     }
clGetKernelSubGroupInfo_ptr()295     clGetKernelSubGroupInfoKHR_fn clGetKernelSubGroupInfo_ptr()
296     {
297         return _clGetKernelSubGroupInfo_ptr;
298     }
299     const char *clGetKernelSubGroupInfo_name;
300 
301 private:
302     clGetKernelSubGroupInfoKHR_fn _clGetKernelSubGroupInfo_ptr;
303 };
304 
305 // Need to defined custom type for vector size = 3 and half type. This is
306 // because of 3-component types are otherwise indistinguishable from the
307 // 4-component types, and because the half type is indistinguishable from some
308 // other 16-bit type (ushort)
309 namespace subgroups {
310 struct cl_char3
311 {
312     ::cl_char3 data;
313 };
314 struct cl_uchar3
315 {
316     ::cl_uchar3 data;
317 };
318 struct cl_short3
319 {
320     ::cl_short3 data;
321 };
322 struct cl_ushort3
323 {
324     ::cl_ushort3 data;
325 };
326 struct cl_int3
327 {
328     ::cl_int3 data;
329 };
330 struct cl_uint3
331 {
332     ::cl_uint3 data;
333 };
334 struct cl_long3
335 {
336     ::cl_long3 data;
337 };
338 struct cl_ulong3
339 {
340     ::cl_ulong3 data;
341 };
342 struct cl_float3
343 {
344     ::cl_float3 data;
345 };
346 struct cl_double3
347 {
348     ::cl_double3 data;
349 };
350 struct cl_half
351 {
352     ::cl_half data;
353 };
354 struct cl_half2
355 {
356     ::cl_half2 data;
357 };
358 struct cl_half3
359 {
360     ::cl_half3 data;
361 };
362 struct cl_half4
363 {
364     ::cl_half4 data;
365 };
366 struct cl_half8
367 {
368     ::cl_half8 data;
369 };
370 struct cl_half16
371 {
372     ::cl_half16 data;
373 };
374 }
375 
376 // Declare operator<< for cl_ types, accessing the .s member.
377 #define OP_OSTREAM(Ty, VecSize)                                                \
378     std::ostream &operator<<(std::ostream &os, const Ty##VecSize &val);
379 
380 // Declare operator<< for subgroups::cl_ types, accessing the .data member and
381 // forwarding to operator<< for the cl_ types.
382 #define OP_OSTREAM_SUBGROUP(Ty, VecSize)                                       \
383     std::ostream &operator<<(std::ostream &os, const Ty##VecSize &val);
384 
385 // Declare operator<< for all vector sizes.
386 #define OP_OSTREAM_ALL_VEC(Ty)                                                 \
387     OP_OSTREAM(Ty, 2)                                                          \
388     OP_OSTREAM(Ty, 4)                                                          \
389     OP_OSTREAM(Ty, 8)                                                          \
390     OP_OSTREAM(Ty, 16)                                                         \
391     OP_OSTREAM_SUBGROUP(subgroups::Ty, 3)
392 
393 OP_OSTREAM_ALL_VEC(cl_char)
OP_OSTREAM_ALL_VEC(cl_uchar)394 OP_OSTREAM_ALL_VEC(cl_uchar)
395 OP_OSTREAM_ALL_VEC(cl_short)
396 OP_OSTREAM_ALL_VEC(cl_ushort)
397 OP_OSTREAM_ALL_VEC(cl_int)
398 OP_OSTREAM_ALL_VEC(cl_uint)
399 OP_OSTREAM_ALL_VEC(cl_long)
400 OP_OSTREAM_ALL_VEC(cl_ulong)
401 OP_OSTREAM_ALL_VEC(cl_float)
402 OP_OSTREAM_ALL_VEC(cl_double)
403 OP_OSTREAM_ALL_VEC(cl_half)
404 OP_OSTREAM_SUBGROUP(subgroups::cl_half, )
405 OP_OSTREAM_SUBGROUP(subgroups::cl_half, 2)
406 OP_OSTREAM_SUBGROUP(subgroups::cl_half, 4)
407 OP_OSTREAM_SUBGROUP(subgroups::cl_half, 8)
408 OP_OSTREAM_SUBGROUP(subgroups::cl_half, 16)
409 
410 #undef OP_OSTREAM
411 #undef OP_OSTREAM_SUBGROUP
412 #undef OP_OSTREAM_ALL_VEC
413 
414 template <typename Ty>
415 std::string print_expected_obtained(const Ty &expected, const Ty &obtained)
416 {
417     std::ostringstream oss;
418     oss << "Expected: " << expected << " Obtained: " << obtained;
419     return oss.str();
420 }
421 
int64_ok(cl_device_id device)422 static bool int64_ok(cl_device_id device)
423 {
424     char profile[128];
425     int error;
426 
427     error = clGetDeviceInfo(device, CL_DEVICE_PROFILE, sizeof(profile),
428                             (void *)&profile, NULL);
429     if (error)
430     {
431         log_info("clGetDeviceInfo failed with CL_DEVICE_PROFILE\n");
432         return false;
433     }
434 
435     if (strcmp(profile, "EMBEDDED_PROFILE") == 0)
436         return is_extension_available(device, "cles_khr_int64");
437 
438     return true;
439 }
440 
double_ok(cl_device_id device)441 static bool double_ok(cl_device_id device)
442 {
443     int error;
444     cl_device_fp_config c;
445     error = clGetDeviceInfo(device, CL_DEVICE_DOUBLE_FP_CONFIG, sizeof(c),
446                             (void *)&c, NULL);
447     if (error)
448     {
449         log_info("clGetDeviceInfo failed with CL_DEVICE_DOUBLE_FP_CONFIG\n");
450         return false;
451     }
452     return c != 0;
453 }
454 
half_ok(cl_device_id device)455 static bool half_ok(cl_device_id device)
456 {
457     int error;
458     cl_device_fp_config c;
459     error = clGetDeviceInfo(device, CL_DEVICE_HALF_FP_CONFIG, sizeof(c),
460                             (void *)&c, NULL);
461     if (error)
462     {
463         log_info("clGetDeviceInfo failed with CL_DEVICE_HALF_FP_CONFIG\n");
464         return false;
465     }
466     return c != 0;
467 }
468 
469 template <typename Ty> struct CommonTypeManager
470 {
471 
nameCommonTypeManager472     static const char *name() { return ""; }
add_typedefCommonTypeManager473     static const char *add_typedef() { return "\n"; }
474     typedef std::false_type is_vector_type;
475     typedef std::false_type is_sb_vector_size3;
476     typedef std::false_type is_sb_vector_type;
477     typedef std::false_type is_sb_scalar_type;
type_supportedCommonTypeManager478     static const bool type_supported(cl_device_id) { return true; }
identify_limitsCommonTypeManager479     static const Ty identify_limits(ArithmeticOp operation)
480     {
481         switch (operation)
482         {
483             case ArithmeticOp::add_: return (Ty)0;
484             case ArithmeticOp::max_: return (std::numeric_limits<Ty>::min)();
485             case ArithmeticOp::min_: return (std::numeric_limits<Ty>::max)();
486             case ArithmeticOp::mul_: return (Ty)1;
487             case ArithmeticOp::and_: return (Ty)~0;
488             case ArithmeticOp::or_: return (Ty)0;
489             case ArithmeticOp::xor_: return (Ty)0;
490             default: log_error("Unknown operation request\n"); break;
491         }
492         return 0;
493     }
494 };
495 
496 template <typename> struct TypeManager;
497 
498 template <> struct TypeManager<cl_int> : public CommonTypeManager<cl_int>
499 {
500     static const char *name() { return "int"; }
501     static const char *add_typedef() { return "typedef int Type;\n"; }
502     static cl_int identify_limits(ArithmeticOp operation)
503     {
504         switch (operation)
505         {
506             case ArithmeticOp::add_: return (cl_int)0;
507             case ArithmeticOp::max_:
508                 return (std::numeric_limits<cl_int>::min)();
509             case ArithmeticOp::min_:
510                 return (std::numeric_limits<cl_int>::max)();
511             case ArithmeticOp::mul_: return (cl_int)1;
512             case ArithmeticOp::and_: return (cl_int)~0;
513             case ArithmeticOp::or_: return (cl_int)0;
514             case ArithmeticOp::xor_: return (cl_int)0;
515             case ArithmeticOp::logical_and: return (cl_int)1;
516             case ArithmeticOp::logical_or: return (cl_int)0;
517             case ArithmeticOp::logical_xor: return (cl_int)0;
518             default: log_error("Unknown operation request\n"); break;
519         }
520         return 0;
521     }
522 };
523 template <> struct TypeManager<cl_int2> : public CommonTypeManager<cl_int2>
524 {
525     static const char *name() { return "int2"; }
526     static const char *add_typedef() { return "typedef int2 Type;\n"; }
527     typedef std::true_type is_vector_type;
528     using scalar_type = cl_int;
529 };
530 template <>
531 struct TypeManager<subgroups::cl_int3>
532     : public CommonTypeManager<subgroups::cl_int3>
533 {
534     static const char *name() { return "int3"; }
535     static const char *add_typedef() { return "typedef int3 Type;\n"; }
536     typedef std::true_type is_sb_vector_size3;
537     using scalar_type = cl_int;
538 };
539 template <> struct TypeManager<cl_int4> : public CommonTypeManager<cl_int4>
540 {
541     static const char *name() { return "int4"; }
542     static const char *add_typedef() { return "typedef int4 Type;\n"; }
543     using scalar_type = cl_int;
544     typedef std::true_type is_vector_type;
545 };
546 template <> struct TypeManager<cl_int8> : public CommonTypeManager<cl_int8>
547 {
548     static const char *name() { return "int8"; }
549     static const char *add_typedef() { return "typedef int8 Type;\n"; }
550     using scalar_type = cl_int;
551     typedef std::true_type is_vector_type;
552 };
553 template <> struct TypeManager<cl_int16> : public CommonTypeManager<cl_int16>
554 {
555     static const char *name() { return "int16"; }
556     static const char *add_typedef() { return "typedef int16 Type;\n"; }
557     using scalar_type = cl_int;
558     typedef std::true_type is_vector_type;
559 };
560 // cl_uint
561 template <> struct TypeManager<cl_uint> : public CommonTypeManager<cl_uint>
562 {
563     static const char *name() { return "uint"; }
564     static const char *add_typedef() { return "typedef uint Type;\n"; }
565 };
566 template <> struct TypeManager<cl_uint2> : public CommonTypeManager<cl_uint2>
567 {
568     static const char *name() { return "uint2"; }
569     static const char *add_typedef() { return "typedef uint2 Type;\n"; }
570     using scalar_type = cl_uint;
571     typedef std::true_type is_vector_type;
572 };
573 template <>
574 struct TypeManager<subgroups::cl_uint3>
575     : public CommonTypeManager<subgroups::cl_uint3>
576 {
577     static const char *name() { return "uint3"; }
578     static const char *add_typedef() { return "typedef uint3 Type;\n"; }
579     typedef std::true_type is_sb_vector_size3;
580     using scalar_type = cl_uint;
581 };
582 template <> struct TypeManager<cl_uint4> : public CommonTypeManager<cl_uint4>
583 {
584     static const char *name() { return "uint4"; }
585     static const char *add_typedef() { return "typedef uint4 Type;\n"; }
586     using scalar_type = cl_uint;
587     typedef std::true_type is_vector_type;
588 };
589 template <> struct TypeManager<cl_uint8> : public CommonTypeManager<cl_uint8>
590 {
591     static const char *name() { return "uint8"; }
592     static const char *add_typedef() { return "typedef uint8 Type;\n"; }
593     using scalar_type = cl_uint;
594     typedef std::true_type is_vector_type;
595 };
596 template <> struct TypeManager<cl_uint16> : public CommonTypeManager<cl_uint16>
597 {
598     static const char *name() { return "uint16"; }
599     static const char *add_typedef() { return "typedef uint16 Type;\n"; }
600     using scalar_type = cl_uint;
601     typedef std::true_type is_vector_type;
602 };
603 // cl_short
604 template <> struct TypeManager<cl_short> : public CommonTypeManager<cl_short>
605 {
606     static const char *name() { return "short"; }
607     static const char *add_typedef() { return "typedef short Type;\n"; }
608 };
609 template <> struct TypeManager<cl_short2> : public CommonTypeManager<cl_short2>
610 {
611     static const char *name() { return "short2"; }
612     static const char *add_typedef() { return "typedef short2 Type;\n"; }
613     using scalar_type = cl_short;
614     typedef std::true_type is_vector_type;
615 };
616 template <>
617 struct TypeManager<subgroups::cl_short3>
618     : public CommonTypeManager<subgroups::cl_short3>
619 {
620     static const char *name() { return "short3"; }
621     static const char *add_typedef() { return "typedef short3 Type;\n"; }
622     typedef std::true_type is_sb_vector_size3;
623     using scalar_type = cl_short;
624 };
625 template <> struct TypeManager<cl_short4> : public CommonTypeManager<cl_short4>
626 {
627     static const char *name() { return "short4"; }
628     static const char *add_typedef() { return "typedef short4 Type;\n"; }
629     using scalar_type = cl_short;
630     typedef std::true_type is_vector_type;
631 };
632 template <> struct TypeManager<cl_short8> : public CommonTypeManager<cl_short8>
633 {
634     static const char *name() { return "short8"; }
635     static const char *add_typedef() { return "typedef short8 Type;\n"; }
636     using scalar_type = cl_short;
637     typedef std::true_type is_vector_type;
638 };
639 template <>
640 struct TypeManager<cl_short16> : public CommonTypeManager<cl_short16>
641 {
642     static const char *name() { return "short16"; }
643     static const char *add_typedef() { return "typedef short16 Type;\n"; }
644     using scalar_type = cl_short;
645     typedef std::true_type is_vector_type;
646 };
647 // cl_ushort
648 template <> struct TypeManager<cl_ushort> : public CommonTypeManager<cl_ushort>
649 {
650     static const char *name() { return "ushort"; }
651     static const char *add_typedef() { return "typedef ushort Type;\n"; }
652 };
653 template <>
654 struct TypeManager<cl_ushort2> : public CommonTypeManager<cl_ushort2>
655 {
656     static const char *name() { return "ushort2"; }
657     static const char *add_typedef() { return "typedef ushort2 Type;\n"; }
658     using scalar_type = cl_ushort;
659     typedef std::true_type is_vector_type;
660 };
661 template <>
662 struct TypeManager<subgroups::cl_ushort3>
663     : public CommonTypeManager<subgroups::cl_ushort3>
664 {
665     static const char *name() { return "ushort3"; }
666     static const char *add_typedef() { return "typedef ushort3 Type;\n"; }
667     typedef std::true_type is_sb_vector_size3;
668     using scalar_type = cl_ushort;
669 };
670 template <>
671 struct TypeManager<cl_ushort4> : public CommonTypeManager<cl_ushort4>
672 {
673     static const char *name() { return "ushort4"; }
674     static const char *add_typedef() { return "typedef ushort4 Type;\n"; }
675     using scalar_type = cl_ushort;
676     typedef std::true_type is_vector_type;
677 };
678 template <>
679 struct TypeManager<cl_ushort8> : public CommonTypeManager<cl_ushort8>
680 {
681     static const char *name() { return "ushort8"; }
682     static const char *add_typedef() { return "typedef ushort8 Type;\n"; }
683     using scalar_type = cl_ushort;
684     typedef std::true_type is_vector_type;
685 };
686 template <>
687 struct TypeManager<cl_ushort16> : public CommonTypeManager<cl_ushort16>
688 {
689     static const char *name() { return "ushort16"; }
690     static const char *add_typedef() { return "typedef ushort16 Type;\n"; }
691     using scalar_type = cl_ushort;
692     typedef std::true_type is_vector_type;
693 };
694 // cl_char
695 template <> struct TypeManager<cl_char> : public CommonTypeManager<cl_char>
696 {
697     static const char *name() { return "char"; }
698     static const char *add_typedef() { return "typedef char Type;\n"; }
699 };
700 template <> struct TypeManager<cl_char2> : public CommonTypeManager<cl_char2>
701 {
702     static const char *name() { return "char2"; }
703     static const char *add_typedef() { return "typedef char2 Type;\n"; }
704     using scalar_type = cl_char;
705     typedef std::true_type is_vector_type;
706 };
707 template <>
708 struct TypeManager<subgroups::cl_char3>
709     : public CommonTypeManager<subgroups::cl_char3>
710 {
711     static const char *name() { return "char3"; }
712     static const char *add_typedef() { return "typedef char3 Type;\n"; }
713     typedef std::true_type is_sb_vector_size3;
714     using scalar_type = cl_char;
715 };
716 template <> struct TypeManager<cl_char4> : public CommonTypeManager<cl_char4>
717 {
718     static const char *name() { return "char4"; }
719     static const char *add_typedef() { return "typedef char4 Type;\n"; }
720     using scalar_type = cl_char;
721     typedef std::true_type is_vector_type;
722 };
723 template <> struct TypeManager<cl_char8> : public CommonTypeManager<cl_char8>
724 {
725     static const char *name() { return "char8"; }
726     static const char *add_typedef() { return "typedef char8 Type;\n"; }
727     using scalar_type = cl_char;
728     typedef std::true_type is_vector_type;
729 };
730 template <> struct TypeManager<cl_char16> : public CommonTypeManager<cl_char16>
731 {
732     static const char *name() { return "char16"; }
733     static const char *add_typedef() { return "typedef char16 Type;\n"; }
734     using scalar_type = cl_char;
735     typedef std::true_type is_vector_type;
736 };
737 // cl_uchar
738 template <> struct TypeManager<cl_uchar> : public CommonTypeManager<cl_uchar>
739 {
740     static const char *name() { return "uchar"; }
741     static const char *add_typedef() { return "typedef uchar Type;\n"; }
742 };
743 template <> struct TypeManager<cl_uchar2> : public CommonTypeManager<cl_uchar2>
744 {
745     static const char *name() { return "uchar2"; }
746     static const char *add_typedef() { return "typedef uchar2 Type;\n"; }
747     using scalar_type = cl_uchar;
748     typedef std::true_type is_vector_type;
749 };
750 template <>
751 struct TypeManager<subgroups::cl_uchar3>
752     : public CommonTypeManager<subgroups::cl_char3>
753 {
754     static const char *name() { return "uchar3"; }
755     static const char *add_typedef() { return "typedef uchar3 Type;\n"; }
756     typedef std::true_type is_sb_vector_size3;
757     using scalar_type = cl_uchar;
758 };
759 template <> struct TypeManager<cl_uchar4> : public CommonTypeManager<cl_uchar4>
760 {
761     static const char *name() { return "uchar4"; }
762     static const char *add_typedef() { return "typedef uchar4 Type;\n"; }
763     using scalar_type = cl_uchar;
764     typedef std::true_type is_vector_type;
765 };
766 template <> struct TypeManager<cl_uchar8> : public CommonTypeManager<cl_uchar8>
767 {
768     static const char *name() { return "uchar8"; }
769     static const char *add_typedef() { return "typedef uchar8 Type;\n"; }
770     using scalar_type = cl_uchar;
771     typedef std::true_type is_vector_type;
772 };
773 template <>
774 struct TypeManager<cl_uchar16> : public CommonTypeManager<cl_uchar16>
775 {
776     static const char *name() { return "uchar16"; }
777     static const char *add_typedef() { return "typedef uchar16 Type;\n"; }
778     using scalar_type = cl_uchar;
779     typedef std::true_type is_vector_type;
780 };
781 // cl_long
782 template <> struct TypeManager<cl_long> : public CommonTypeManager<cl_long>
783 {
784     static const char *name() { return "long"; }
785     static const char *add_typedef() { return "typedef long Type;\n"; }
786     static const bool type_supported(cl_device_id device)
787     {
788         return int64_ok(device);
789     }
790 };
791 template <> struct TypeManager<cl_long2> : public CommonTypeManager<cl_long2>
792 {
793     static const char *name() { return "long2"; }
794     static const char *add_typedef() { return "typedef long2 Type;\n"; }
795     using scalar_type = cl_long;
796     typedef std::true_type is_vector_type;
797     static const bool type_supported(cl_device_id device)
798     {
799         return int64_ok(device);
800     }
801 };
802 template <>
803 struct TypeManager<subgroups::cl_long3>
804     : public CommonTypeManager<subgroups::cl_long3>
805 {
806     static const char *name() { return "long3"; }
807     static const char *add_typedef() { return "typedef long3 Type;\n"; }
808     typedef std::true_type is_sb_vector_size3;
809     using scalar_type = cl_long;
810     static const bool type_supported(cl_device_id device)
811     {
812         return int64_ok(device);
813     }
814 };
815 template <> struct TypeManager<cl_long4> : public CommonTypeManager<cl_long4>
816 {
817     static const char *name() { return "long4"; }
818     static const char *add_typedef() { return "typedef long4 Type;\n"; }
819     using scalar_type = cl_long;
820     typedef std::true_type is_vector_type;
821     static const bool type_supported(cl_device_id device)
822     {
823         return int64_ok(device);
824     }
825 };
826 template <> struct TypeManager<cl_long8> : public CommonTypeManager<cl_long8>
827 {
828     static const char *name() { return "long8"; }
829     static const char *add_typedef() { return "typedef long8 Type;\n"; }
830     using scalar_type = cl_long;
831     typedef std::true_type is_vector_type;
832     static const bool type_supported(cl_device_id device)
833     {
834         return int64_ok(device);
835     }
836 };
837 template <> struct TypeManager<cl_long16> : public CommonTypeManager<cl_long16>
838 {
839     static const char *name() { return "long16"; }
840     static const char *add_typedef() { return "typedef long16 Type;\n"; }
841     using scalar_type = cl_long;
842     typedef std::true_type is_vector_type;
843     static const bool type_supported(cl_device_id device)
844     {
845         return int64_ok(device);
846     }
847 };
848 // cl_ulong
849 template <> struct TypeManager<cl_ulong> : public CommonTypeManager<cl_ulong>
850 {
851     static const char *name() { return "ulong"; }
852     static const char *add_typedef() { return "typedef ulong Type;\n"; }
853     static const bool type_supported(cl_device_id device)
854     {
855         return int64_ok(device);
856     }
857 };
858 template <> struct TypeManager<cl_ulong2> : public CommonTypeManager<cl_ulong2>
859 {
860     static const char *name() { return "ulong2"; }
861     static const char *add_typedef() { return "typedef ulong2 Type;\n"; }
862     using scalar_type = cl_ulong;
863     typedef std::true_type is_vector_type;
864     static const bool type_supported(cl_device_id device)
865     {
866         return int64_ok(device);
867     }
868 };
869 template <>
870 struct TypeManager<subgroups::cl_ulong3>
871     : public CommonTypeManager<subgroups::cl_ulong3>
872 {
873     static const char *name() { return "ulong3"; }
874     static const char *add_typedef() { return "typedef ulong3 Type;\n"; }
875     typedef std::true_type is_sb_vector_size3;
876     using scalar_type = cl_ulong;
877     static const bool type_supported(cl_device_id device)
878     {
879         return int64_ok(device);
880     }
881 };
882 template <> struct TypeManager<cl_ulong4> : public CommonTypeManager<cl_ulong4>
883 {
884     static const char *name() { return "ulong4"; }
885     static const char *add_typedef() { return "typedef ulong4 Type;\n"; }
886     using scalar_type = cl_ulong;
887     typedef std::true_type is_vector_type;
888     static const bool type_supported(cl_device_id device)
889     {
890         return int64_ok(device);
891     }
892 };
893 template <> struct TypeManager<cl_ulong8> : public CommonTypeManager<cl_ulong8>
894 {
895     static const char *name() { return "ulong8"; }
896     static const char *add_typedef() { return "typedef ulong8 Type;\n"; }
897     using scalar_type = cl_ulong;
898     typedef std::true_type is_vector_type;
899     static const bool type_supported(cl_device_id device)
900     {
901         return int64_ok(device);
902     }
903 };
904 template <>
905 struct TypeManager<cl_ulong16> : public CommonTypeManager<cl_ulong16>
906 {
907     static const char *name() { return "ulong16"; }
908     static const char *add_typedef() { return "typedef ulong16 Type;\n"; }
909     using scalar_type = cl_ulong;
910     typedef std::true_type is_vector_type;
911     static const bool type_supported(cl_device_id device)
912     {
913         return int64_ok(device);
914     }
915 };
916 
917 // cl_float
918 template <> struct TypeManager<cl_float> : public CommonTypeManager<cl_float>
919 {
920     static const char *name() { return "float"; }
921     static const char *add_typedef() { return "typedef float Type;\n"; }
922     static cl_float identify_limits(ArithmeticOp operation)
923     {
924         switch (operation)
925         {
926             case ArithmeticOp::add_: return 0.0f;
927             case ArithmeticOp::max_:
928                 return -std::numeric_limits<float>::infinity();
929             case ArithmeticOp::min_:
930                 return std::numeric_limits<float>::infinity();
931             case ArithmeticOp::mul_: return (cl_float)1;
932             default: log_error("Unknown operation request\n"); break;
933         }
934         return 0;
935     }
936 };
937 template <> struct TypeManager<cl_float2> : public CommonTypeManager<cl_float2>
938 {
939     static const char *name() { return "float2"; }
940     static const char *add_typedef() { return "typedef float2 Type;\n"; }
941     using scalar_type = cl_float;
942     typedef std::true_type is_vector_type;
943 };
944 template <>
945 struct TypeManager<subgroups::cl_float3>
946     : public CommonTypeManager<subgroups::cl_float3>
947 {
948     static const char *name() { return "float3"; }
949     static const char *add_typedef() { return "typedef float3 Type;\n"; }
950     typedef std::true_type is_sb_vector_size3;
951     using scalar_type = cl_float;
952 };
953 template <> struct TypeManager<cl_float4> : public CommonTypeManager<cl_float4>
954 {
955     static const char *name() { return "float4"; }
956     static const char *add_typedef() { return "typedef float4 Type;\n"; }
957     using scalar_type = cl_float;
958     typedef std::true_type is_vector_type;
959 };
960 template <> struct TypeManager<cl_float8> : public CommonTypeManager<cl_float8>
961 {
962     static const char *name() { return "float8"; }
963     static const char *add_typedef() { return "typedef float8 Type;\n"; }
964     using scalar_type = cl_float;
965     typedef std::true_type is_vector_type;
966 };
967 template <>
968 struct TypeManager<cl_float16> : public CommonTypeManager<cl_float16>
969 {
970     static const char *name() { return "float16"; }
971     static const char *add_typedef() { return "typedef float16 Type;\n"; }
972     using scalar_type = cl_float;
973     typedef std::true_type is_vector_type;
974 };
975 
976 // cl_double
977 template <> struct TypeManager<cl_double> : public CommonTypeManager<cl_double>
978 {
979     static const char *name() { return "double"; }
980     static const char *add_typedef() { return "typedef double Type;\n"; }
981     static cl_double identify_limits(ArithmeticOp operation)
982     {
983         switch (operation)
984         {
985             case ArithmeticOp::add_: return 0.0;
986             case ArithmeticOp::max_:
987                 return -std::numeric_limits<double>::infinity();
988             case ArithmeticOp::min_:
989                 return std::numeric_limits<double>::infinity();
990             case ArithmeticOp::mul_: return (cl_double)1;
991             default: log_error("Unknown operation request\n"); break;
992         }
993         return 0;
994     }
995     static const bool type_supported(cl_device_id device)
996     {
997         return double_ok(device);
998     }
999 };
1000 template <>
1001 struct TypeManager<cl_double2> : public CommonTypeManager<cl_double2>
1002 {
1003     static const char *name() { return "double2"; }
1004     static const char *add_typedef() { return "typedef double2 Type;\n"; }
1005     using scalar_type = cl_double;
1006     typedef std::true_type is_vector_type;
1007     static const bool type_supported(cl_device_id device)
1008     {
1009         return double_ok(device);
1010     }
1011 };
1012 template <>
1013 struct TypeManager<subgroups::cl_double3>
1014     : public CommonTypeManager<subgroups::cl_double3>
1015 {
1016     static const char *name() { return "double3"; }
1017     static const char *add_typedef() { return "typedef double3 Type;\n"; }
1018     typedef std::true_type is_sb_vector_size3;
1019     using scalar_type = cl_double;
1020     static const bool type_supported(cl_device_id device)
1021     {
1022         return double_ok(device);
1023     }
1024 };
1025 template <>
1026 struct TypeManager<cl_double4> : public CommonTypeManager<cl_double4>
1027 {
1028     static const char *name() { return "double4"; }
1029     static const char *add_typedef() { return "typedef double4 Type;\n"; }
1030     using scalar_type = cl_double;
1031     typedef std::true_type is_vector_type;
1032     static const bool type_supported(cl_device_id device)
1033     {
1034         return double_ok(device);
1035     }
1036 };
1037 template <>
1038 struct TypeManager<cl_double8> : public CommonTypeManager<cl_double8>
1039 {
1040     static const char *name() { return "double8"; }
1041     static const char *add_typedef() { return "typedef double8 Type;\n"; }
1042     using scalar_type = cl_double;
1043     typedef std::true_type is_vector_type;
1044     static const bool type_supported(cl_device_id device)
1045     {
1046         return double_ok(device);
1047     }
1048 };
1049 template <>
1050 struct TypeManager<cl_double16> : public CommonTypeManager<cl_double16>
1051 {
1052     static const char *name() { return "double16"; }
1053     static const char *add_typedef() { return "typedef double16 Type;\n"; }
1054     using scalar_type = cl_double;
1055     typedef std::true_type is_vector_type;
1056     static const bool type_supported(cl_device_id device)
1057     {
1058         return double_ok(device);
1059     }
1060 };
1061 
1062 // cl_half
1063 template <>
1064 struct TypeManager<subgroups::cl_half>
1065     : public CommonTypeManager<subgroups::cl_half>
1066 {
1067     static const char *name() { return "half"; }
1068     static const char *add_typedef() { return "typedef half Type;\n"; }
1069     typedef std::true_type is_sb_scalar_type;
1070     static subgroups::cl_half identify_limits(ArithmeticOp operation)
1071     {
1072         switch (operation)
1073         {
1074             case ArithmeticOp::add_: return { 0x0000 };
1075             case ArithmeticOp::max_: return { 0xfc00 };
1076             case ArithmeticOp::min_: return { 0x7c00 };
1077             case ArithmeticOp::mul_: return { 0x3c00 };
1078             default: log_error("Unknown operation request\n"); break;
1079         }
1080         return { 0 };
1081     }
1082     static const bool type_supported(cl_device_id device)
1083     {
1084         return half_ok(device);
1085     }
1086 };
1087 template <>
1088 struct TypeManager<subgroups::cl_half2>
1089     : public CommonTypeManager<subgroups::cl_half2>
1090 {
1091     static const char *name() { return "half2"; }
1092     static const char *add_typedef() { return "typedef half2 Type;\n"; }
1093     using scalar_type = subgroups::cl_half;
1094     typedef std::true_type is_sb_vector_type;
1095     static const bool type_supported(cl_device_id device)
1096     {
1097         return half_ok(device);
1098     }
1099 };
1100 template <>
1101 struct TypeManager<subgroups::cl_half3>
1102     : public CommonTypeManager<subgroups::cl_half3>
1103 {
1104     static const char *name() { return "half3"; }
1105     static const char *add_typedef() { return "typedef half3 Type;\n"; }
1106     typedef std::true_type is_sb_vector_size3;
1107     using scalar_type = subgroups::cl_half;
1108 
1109     static const bool type_supported(cl_device_id device)
1110     {
1111         return half_ok(device);
1112     }
1113 };
1114 template <>
1115 struct TypeManager<subgroups::cl_half4>
1116     : public CommonTypeManager<subgroups::cl_half4>
1117 {
1118     static const char *name() { return "half4"; }
1119     static const char *add_typedef() { return "typedef half4 Type;\n"; }
1120     using scalar_type = subgroups::cl_half;
1121     typedef std::true_type is_sb_vector_type;
1122     static const bool type_supported(cl_device_id device)
1123     {
1124         return half_ok(device);
1125     }
1126 };
1127 template <>
1128 struct TypeManager<subgroups::cl_half8>
1129     : public CommonTypeManager<subgroups::cl_half8>
1130 {
1131     static const char *name() { return "half8"; }
1132     static const char *add_typedef() { return "typedef half8 Type;\n"; }
1133     using scalar_type = subgroups::cl_half;
1134     typedef std::true_type is_sb_vector_type;
1135 
1136     static const bool type_supported(cl_device_id device)
1137     {
1138         return half_ok(device);
1139     }
1140 };
1141 template <>
1142 struct TypeManager<subgroups::cl_half16>
1143     : public CommonTypeManager<subgroups::cl_half16>
1144 {
1145     static const char *name() { return "half16"; }
1146     static const char *add_typedef() { return "typedef half16 Type;\n"; }
1147     using scalar_type = subgroups::cl_half;
1148     typedef std::true_type is_sb_vector_type;
1149     static const bool type_supported(cl_device_id device)
1150     {
1151         return half_ok(device);
1152     }
1153 };
1154 
1155 // set scalar value to vector of halfs
1156 template <typename Ty, int N = 0>
1157 typename std::enable_if<TypeManager<Ty>::is_sb_vector_type::value>::type
1158 set_value(Ty &lhs, const cl_ulong &rhs)
1159 {
1160     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1161     for (auto i = 0; i < size; ++i)
1162     {
1163         lhs.data.s[i] = rhs;
1164     }
1165 }
1166 
1167 
1168 // set scalar value to vector
1169 template <typename Ty>
1170 typename std::enable_if<TypeManager<Ty>::is_vector_type::value>::type
1171 set_value(Ty &lhs, const cl_ulong &rhs)
1172 {
1173     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1174     for (auto i = 0; i < size; ++i)
1175     {
1176         lhs.s[i] = rhs;
1177     }
1178 }
1179 
1180 // set vector to vector value
1181 template <typename Ty>
1182 typename std::enable_if<TypeManager<Ty>::is_vector_type::value>::type
1183 set_value(Ty &lhs, const Ty &rhs)
1184 {
1185     lhs = rhs;
1186 }
1187 
1188 // set scalar value to vector size 3
1189 template <typename Ty, int N = 0>
1190 typename std::enable_if<TypeManager<Ty>::is_sb_vector_size3::value>::type
1191 set_value(Ty &lhs, const cl_ulong &rhs)
1192 {
1193     for (auto i = 0; i < 3; ++i)
1194     {
1195         lhs.data.s[i] = rhs;
1196     }
1197 }
1198 
1199 // set scalar value to scalar
1200 template <typename Ty>
1201 typename std::enable_if<std::is_scalar<Ty>::value>::type
1202 set_value(Ty &lhs, const cl_ulong &rhs)
1203 {
1204     lhs = static_cast<Ty>(rhs);
1205 }
1206 
1207 // set scalar value to half scalar
1208 template <typename Ty>
1209 typename std::enable_if<TypeManager<Ty>::is_sb_scalar_type::value>::type
1210 set_value(Ty &lhs, const cl_ulong &rhs)
1211 {
1212     lhs.data = cl_half_from_float(static_cast<cl_float>(rhs), g_rounding_mode);
1213 }
1214 
1215 // compare for common vectors
1216 template <typename Ty>
1217 typename std::enable_if<TypeManager<Ty>::is_vector_type::value, bool>::type
1218 compare(const Ty &lhs, const Ty &rhs)
1219 {
1220     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1221     for (auto i = 0; i < size; ++i)
1222     {
1223         if (lhs.s[i] != rhs.s[i])
1224         {
1225             return false;
1226         }
1227     }
1228     return true;
1229 }
1230 
1231 // compare for vectors 3
1232 template <typename Ty>
1233 typename std::enable_if<TypeManager<Ty>::is_sb_vector_size3::value, bool>::type
1234 compare(const Ty &lhs, const Ty &rhs)
1235 {
1236     for (auto i = 0; i < 3; ++i)
1237     {
1238         if (lhs.data.s[i] != rhs.data.s[i])
1239         {
1240             return false;
1241         }
1242     }
1243     return true;
1244 }
1245 
1246 // compare for half vectors
1247 template <typename Ty>
1248 typename std::enable_if<TypeManager<Ty>::is_sb_vector_type::value, bool>::type
1249 compare(const Ty &lhs, const Ty &rhs)
1250 {
1251     const int size = sizeof(Ty) / sizeof(typename TypeManager<Ty>::scalar_type);
1252     for (auto i = 0; i < size; ++i)
1253     {
1254         if (lhs.data.s[i] != rhs.data.s[i])
1255         {
1256             return false;
1257         }
1258     }
1259     return true;
1260 }
1261 
1262 // compare for scalars
1263 template <typename Ty>
1264 typename std::enable_if<std::is_scalar<Ty>::value, bool>::type
1265 compare(const Ty &lhs, const Ty &rhs)
1266 {
1267     return lhs == rhs;
1268 }
1269 
1270 // compare for scalar halfs
1271 template <typename Ty>
1272 typename std::enable_if<TypeManager<Ty>::is_sb_scalar_type::value, bool>::type
1273 compare(const Ty &lhs, const Ty &rhs)
1274 {
1275     return lhs.data == rhs.data;
1276 }
1277 
1278 template <typename Ty> inline bool compare_ordered(const Ty &lhs, const Ty &rhs)
1279 {
1280     return lhs == rhs;
1281 }
1282 
1283 template <>
1284 inline bool compare_ordered(const subgroups::cl_half &lhs,
1285                             const subgroups::cl_half &rhs)
1286 {
1287     return cl_half_to_float(lhs.data) == cl_half_to_float(rhs.data);
1288 }
1289 
1290 template <typename Ty>
1291 inline bool compare_ordered(const subgroups::cl_half &lhs, const int &rhs)
1292 {
1293     return cl_half_to_float(lhs.data) == rhs;
1294 }
1295 
1296 template <typename Ty, typename Fns> class KernelExecutor {
1297 public:
1298     KernelExecutor(cl_context c, cl_command_queue q, cl_kernel k, size_t g,
1299                    size_t l, Ty *id, size_t is, Ty *mid, Ty *mod, cl_int *md,
1300                    size_t ms, Ty *od, size_t os, size_t ts = 0)
1301         : context(c), queue(q), kernel(k), global(g), local(l), idata(id),
1302           isize(is), mapin_data(mid), mapout_data(mod), mdata(md), msize(ms),
1303           odata(od), osize(os), tsize(ts)
1304     {
1305         has_status = false;
1306         run_failed = false;
1307     }
1308     cl_context context;
1309     cl_command_queue queue;
1310     cl_kernel kernel;
1311     size_t global;
1312     size_t local;
1313     Ty *idata;
1314     size_t isize;
1315     Ty *mapin_data;
1316     Ty *mapout_data;
1317     cl_int *mdata;
1318     size_t msize;
1319     Ty *odata;
1320     size_t osize;
1321     size_t tsize;
1322     bool run_failed;
1323 
1324 private:
1325     bool has_status;
1326     test_status status;
1327 
1328 public:
1329     // Run a test kernel to compute the result of a built-in on an input
1330     int run()
1331     {
1332         clMemWrapper in;
1333         clMemWrapper xy;
1334         clMemWrapper out;
1335         clMemWrapper tmp;
1336         int error;
1337 
1338         in = clCreateBuffer(context, CL_MEM_READ_ONLY, isize, NULL, &error);
1339         test_error(error, "clCreateBuffer failed");
1340 
1341         xy = clCreateBuffer(context, CL_MEM_WRITE_ONLY, msize, NULL, &error);
1342         test_error(error, "clCreateBuffer failed");
1343 
1344         out = clCreateBuffer(context, CL_MEM_WRITE_ONLY, osize, NULL, &error);
1345         test_error(error, "clCreateBuffer failed");
1346 
1347         if (tsize)
1348         {
1349             tmp = clCreateBuffer(context,
1350                                  CL_MEM_READ_WRITE | CL_MEM_HOST_NO_ACCESS,
1351                                  tsize, NULL, &error);
1352             test_error(error, "clCreateBuffer failed");
1353         }
1354 
1355         error = clSetKernelArg(kernel, 0, sizeof(in), (void *)&in);
1356         test_error(error, "clSetKernelArg failed");
1357 
1358         error = clSetKernelArg(kernel, 1, sizeof(xy), (void *)&xy);
1359         test_error(error, "clSetKernelArg failed");
1360 
1361         error = clSetKernelArg(kernel, 2, sizeof(out), (void *)&out);
1362         test_error(error, "clSetKernelArg failed");
1363 
1364         if (tsize)
1365         {
1366             error = clSetKernelArg(kernel, 3, sizeof(tmp), (void *)&tmp);
1367             test_error(error, "clSetKernelArg failed");
1368         }
1369 
1370         error = clEnqueueWriteBuffer(queue, in, CL_FALSE, 0, isize, idata, 0,
1371                                      NULL, NULL);
1372         test_error(error, "clEnqueueWriteBuffer failed");
1373 
1374         error = clEnqueueWriteBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0,
1375                                      NULL, NULL);
1376         test_error(error, "clEnqueueWriteBuffer failed");
1377         error = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global, &local,
1378                                        0, NULL, NULL);
1379         test_error(error, "clEnqueueNDRangeKernel failed");
1380 
1381         error = clEnqueueReadBuffer(queue, xy, CL_FALSE, 0, msize, mdata, 0,
1382                                     NULL, NULL);
1383         test_error(error, "clEnqueueReadBuffer failed");
1384 
1385         error = clEnqueueReadBuffer(queue, out, CL_FALSE, 0, osize, odata, 0,
1386                                     NULL, NULL);
1387         test_error(error, "clEnqueueReadBuffer failed");
1388 
1389         error = clFinish(queue);
1390         test_error(error, "clFinish failed");
1391 
1392         return error;
1393     }
1394 
1395 private:
1396     test_status
1397     run_and_check_with_cluster_size(const WorkGroupParams &test_params)
1398     {
1399         cl_int error = run();
1400         if (error != CL_SUCCESS)
1401         {
1402             print_error(error, "Failed to run subgroup test kernel");
1403             status = TEST_FAIL;
1404             run_failed = true;
1405             return status;
1406         }
1407 
1408         test_status tmp_status =
1409             Fns::chk(idata, odata, mapin_data, mapout_data, mdata, test_params);
1410 
1411         if (!has_status || tmp_status == TEST_FAIL
1412             || (tmp_status == TEST_PASS && status != TEST_FAIL))
1413         {
1414             status = tmp_status;
1415             has_status = true;
1416         }
1417 
1418         return status;
1419     }
1420 
1421 public:
1422     test_status run_and_check(WorkGroupParams &test_params)
1423     {
1424         test_status tmp_status = TEST_SKIPPED_ITSELF;
1425 
1426         if (test_params.cluster_size_arg != -1)
1427         {
1428             for (cl_uint cluster_size = 1;
1429                  cluster_size <= test_params.subgroup_size; cluster_size *= 2)
1430             {
1431                 test_params.cluster_size = cluster_size;
1432                 cl_int error =
1433                     clSetKernelArg(kernel, test_params.cluster_size_arg,
1434                                    sizeof(cl_uint), &cluster_size);
1435                 test_error_fail(error, "Unable to set cluster size");
1436 
1437                 tmp_status = run_and_check_with_cluster_size(test_params);
1438 
1439                 if (tmp_status == TEST_FAIL) break;
1440             }
1441         }
1442         else
1443         {
1444             tmp_status = run_and_check_with_cluster_size(test_params);
1445         }
1446 
1447         return tmp_status;
1448     }
1449 };
1450 
1451 // Driver for testing a single built in function
1452 template <typename Ty, typename Fns, size_t TSIZE = 0> struct test
1453 {
1454     static test_status run(cl_device_id device, cl_context context,
1455                            cl_command_queue queue, int num_elements,
1456                            const char *kname, const char *src,
1457                            WorkGroupParams test_params)
1458     {
1459         size_t tmp;
1460         cl_int error;
1461         size_t subgroup_size, num_subgroups;
1462         size_t global = test_params.global_workgroup_size;
1463         size_t local = test_params.local_workgroup_size;
1464         clProgramWrapper program;
1465         clKernelWrapper kernel;
1466         cl_platform_id platform;
1467         std::vector<cl_int> sgmap;
1468         sgmap.resize(4 * global);
1469         std::vector<Ty> mapin;
1470         mapin.resize(local);
1471         std::vector<Ty> mapout;
1472         mapout.resize(local);
1473         std::stringstream kernel_sstr;
1474 
1475         Fns::log_test(test_params, "");
1476 
1477         kernel_sstr << "#define NR_OF_ACTIVE_WORK_ITEMS ";
1478         kernel_sstr << NR_OF_ACTIVE_WORK_ITEMS << "\n";
1479         // Make sure a test of type Ty is supported by the device
1480         if (!TypeManager<Ty>::type_supported(device))
1481         {
1482             log_info("Data type not supported : %s\n", TypeManager<Ty>::name());
1483             return TEST_SKIPPED_ITSELF;
1484         }
1485 
1486         if (strstr(TypeManager<Ty>::name(), "double"))
1487         {
1488             kernel_sstr << "#pragma OPENCL EXTENSION cl_khr_fp64: enable\n";
1489         }
1490         else if (strstr(TypeManager<Ty>::name(), "half"))
1491         {
1492             kernel_sstr << "#pragma OPENCL EXTENSION cl_khr_fp16: enable\n";
1493         }
1494 
1495         error = clGetDeviceInfo(device, CL_DEVICE_PLATFORM, sizeof(platform),
1496                                 (void *)&platform, NULL);
1497         test_error_fail(error, "clGetDeviceInfo failed for CL_DEVICE_PLATFORM");
1498         if (test_params.use_core_subgroups)
1499         {
1500             kernel_sstr
1501                 << "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
1502         }
1503         kernel_sstr << "#define XY(M,I) M[I].x = get_sub_group_local_id(); "
1504                        "M[I].y = get_sub_group_id();\n";
1505         kernel_sstr << TypeManager<Ty>::add_typedef();
1506         kernel_sstr << src;
1507         const std::string &kernel_str = kernel_sstr.str();
1508         const char *kernel_src = kernel_str.c_str();
1509 
1510         error = create_single_kernel_helper(context, &program, &kernel, 1,
1511                                             &kernel_src, kname);
1512         if (error != CL_SUCCESS) return TEST_FAIL;
1513 
1514         // Determine some local dimensions to use for the test.
1515         error = get_max_common_work_group_size(
1516             context, kernel, test_params.global_workgroup_size, &local);
1517         test_error_fail(error, "get_max_common_work_group_size failed");
1518 
1519         // Limit it a bit so we have muliple work groups
1520         // Ideally this will still be large enough to give us multiple
1521         if (local > test_params.local_workgroup_size)
1522             local = test_params.local_workgroup_size;
1523 
1524 
1525         // Get the sub group info
1526         subgroupsAPI subgroupsApiSet(platform, test_params.use_core_subgroups);
1527         clGetKernelSubGroupInfoKHR_fn clGetKernelSubGroupInfo_ptr =
1528             subgroupsApiSet.clGetKernelSubGroupInfo_ptr();
1529         if (clGetKernelSubGroupInfo_ptr == NULL)
1530         {
1531             log_error("ERROR: %s function not available\n",
1532                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1533             return TEST_FAIL;
1534         }
1535         error = clGetKernelSubGroupInfo_ptr(
1536             kernel, device, CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
1537             sizeof(local), (void *)&local, sizeof(tmp), (void *)&tmp, NULL);
1538         if (error != CL_SUCCESS)
1539         {
1540             log_error("ERROR: %s function error for "
1541                       "CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE\n",
1542                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1543             return TEST_FAIL;
1544         }
1545 
1546         subgroup_size = tmp;
1547 
1548         error = clGetKernelSubGroupInfo_ptr(
1549             kernel, device, CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE,
1550             sizeof(local), (void *)&local, sizeof(tmp), (void *)&tmp, NULL);
1551         if (error != CL_SUCCESS)
1552         {
1553             log_error("ERROR: %s function error for "
1554                       "CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE\n",
1555                       subgroupsApiSet.clGetKernelSubGroupInfo_name);
1556             return TEST_FAIL;
1557         }
1558 
1559         num_subgroups = tmp;
1560         // Make sure the number of sub groups is what we expect
1561         if (num_subgroups != (local + subgroup_size - 1) / subgroup_size)
1562         {
1563             log_error("ERROR: unexpected number of subgroups (%zu) returned\n",
1564                       num_subgroups);
1565             return TEST_FAIL;
1566         }
1567 
1568         std::vector<Ty> idata;
1569         std::vector<Ty> odata;
1570         size_t input_array_size = global;
1571         size_t output_array_size = global;
1572         size_t dynscl = test_params.dynsc;
1573 
1574         if (dynscl != 0)
1575         {
1576             input_array_size = global / local * num_subgroups * dynscl;
1577             output_array_size = global / local * dynscl;
1578         }
1579 
1580         idata.resize(input_array_size);
1581         odata.resize(output_array_size);
1582 
1583         if (test_params.divergence_mask_arg != -1)
1584         {
1585             cl_uint4 mask_vector;
1586             mask_vector.x = 0xffffffffU;
1587             mask_vector.y = 0xffffffffU;
1588             mask_vector.z = 0xffffffffU;
1589             mask_vector.w = 0xffffffffU;
1590             error = clSetKernelArg(kernel, test_params.divergence_mask_arg,
1591                                    sizeof(cl_uint4), &mask_vector);
1592             test_error_fail(error, "Unable to set divergence mask argument");
1593         }
1594 
1595         if (test_params.cluster_size_arg != -1)
1596         {
1597             cl_uint dummy_cluster_size = 1;
1598             error = clSetKernelArg(kernel, test_params.cluster_size_arg,
1599                                    sizeof(cl_uint), &dummy_cluster_size);
1600             test_error_fail(error, "Unable to set dummy cluster size");
1601         }
1602 
1603         KernelExecutor<Ty, Fns> executor(
1604             context, queue, kernel, global, local, idata.data(),
1605             input_array_size * sizeof(Ty), mapin.data(), mapout.data(),
1606             sgmap.data(), global * sizeof(cl_int4), odata.data(),
1607             output_array_size * sizeof(Ty), TSIZE * sizeof(Ty));
1608 
1609         // Run the kernel once on zeroes to get the map
1610         memset(idata.data(), 0, input_array_size * sizeof(Ty));
1611         error = executor.run();
1612         test_error_fail(error, "Running kernel first time failed");
1613 
1614         // Generate the desired input for the kernel
1615         test_params.subgroup_size = subgroup_size;
1616         Fns::gen(idata.data(), mapin.data(), sgmap.data(), test_params);
1617 
1618         test_status status;
1619 
1620         if (test_params.divergence_mask_arg != -1)
1621         {
1622             for (auto &mask : test_params.all_work_item_masks)
1623             {
1624                 test_params.work_items_mask = mask;
1625                 cl_uint4 mask_vector = bs128_to_cl_uint4(mask);
1626                 clSetKernelArg(kernel, test_params.divergence_mask_arg,
1627                                sizeof(cl_uint4), &mask_vector);
1628 
1629                 status = executor.run_and_check(test_params);
1630 
1631                 if (status == TEST_FAIL) break;
1632             }
1633         }
1634         else
1635         {
1636             status = executor.run_and_check(test_params);
1637         }
1638         // Detailed failure and skip messages should be logged by
1639         // run_and_check.
1640         if (status == TEST_PASS)
1641         {
1642             Fns::log_test(test_params, " passed");
1643         }
1644         else if (!executor.run_failed && status == TEST_FAIL)
1645         {
1646             test_fail("Data verification failed\n");
1647         }
1648         return status;
1649     }
1650 };
1651 
1652 void set_last_workgroup_params(int non_uniform_size, int &number_of_subgroups,
1653                                int subgroup_size, int &workgroup_size,
1654                                int &last_subgroup_size);
1655 
1656 template <typename Ty>
1657 static void set_randomdata_for_subgroup(Ty *workgroup, int wg_offset,
1658                                         int current_sbs)
1659 {
1660     int randomize_data = (int)(genrand_int32(gMTdata) % 3);
1661     // Initialize data matrix indexed by local id and sub group id
1662     switch (randomize_data)
1663     {
1664         case 0:
1665             memset(&workgroup[wg_offset], 0, current_sbs * sizeof(Ty));
1666             break;
1667         case 1: {
1668             memset(&workgroup[wg_offset], 0, current_sbs * sizeof(Ty));
1669             int wi_id = (int)(genrand_int32(gMTdata) % (cl_uint)current_sbs);
1670             set_value(workgroup[wg_offset + wi_id], 41);
1671         }
1672         break;
1673         case 2:
1674             memset(&workgroup[wg_offset], 0xff, current_sbs * sizeof(Ty));
1675             break;
1676     }
1677 }
1678 
1679 struct RunTestForType
1680 {
1681     RunTestForType(cl_device_id device, cl_context context,
1682                    cl_command_queue queue, int num_elements,
1683                    WorkGroupParams test_params)
1684         : device_(device), context_(context), queue_(queue),
1685           num_elements_(num_elements), test_params_(test_params)
1686     {}
1687     template <typename T, typename U>
1688     int run_impl(const std::string &function_name)
1689     {
1690         int error = TEST_PASS;
1691         std::string source =
1692             std::regex_replace(test_params_.get_kernel_source(function_name),
1693                                std::regex("\\%s"), function_name);
1694         std::string kernel_name = "test_" + function_name;
1695         error =
1696             test<T, U>::run(device_, context_, queue_, num_elements_,
1697                             kernel_name.c_str(), source.c_str(), test_params_);
1698 
1699         // If we return TEST_SKIPPED_ITSELF here, then an entire suite may be
1700         // reported as having been skipped even if some tests within it
1701         // passed, as the status codes are erroneously ORed together:
1702         return error == TEST_FAIL ? TEST_FAIL : TEST_PASS;
1703     }
1704 
1705 private:
1706     cl_device_id device_;
1707     cl_context context_;
1708     cl_command_queue queue_;
1709     int num_elements_;
1710     WorkGroupParams test_params_;
1711 };
1712 
1713 #endif
1714