xref: /aosp_15_r20/external/OpenCL-CTS/test_conformance/subgroups/test_subgroup_ballot.cpp (revision 6467f958c7de8070b317fc65bcb0f6472e388d82)
1 //
2 // Copyright (c) 2021 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #include "procs.h"
17 #include "subhelpers.h"
18 #include "subgroup_common_templates.h"
19 #include "harness/typeWrappers.h"
20 #include <bitset>
21 
22 namespace {
23 // Test for ballot functions
24 template <typename Ty> struct BALLOT
25 {
log_test__anoncae851e40111::BALLOT26     static void log_test(const WorkGroupParams &test_params,
27                          const char *extra_text)
28     {
29         log_info("  sub_group_ballot...%s\n", extra_text);
30     }
31 
gen__anoncae851e40111::BALLOT32     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
33     {
34         int gws = test_params.global_workgroup_size;
35         int lws = test_params.local_workgroup_size;
36         int sbs = test_params.subgroup_size;
37         int sb_number = (lws + sbs - 1) / sbs;
38         int non_uniform_size = gws % lws;
39         int wg_number = gws / lws;
40         wg_number = non_uniform_size ? wg_number + 1 : wg_number;
41         int last_subgroup_size = 0;
42 
43         for (int wg_id = 0; wg_id < wg_number; ++wg_id)
44         { // for each work_group
45             if (non_uniform_size && wg_id == wg_number - 1)
46             {
47                 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
48                                           last_subgroup_size);
49             }
50             for (int sb_id = 0; sb_id < sb_number; ++sb_id)
51             { // for each subgroup
52                 int wg_offset = sb_id * sbs;
53                 int current_sbs;
54                 if (last_subgroup_size && sb_id == sb_number - 1)
55                 {
56                     current_sbs = last_subgroup_size;
57                 }
58                 else
59                 {
60                     current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
61                 }
62 
63                 for (int wi_id = 0; wi_id < current_sbs; wi_id++)
64                 {
65                     cl_uint v;
66                     if (genrand_bool(gMTdata))
67                     {
68                         v = genrand_bool(gMTdata);
69                     }
70                     else if (genrand_bool(gMTdata))
71                     {
72                         v = 1U << ((genrand_int32(gMTdata) % 31) + 1);
73                     }
74                     else
75                     {
76                         v = genrand_int32(gMTdata);
77                     }
78                     cl_uint4 v4 = { { v, 0, 0, 0 } };
79                     t[wi_id + wg_offset] = v4;
80                 }
81             }
82             // Now map into work group using map from device
83             for (int wi_id = 0; wi_id < lws; ++wi_id)
84             {
85                 x[wi_id] = t[wi_id];
86             }
87             x += lws;
88             m += 4 * lws;
89         }
90     }
91 
chk__anoncae851e40111::BALLOT92     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
93                            const WorkGroupParams &test_params)
94     {
95         int gws = test_params.global_workgroup_size;
96         int lws = test_params.local_workgroup_size;
97         int sbs = test_params.subgroup_size;
98         int sb_number = (lws + sbs - 1) / sbs;
99         int non_uniform_size = gws % lws;
100         int wg_number = gws / lws;
101         wg_number = non_uniform_size ? wg_number + 1 : wg_number;
102         int last_subgroup_size = 0;
103 
104         for (int wg_id = 0; wg_id < wg_number; ++wg_id)
105         { // for each work_group
106             if (non_uniform_size && wg_id == wg_number - 1)
107             {
108                 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
109                                           last_subgroup_size);
110             }
111             for (int wi_id = 0; wi_id < lws; ++wi_id)
112             { // inside the work_group
113                 mx[wi_id] = x[wi_id]; // read host inputs for work_group
114                 my[wi_id] = y[wi_id]; // read device outputs for work_group
115             }
116 
117             for (int sb_id = 0; sb_id < sb_number; ++sb_id)
118             { // for each subgroup
119                 int wg_offset = sb_id * sbs;
120                 int current_sbs;
121                 if (last_subgroup_size && sb_id == sb_number - 1)
122                 {
123                     current_sbs = last_subgroup_size;
124                 }
125                 else
126                 {
127                     current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
128                 }
129 
130                 bs128 expected_result_bs = 0;
131 
132                 std::set<int> active_work_items;
133                 for (int wi_id = 0; wi_id < current_sbs; ++wi_id)
134                 {
135                     if (test_params.work_items_mask.test(wi_id))
136                     {
137                         bool predicate = (mx[wg_offset + wi_id].s0 != 0);
138                         expected_result_bs |= (bs128(predicate) << wi_id);
139                         active_work_items.insert(wi_id);
140                     }
141                 }
142                 if (active_work_items.empty())
143                 {
144                     continue;
145                 }
146 
147                 cl_uint4 expected_result =
148                     bs128_to_cl_uint4(expected_result_bs);
149                 for (const int &active_work_item : active_work_items)
150                 {
151                     int wi_id = active_work_item;
152 
153                     cl_uint4 device_result = my[wg_offset + wi_id];
154                     bs128 device_result_bs = cl_uint4_to_bs128(device_result);
155 
156                     if (device_result_bs != expected_result_bs)
157                     {
158                         log_error(
159                             "ERROR: sub_group_ballot mismatch for local id "
160                             "%d in sub group %d in group %d obtained {%d, %d, "
161                             "%d, %d}, expected {%d, %d, %d, %d}\n",
162                             wi_id, sb_id, wg_id, device_result.s0,
163                             device_result.s1, device_result.s2,
164                             device_result.s3, expected_result.s0,
165                             expected_result.s1, expected_result.s2,
166                             expected_result.s3);
167                         return TEST_FAIL;
168                     }
169                 }
170             }
171 
172             x += lws;
173             y += lws;
174             m += 4 * lws;
175         }
176 
177         return TEST_PASS;
178     }
179 };
180 
181 // Test for bit extract ballot functions
182 template <typename Ty, BallotOp operation> struct BALLOT_BIT_EXTRACT
183 {
log_test__anoncae851e40111::BALLOT_BIT_EXTRACT184     static void log_test(const WorkGroupParams &test_params,
185                          const char *extra_text)
186     {
187         log_info("  sub_group_ballot_%s(%s)...%s\n", operation_names(operation),
188                  TypeManager<Ty>::name(), extra_text);
189     }
190 
gen__anoncae851e40111::BALLOT_BIT_EXTRACT191     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
192     {
193         int wi_id, sb_id, wg_id;
194         int gws = test_params.global_workgroup_size;
195         int lws = test_params.local_workgroup_size;
196         int sbs = test_params.subgroup_size;
197         int sb_number = (lws + sbs - 1) / sbs;
198         int wg_number = gws / lws;
199         int limit_sbs = sbs > 100 ? 100 : sbs;
200 
201         for (wg_id = 0; wg_id < wg_number; ++wg_id)
202         { // for each work_group
203             for (sb_id = 0; sb_id < sb_number; ++sb_id)
204             { // for each subgroup
205                 int wg_offset = sb_id * sbs;
206                 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
207                 // rand index to bit extract
208                 int index_for_odd = (int)(genrand_int32(gMTdata) & 0x7fffffff)
209                     % (limit_sbs > current_sbs ? current_sbs : limit_sbs);
210                 int index_for_even = (int)(genrand_int32(gMTdata) & 0x7fffffff)
211                     % (limit_sbs > current_sbs ? current_sbs : limit_sbs);
212                 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
213                 {
214                     // index of the third element int the vector.
215                     int midx = 4 * wg_offset + 4 * wi_id + 2;
216                     // storing information about index to bit extract
217                     m[midx] = (cl_int)index_for_odd;
218                     m[++midx] = (cl_int)index_for_even;
219                 }
220                 set_randomdata_for_subgroup<Ty>(t, wg_offset, current_sbs);
221             }
222 
223             // Now map into work group using map from device
224             for (wi_id = 0; wi_id < lws; ++wi_id)
225             {
226                 x[wi_id] = t[wi_id];
227             }
228 
229             x += lws;
230             m += 4 * lws;
231         }
232     }
233 
chk__anoncae851e40111::BALLOT_BIT_EXTRACT234     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
235                            const WorkGroupParams &test_params)
236     {
237         int wi_id, wg_id, sb_id;
238         int gws = test_params.global_workgroup_size;
239         int lws = test_params.local_workgroup_size;
240         int sbs = test_params.subgroup_size;
241         int sb_number = (lws + sbs - 1) / sbs;
242         int wg_number = gws / lws;
243         cl_uint4 expected_result, device_result;
244         int last_subgroup_size = 0;
245         int current_sbs = 0;
246         int non_uniform_size = gws % lws;
247 
248         for (wg_id = 0; wg_id < wg_number; ++wg_id)
249         { // for each work_group
250             if (non_uniform_size && wg_id == wg_number - 1)
251             {
252                 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
253                                           last_subgroup_size);
254             }
255             // Map to array indexed to array indexed by local ID and sub group
256             for (wi_id = 0; wi_id < lws; ++wi_id)
257             { // inside the work_group
258                 // read host inputs for work_group
259                 mx[wi_id] = x[wi_id];
260                 // read device outputs for work_group
261                 my[wi_id] = y[wi_id];
262             }
263 
264             for (sb_id = 0; sb_id < sb_number; ++sb_id)
265             { // for each subgroup
266                 int wg_offset = sb_id * sbs;
267                 if (last_subgroup_size && sb_id == sb_number - 1)
268                 {
269                     current_sbs = last_subgroup_size;
270                 }
271                 else
272                 {
273                     current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
274                 }
275                 // take index of array where info which work_item will
276                 // be broadcast its value is stored
277                 int midx = 4 * wg_offset + 2;
278                 // take subgroup local id of this work_item
279                 int index_for_odd = (int)m[midx];
280                 int index_for_even = (int)m[++midx];
281 
282                 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
283                 { // for each subgroup
284                     int bit_value = 0;
285                     // from which value of bitfield bit
286                     // verification will be done
287                     int take_shift =
288                         (wi_id & 1) ? index_for_odd % 32 : index_for_even % 32;
289                     int bit_mask = 1 << take_shift;
290 
291                     if (wi_id < 32)
292                         (mx[wg_offset + wi_id].s0 & bit_mask) > 0
293                             ? bit_value = 1
294                             : bit_value = 0;
295                     if (wi_id >= 32 && wi_id < 64)
296                         (mx[wg_offset + wi_id].s1 & bit_mask) > 0
297                             ? bit_value = 1
298                             : bit_value = 0;
299                     if (wi_id >= 64 && wi_id < 96)
300                         (mx[wg_offset + wi_id].s2 & bit_mask) > 0
301                             ? bit_value = 1
302                             : bit_value = 0;
303                     if (wi_id >= 96 && wi_id < 128)
304                         (mx[wg_offset + wi_id].s3 & bit_mask) > 0
305                             ? bit_value = 1
306                             : bit_value = 0;
307 
308                     if (wi_id & 1)
309                     {
310                         bit_value ? expected_result = { { 1, 0, 0, 1 } }
311                                   : expected_result = { { 0, 0, 0, 1 } };
312                     }
313                     else
314                     {
315                         bit_value ? expected_result = { { 1, 0, 0, 2 } }
316                                   : expected_result = { { 0, 0, 0, 2 } };
317                     }
318 
319                     device_result = my[wg_offset + wi_id];
320                     if (!compare(device_result, expected_result))
321                     {
322                         log_error(
323                             "ERROR: sub_group_%s mismatch for local id %d in "
324                             "sub group %d in group %d obtained {%d, %d, %d, "
325                             "%d}, expected {%d, %d, %d, %d}\n",
326                             operation_names(operation), wi_id, sb_id, wg_id,
327                             device_result.s0, device_result.s1,
328                             device_result.s2, device_result.s3,
329                             expected_result.s0, expected_result.s1,
330                             expected_result.s2, expected_result.s3);
331                         return TEST_FAIL;
332                     }
333                 }
334             }
335             x += lws;
336             y += lws;
337             m += 4 * lws;
338         }
339         return TEST_PASS;
340     }
341 };
342 
343 template <typename Ty, BallotOp operation> struct BALLOT_INVERSE
344 {
log_test__anoncae851e40111::BALLOT_INVERSE345     static void log_test(const WorkGroupParams &test_params,
346                          const char *extra_text)
347     {
348         log_info("  sub_group_inverse_ballot...%s\n", extra_text);
349     }
350 
gen__anoncae851e40111::BALLOT_INVERSE351     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
352     {
353         // no work here
354     }
355 
chk__anoncae851e40111::BALLOT_INVERSE356     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
357                            const WorkGroupParams &test_params)
358     {
359         int wi_id, wg_id, sb_id;
360         int gws = test_params.global_workgroup_size;
361         int lws = test_params.local_workgroup_size;
362         int sbs = test_params.subgroup_size;
363         int sb_number = (lws + sbs - 1) / sbs;
364         cl_uint4 expected_result, device_result;
365         int non_uniform_size = gws % lws;
366         int wg_number = gws / lws;
367         int last_subgroup_size = 0;
368         int current_sbs = 0;
369         if (non_uniform_size) wg_number++;
370 
371         for (wg_id = 0; wg_id < wg_number; ++wg_id)
372         { // for each work_group
373             if (non_uniform_size && wg_id == wg_number - 1)
374             {
375                 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
376                                           last_subgroup_size);
377             }
378             // Map to array indexed to array indexed by local ID and sub group
379             for (wi_id = 0; wi_id < lws; ++wi_id)
380             { // inside the work_group
381                 mx[wi_id] = x[wi_id]; // read host inputs for work_group
382                 my[wi_id] = y[wi_id]; // read device outputs for work_group
383             }
384 
385             for (sb_id = 0; sb_id < sb_number; ++sb_id)
386             { // for each subgroup
387                 int wg_offset = sb_id * sbs;
388                 if (last_subgroup_size && sb_id == sb_number - 1)
389                 {
390                     current_sbs = last_subgroup_size;
391                 }
392                 else
393                 {
394                     current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
395                 }
396                 // take subgroup local id of this work_item
397                 // Check result
398                 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
399                 { // for each subgroup work item
400 
401                     wi_id & 1 ? expected_result = { { 1, 0, 0, 1 } }
402                               : expected_result = { { 1, 0, 0, 2 } };
403 
404                     device_result = my[wg_offset + wi_id];
405                     if (!compare(device_result, expected_result))
406                     {
407                         log_error(
408                             "ERROR: sub_group_%s mismatch for local id %d in "
409                             "sub group %d in group %d obtained {%d, %d, %d, "
410                             "%d}, expected {%d, %d, %d, %d}\n",
411                             operation_names(operation), wi_id, sb_id, wg_id,
412                             device_result.s0, device_result.s1,
413                             device_result.s2, device_result.s3,
414                             expected_result.s0, expected_result.s1,
415                             expected_result.s2, expected_result.s3);
416                         return TEST_FAIL;
417                     }
418                 }
419             }
420             x += lws;
421             y += lws;
422             m += 4 * lws;
423         }
424 
425         return TEST_PASS;
426     }
427 };
428 
429 
430 // Test for bit count/inclusive and exclusive scan/ find lsb msb ballot function
431 template <typename Ty, BallotOp operation> struct BALLOT_COUNT_SCAN_FIND
432 {
log_test__anoncae851e40111::BALLOT_COUNT_SCAN_FIND433     static void log_test(const WorkGroupParams &test_params,
434                          const char *extra_text)
435     {
436         log_info("  sub_group_%s(%s)...%s\n", operation_names(operation),
437                  TypeManager<Ty>::name(), extra_text);
438     }
439 
gen__anoncae851e40111::BALLOT_COUNT_SCAN_FIND440     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
441     {
442         int wi_id, wg_id, sb_id;
443         int gws = test_params.global_workgroup_size;
444         int lws = test_params.local_workgroup_size;
445         int sbs = test_params.subgroup_size;
446         int sb_number = (lws + sbs - 1) / sbs;
447         int non_uniform_size = gws % lws;
448         int wg_number = gws / lws;
449         int last_subgroup_size = 0;
450         int current_sbs = 0;
451 
452         if (non_uniform_size)
453         {
454             wg_number++;
455         }
456         for (wg_id = 0; wg_id < wg_number; ++wg_id)
457         { // for each work_group
458             if (non_uniform_size && wg_id == wg_number - 1)
459             {
460                 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
461                                           last_subgroup_size);
462             }
463             for (sb_id = 0; sb_id < sb_number; ++sb_id)
464             { // for each subgroup
465                 int wg_offset = sb_id * sbs;
466                 if (last_subgroup_size && sb_id == sb_number - 1)
467                 {
468                     current_sbs = last_subgroup_size;
469                 }
470                 else
471                 {
472                     current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
473                 }
474                 if (operation == BallotOp::ballot_bit_count
475                     || operation == BallotOp::ballot_inclusive_scan
476                     || operation == BallotOp::ballot_exclusive_scan)
477                 {
478                     set_randomdata_for_subgroup<Ty>(t, wg_offset, current_sbs);
479                 }
480                 else if (operation == BallotOp::ballot_find_lsb
481                          || operation == BallotOp::ballot_find_msb)
482                 {
483                     // Regarding to the spec, find lsb and find msb result is
484                     // undefined behavior if input value is zero, so generate
485                     // only non-zero values.
486                     for (wi_id = 0; wi_id < current_sbs; ++wi_id)
487                     {
488                         char x = (genrand_int32(gMTdata)) & 0xff;
489                         // undefined behaviour in case of 0;
490                         x = x ? x : 1;
491                         memset(&t[wg_offset + wi_id], x, sizeof(Ty));
492                     }
493                 }
494                 else
495                 {
496                     log_error("Unknown operation...\n");
497                 }
498             }
499 
500             // Now map into work group using map from device
501             for (wi_id = 0; wi_id < lws; ++wi_id)
502             {
503                 x[wi_id] = t[wi_id];
504             }
505 
506             x += lws;
507             m += 4 * lws;
508         }
509     }
510 
getImportantBits__anoncae851e40111::BALLOT_COUNT_SCAN_FIND511     static bs128 getImportantBits(cl_uint sub_group_local_id,
512                                   cl_uint sub_group_size)
513     {
514         bs128 mask;
515         if (operation == BallotOp::ballot_bit_count
516             || operation == BallotOp::ballot_find_lsb
517             || operation == BallotOp::ballot_find_msb)
518         {
519             for (cl_uint i = 0; i < sub_group_size; ++i) mask.set(i);
520         }
521         else if (operation == BallotOp::ballot_inclusive_scan
522                  || operation == BallotOp::ballot_exclusive_scan)
523         {
524             for (cl_uint i = 0; i < sub_group_local_id; ++i) mask.set(i);
525             if (operation == BallotOp::ballot_inclusive_scan)
526                 mask.set(sub_group_local_id);
527         }
528         return mask;
529     }
530 
chk__anoncae851e40111::BALLOT_COUNT_SCAN_FIND531     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
532                            const WorkGroupParams &test_params)
533     {
534         int wi_id, wg_id, sb_id;
535         int gws = test_params.global_workgroup_size;
536         int lws = test_params.local_workgroup_size;
537         int sbs = test_params.subgroup_size;
538         int sb_number = (lws + sbs - 1) / sbs;
539         int non_uniform_size = gws % lws;
540         int wg_number = gws / lws;
541         wg_number = non_uniform_size ? wg_number + 1 : wg_number;
542         cl_uint expected_result, device_result;
543         int last_subgroup_size = 0;
544         int current_sbs = 0;
545 
546         for (wg_id = 0; wg_id < wg_number; ++wg_id)
547         { // for each work_group
548             if (non_uniform_size && wg_id == wg_number - 1)
549             {
550                 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
551                                           last_subgroup_size);
552             }
553             // Map to array indexed to array indexed by local ID and sub group
554             for (wi_id = 0; wi_id < lws; ++wi_id)
555             { // inside the work_group
556                 // read host inputs for work_group
557                 mx[wi_id] = x[wi_id];
558                 // read device outputs for work_group
559                 my[wi_id] = y[wi_id];
560             }
561 
562             for (sb_id = 0; sb_id < sb_number; ++sb_id)
563             { // for each subgroup
564                 int wg_offset = sb_id * sbs;
565                 if (last_subgroup_size && sb_id == sb_number - 1)
566                 {
567                     current_sbs = last_subgroup_size;
568                 }
569                 else
570                 {
571                     current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
572                 }
573                 // Check result
574                 expected_result = 0;
575                 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
576                 { // for subgroup element
577                     bs128 bs;
578                     // convert cl_uint4 input into std::bitset<128>
579                     bs |= bs128(mx[wg_offset + wi_id].s0)
580                         | (bs128(mx[wg_offset + wi_id].s1) << 32)
581                         | (bs128(mx[wg_offset + wi_id].s2) << 64)
582                         | (bs128(mx[wg_offset + wi_id].s3) << 96);
583                     bs &= getImportantBits(wi_id, sbs);
584                     device_result = my[wg_offset + wi_id].s0;
585                     if (operation == BallotOp::ballot_inclusive_scan
586                         || operation == BallotOp::ballot_exclusive_scan
587                         || operation == BallotOp::ballot_bit_count)
588                     {
589                         expected_result = bs.count();
590                         if (!compare(device_result, expected_result))
591                         {
592                             log_error("ERROR: sub_group_%s "
593                                       "mismatch for local id %d in sub group "
594                                       "%d in group %d obtained %d, "
595                                       "expected %d\n",
596                                       operation_names(operation), wi_id, sb_id,
597                                       wg_id, device_result, expected_result);
598                             return TEST_FAIL;
599                         }
600                     }
601                     else if (operation == BallotOp::ballot_find_lsb)
602                     {
603                         if (bs.none())
604                         {
605                             // Return value is undefined when no bits are set,
606                             // so skip validation:
607                             continue;
608                         }
609                         for (int id = 0; id < sbs; ++id)
610                         {
611                             if (bs.test(id))
612                             {
613                                 expected_result = id;
614                                 break;
615                             }
616                         }
617                         if (!compare(device_result, expected_result))
618                         {
619                             log_error("ERROR: sub_group_ballot_find_lsb "
620                                       "mismatch for local id %d in sub group "
621                                       "%d in group %d obtained %d, "
622                                       "expected %d\n",
623                                       wi_id, sb_id, wg_id, device_result,
624                                       expected_result);
625                             return TEST_FAIL;
626                         }
627                     }
628                     else if (operation == BallotOp::ballot_find_msb)
629                     {
630                         if (bs.none())
631                         {
632                             // Return value is undefined when no bits are set,
633                             // so skip validation:
634                             continue;
635                         }
636                         for (int id = sbs - 1; id >= 0; --id)
637                         {
638                             if (bs.test(id))
639                             {
640                                 expected_result = id;
641                                 break;
642                             }
643                         }
644                         if (!compare(device_result, expected_result))
645                         {
646                             log_error("ERROR: sub_group_ballot_find_msb "
647                                       "mismatch for local id %d in sub group "
648                                       "%d in group %d obtained %d, "
649                                       "expected %d\n",
650                                       wi_id, sb_id, wg_id, device_result,
651                                       expected_result);
652                             return TEST_FAIL;
653                         }
654                     }
655                 }
656             }
657             x += lws;
658             y += lws;
659             m += 4 * lws;
660         }
661         return TEST_PASS;
662     }
663 };
664 
665 // test mask functions
666 template <typename Ty, BallotOp operation> struct SMASK
667 {
log_test__anoncae851e40111::SMASK668     static void log_test(const WorkGroupParams &test_params,
669                          const char *extra_text)
670     {
671         log_info("  get_sub_group_%s_mask...%s\n", operation_names(operation),
672                  extra_text);
673     }
674 
gen__anoncae851e40111::SMASK675     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
676     {
677         int wi_id, wg_id, sb_id;
678         int gws = test_params.global_workgroup_size;
679         int lws = test_params.local_workgroup_size;
680         int sbs = test_params.subgroup_size;
681         int sb_number = (lws + sbs - 1) / sbs;
682         int wg_number = gws / lws;
683         for (wg_id = 0; wg_id < wg_number; ++wg_id)
684         { // for each work_group
685             for (sb_id = 0; sb_id < sb_number; ++sb_id)
686             { // for each subgroup
687                 int wg_offset = sb_id * sbs;
688                 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
689                 // Produce expected masks for each work item in the subgroup
690                 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
691                 {
692                     int midx = 4 * wg_offset + 4 * wi_id;
693                     cl_uint max_sub_group_size = m[midx + 2];
694                     cl_uint4 expected_mask = { { 0 } };
695                     expected_mask = generate_bit_mask(
696                         wi_id, operation_names(operation), max_sub_group_size);
697                     set_value(t[wg_offset + wi_id], expected_mask);
698                 }
699             }
700 
701             // Now map into work group using map from device
702             for (wi_id = 0; wi_id < lws; ++wi_id)
703             {
704                 x[wi_id] = t[wi_id];
705             }
706             x += lws;
707             m += 4 * lws;
708         }
709     }
710 
chk__anoncae851e40111::SMASK711     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
712                            const WorkGroupParams &test_params)
713     {
714         int wi_id, wg_id, sb_id;
715         int gws = test_params.global_workgroup_size;
716         int lws = test_params.local_workgroup_size;
717         int sbs = test_params.subgroup_size;
718         int sb_number = (lws + sbs - 1) / sbs;
719         Ty expected_result, device_result;
720         int wg_number = gws / lws;
721 
722         for (wg_id = 0; wg_id < wg_number; ++wg_id)
723         { // for each work_group
724             for (wi_id = 0; wi_id < lws; ++wi_id)
725             { // inside the work_group
726                 mx[wi_id] = x[wi_id]; // read host inputs for work_group
727                 my[wi_id] = y[wi_id]; // read device outputs for work_group
728             }
729 
730             for (sb_id = 0; sb_id < sb_number; ++sb_id)
731             {
732                 int wg_offset = sb_id * sbs;
733                 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
734 
735                 // Check result
736                 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
737                 { // inside the subgroup
738                     expected_result =
739                         mx[wg_offset + wi_id]; // read host input for subgroup
740                     device_result =
741                         my[wg_offset
742                            + wi_id]; // read device outputs for subgroup
743                     if (!compare(device_result, expected_result))
744                     {
745                         log_error("ERROR:  get_sub_group_%s_mask... mismatch "
746                                   "for local id %d in sub group %d in group "
747                                   "%d, %s\n",
748                                   operation_names(operation), wi_id, sb_id,
749                                   wg_id,
750                                   print_expected_obtained(expected_result,
751                                                           device_result)
752                                       .c_str());
753                         return TEST_FAIL;
754                     }
755                 }
756             }
757             x += lws;
758             y += lws;
759             m += 4 * lws;
760         }
761         return TEST_PASS;
762     }
763 };
764 
765 std::string sub_group_non_uniform_broadcast_source = R"(
766 __kernel void test_sub_group_non_uniform_broadcast(const __global Type *in, __global int4 *xy, __global Type *out) {
767     int gid = get_global_id(0);
768     XY(xy,gid);
769     Type x = in[gid];
770     if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {
771         out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].z);
772     } else {
773         out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].w);
774     }
775 }
776 )";
777 std::string sub_group_broadcast_first_source = R"(
778 __kernel void test_sub_group_broadcast_first(const __global Type *in, __global int4 *xy, __global Type *out) {
779     int gid = get_global_id(0);
780     XY(xy,gid);
781     Type x = in[gid];
782     if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {
783         out[gid] = sub_group_broadcast_first(x);;
784     } else {
785         out[gid] = sub_group_broadcast_first(x);;
786     }
787 }
788 )";
789 std::string sub_group_ballot_bit_scan_find_source = R"(
790 __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out) {
791     int gid = get_global_id(0);
792     XY(xy,gid);
793     Type x = in[gid];
794     uint4 value = (uint4)(0,0,0,0);
795     value = (uint4)(%s(x),0,0,0);
796     out[gid] = value;
797 }
798 )";
799 std::string sub_group_ballot_mask_source = R"(
800 __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out) {
801     int gid = get_global_id(0);
802     XY(xy,gid);
803     xy[gid].z = get_max_sub_group_size();
804     Type x = in[gid];
805     uint4 mask = %s();
806     out[gid] = mask;
807 }
808 )";
809 std::string sub_group_ballot_source = R"(
810 __kernel void test_sub_group_ballot(const __global Type *in, __global int4 *xy, __global Type *out, uint4 work_item_mask_vector) {
811     uint gid = get_global_id(0);
812     XY(xy,gid);
813     uint subgroup_local_id = get_sub_group_local_id();
814     uint elect_work_item = 1 << (subgroup_local_id % 32);
815     uint work_item_mask;
816     if (subgroup_local_id < 32) {
817         work_item_mask = work_item_mask_vector.x;
818     } else if(subgroup_local_id < 64) {
819         work_item_mask = work_item_mask_vector.y;
820     } else if(subgroup_local_id < 96) {
821         work_item_mask = work_item_mask_vector.z;
822     } else if(subgroup_local_id < 128) {
823         work_item_mask = work_item_mask_vector.w;
824     }
825     uint4 value = (uint4)(0, 0, 0, 0);
826     if (elect_work_item & work_item_mask) {
827         value = sub_group_ballot(in[gid].s0);
828     }
829     out[gid] = value;
830 }
831 )";
832 std::string sub_group_inverse_ballot_source = R"(
833 __kernel void test_sub_group_inverse_ballot(const __global Type *in, __global int4 *xy, __global Type *out) {
834     int gid = get_global_id(0);
835     XY(xy,gid);
836     Type x = in[gid];
837     uint4 value = (uint4)(10,0,0,0);
838     if (get_sub_group_local_id() & 1) {
839         uint4 partial_ballot_mask = (uint4)(0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA);
840         if (sub_group_inverse_ballot(partial_ballot_mask)) {
841             value = (uint4)(1,0,0,1);
842         } else {
843             value = (uint4)(0,0,0,1);
844         }
845     } else {
846         uint4 partial_ballot_mask = (uint4)(0x55555555,0x55555555,0x55555555,0x55555555);
847         if (sub_group_inverse_ballot(partial_ballot_mask)) {
848             value = (uint4)(1,0,0,2);
849         } else {
850             value = (uint4)(0,0,0,2);
851         }
852     }
853     out[gid] = value;
854 }
855 )";
856 std::string sub_group_ballot_bit_extract_source = R"(
857  __kernel void test_sub_group_ballot_bit_extract(const __global Type *in, __global int4 *xy, __global Type *out) {
858     int gid = get_global_id(0);
859     XY(xy,gid);
860     Type x = in[gid];
861     uint index = xy[gid].z;
862     uint4 value = (uint4)(10,0,0,0);
863     if (get_sub_group_local_id() & 1) {
864         if (sub_group_ballot_bit_extract(x, xy[gid].z)) {
865             value = (uint4)(1,0,0,1);
866         } else {
867             value = (uint4)(0,0,0,1);
868         }
869     } else {
870         if (sub_group_ballot_bit_extract(x, xy[gid].w)) {
871             value = (uint4)(1,0,0,2);
872         } else {
873             value = (uint4)(0,0,0,2);
874         }
875     }
876     out[gid] = value;
877 }
878 )";
879 
run_non_uniform_broadcast_for_type(RunTestForType rft)880 template <typename T> int run_non_uniform_broadcast_for_type(RunTestForType rft)
881 {
882     int error =
883         rft.run_impl<T, BC<T, SubgroupsBroadcastOp::non_uniform_broadcast>>(
884             "sub_group_non_uniform_broadcast");
885     return error;
886 }
887 
888 
889 }
890 
test_subgroup_functions_ballot(cl_device_id device,cl_context context,cl_command_queue queue,int num_elements)891 int test_subgroup_functions_ballot(cl_device_id device, cl_context context,
892                                    cl_command_queue queue, int num_elements)
893 {
894     if (!is_extension_available(device, "cl_khr_subgroup_ballot"))
895     {
896         log_info("cl_khr_subgroup_ballot is not supported on this device, "
897                  "skipping test.\n");
898         return TEST_SKIPPED_ITSELF;
899     }
900 
901     constexpr size_t global_work_size = 170;
902     constexpr size_t local_work_size = 64;
903     WorkGroupParams test_params(global_work_size, local_work_size);
904     test_params.save_kernel_source(sub_group_ballot_mask_source);
905     test_params.save_kernel_source(sub_group_non_uniform_broadcast_source,
906                                    "sub_group_non_uniform_broadcast");
907     test_params.save_kernel_source(sub_group_broadcast_first_source,
908                                    "sub_group_broadcast_first");
909     RunTestForType rft(device, context, queue, num_elements, test_params);
910 
911     // non uniform broadcast functions
912     int error = run_non_uniform_broadcast_for_type<cl_int>(rft);
913     error |= run_non_uniform_broadcast_for_type<cl_int2>(rft);
914     error |= run_non_uniform_broadcast_for_type<subgroups::cl_int3>(rft);
915     error |= run_non_uniform_broadcast_for_type<cl_int4>(rft);
916     error |= run_non_uniform_broadcast_for_type<cl_int8>(rft);
917     error |= run_non_uniform_broadcast_for_type<cl_int16>(rft);
918 
919     error |= run_non_uniform_broadcast_for_type<cl_uint>(rft);
920     error |= run_non_uniform_broadcast_for_type<cl_uint2>(rft);
921     error |= run_non_uniform_broadcast_for_type<subgroups::cl_uint3>(rft);
922     error |= run_non_uniform_broadcast_for_type<cl_uint4>(rft);
923     error |= run_non_uniform_broadcast_for_type<cl_uint8>(rft);
924     error |= run_non_uniform_broadcast_for_type<cl_uint16>(rft);
925 
926     error |= run_non_uniform_broadcast_for_type<cl_char>(rft);
927     error |= run_non_uniform_broadcast_for_type<cl_char2>(rft);
928     error |= run_non_uniform_broadcast_for_type<subgroups::cl_char3>(rft);
929     error |= run_non_uniform_broadcast_for_type<cl_char4>(rft);
930     error |= run_non_uniform_broadcast_for_type<cl_char8>(rft);
931     error |= run_non_uniform_broadcast_for_type<cl_char16>(rft);
932 
933     error |= run_non_uniform_broadcast_for_type<cl_uchar>(rft);
934     error |= run_non_uniform_broadcast_for_type<cl_uchar2>(rft);
935     error |= run_non_uniform_broadcast_for_type<subgroups::cl_uchar3>(rft);
936     error |= run_non_uniform_broadcast_for_type<cl_uchar4>(rft);
937     error |= run_non_uniform_broadcast_for_type<cl_uchar8>(rft);
938     error |= run_non_uniform_broadcast_for_type<cl_uchar16>(rft);
939 
940     error |= run_non_uniform_broadcast_for_type<cl_short>(rft);
941     error |= run_non_uniform_broadcast_for_type<cl_short2>(rft);
942     error |= run_non_uniform_broadcast_for_type<subgroups::cl_short3>(rft);
943     error |= run_non_uniform_broadcast_for_type<cl_short4>(rft);
944     error |= run_non_uniform_broadcast_for_type<cl_short8>(rft);
945     error |= run_non_uniform_broadcast_for_type<cl_short16>(rft);
946 
947     error |= run_non_uniform_broadcast_for_type<cl_ushort>(rft);
948     error |= run_non_uniform_broadcast_for_type<cl_ushort2>(rft);
949     error |= run_non_uniform_broadcast_for_type<subgroups::cl_ushort3>(rft);
950     error |= run_non_uniform_broadcast_for_type<cl_ushort4>(rft);
951     error |= run_non_uniform_broadcast_for_type<cl_ushort8>(rft);
952     error |= run_non_uniform_broadcast_for_type<cl_ushort16>(rft);
953 
954     error |= run_non_uniform_broadcast_for_type<cl_long>(rft);
955     error |= run_non_uniform_broadcast_for_type<cl_long2>(rft);
956     error |= run_non_uniform_broadcast_for_type<subgroups::cl_long3>(rft);
957     error |= run_non_uniform_broadcast_for_type<cl_long4>(rft);
958     error |= run_non_uniform_broadcast_for_type<cl_long8>(rft);
959     error |= run_non_uniform_broadcast_for_type<cl_long16>(rft);
960 
961     error |= run_non_uniform_broadcast_for_type<cl_ulong>(rft);
962     error |= run_non_uniform_broadcast_for_type<cl_ulong2>(rft);
963     error |= run_non_uniform_broadcast_for_type<subgroups::cl_ulong3>(rft);
964     error |= run_non_uniform_broadcast_for_type<cl_ulong4>(rft);
965     error |= run_non_uniform_broadcast_for_type<cl_ulong8>(rft);
966     error |= run_non_uniform_broadcast_for_type<cl_ulong16>(rft);
967 
968     error |= run_non_uniform_broadcast_for_type<cl_float>(rft);
969     error |= run_non_uniform_broadcast_for_type<cl_float2>(rft);
970     error |= run_non_uniform_broadcast_for_type<subgroups::cl_float3>(rft);
971     error |= run_non_uniform_broadcast_for_type<cl_float4>(rft);
972     error |= run_non_uniform_broadcast_for_type<cl_float8>(rft);
973     error |= run_non_uniform_broadcast_for_type<cl_float16>(rft);
974 
975     error |= run_non_uniform_broadcast_for_type<cl_double>(rft);
976     error |= run_non_uniform_broadcast_for_type<cl_double2>(rft);
977     error |= run_non_uniform_broadcast_for_type<subgroups::cl_double3>(rft);
978     error |= run_non_uniform_broadcast_for_type<cl_double4>(rft);
979     error |= run_non_uniform_broadcast_for_type<cl_double8>(rft);
980     error |= run_non_uniform_broadcast_for_type<cl_double16>(rft);
981 
982     error |= run_non_uniform_broadcast_for_type<subgroups::cl_half>(rft);
983     error |= run_non_uniform_broadcast_for_type<subgroups::cl_half2>(rft);
984     error |= run_non_uniform_broadcast_for_type<subgroups::cl_half3>(rft);
985     error |= run_non_uniform_broadcast_for_type<subgroups::cl_half4>(rft);
986     error |= run_non_uniform_broadcast_for_type<subgroups::cl_half8>(rft);
987     error |= run_non_uniform_broadcast_for_type<subgroups::cl_half16>(rft);
988 
989     // broadcast first functions
990     error |=
991         rft.run_impl<cl_int, BC<cl_int, SubgroupsBroadcastOp::broadcast_first>>(
992             "sub_group_broadcast_first");
993     error |= rft.run_impl<cl_uint,
994                           BC<cl_uint, SubgroupsBroadcastOp::broadcast_first>>(
995         "sub_group_broadcast_first");
996     error |= rft.run_impl<cl_long,
997                           BC<cl_long, SubgroupsBroadcastOp::broadcast_first>>(
998         "sub_group_broadcast_first");
999     error |= rft.run_impl<cl_ulong,
1000                           BC<cl_ulong, SubgroupsBroadcastOp::broadcast_first>>(
1001         "sub_group_broadcast_first");
1002     error |= rft.run_impl<cl_short,
1003                           BC<cl_short, SubgroupsBroadcastOp::broadcast_first>>(
1004         "sub_group_broadcast_first");
1005     error |= rft.run_impl<cl_ushort,
1006                           BC<cl_ushort, SubgroupsBroadcastOp::broadcast_first>>(
1007         "sub_group_broadcast_first");
1008     error |= rft.run_impl<cl_char,
1009                           BC<cl_char, SubgroupsBroadcastOp::broadcast_first>>(
1010         "sub_group_broadcast_first");
1011     error |= rft.run_impl<cl_uchar,
1012                           BC<cl_uchar, SubgroupsBroadcastOp::broadcast_first>>(
1013         "sub_group_broadcast_first");
1014     error |= rft.run_impl<cl_float,
1015                           BC<cl_float, SubgroupsBroadcastOp::broadcast_first>>(
1016         "sub_group_broadcast_first");
1017     error |= rft.run_impl<cl_double,
1018                           BC<cl_double, SubgroupsBroadcastOp::broadcast_first>>(
1019         "sub_group_broadcast_first");
1020     error |= rft.run_impl<
1021         subgroups::cl_half,
1022         BC<subgroups::cl_half, SubgroupsBroadcastOp::broadcast_first>>(
1023         "sub_group_broadcast_first");
1024 
1025     // mask functions
1026     error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::eq_mask>>(
1027         "get_sub_group_eq_mask");
1028     error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::ge_mask>>(
1029         "get_sub_group_ge_mask");
1030     error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::gt_mask>>(
1031         "get_sub_group_gt_mask");
1032     error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::le_mask>>(
1033         "get_sub_group_le_mask");
1034     error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::lt_mask>>(
1035         "get_sub_group_lt_mask");
1036 
1037     // sub_group_ballot function
1038     WorkGroupParams test_params_ballot(global_work_size, local_work_size, 3);
1039     test_params_ballot.save_kernel_source(sub_group_ballot_source);
1040     RunTestForType rft_ballot(device, context, queue, num_elements,
1041                               test_params_ballot);
1042     error |=
1043         rft_ballot.run_impl<cl_uint4, BALLOT<cl_uint4>>("sub_group_ballot");
1044 
1045     // ballot arithmetic functions
1046     WorkGroupParams test_params_arith(global_work_size, local_work_size);
1047     test_params_arith.save_kernel_source(sub_group_ballot_bit_scan_find_source);
1048     test_params_arith.save_kernel_source(sub_group_inverse_ballot_source,
1049                                          "sub_group_inverse_ballot");
1050     test_params_arith.save_kernel_source(sub_group_ballot_bit_extract_source,
1051                                          "sub_group_ballot_bit_extract");
1052     RunTestForType rft_arith(device, context, queue, num_elements,
1053                              test_params_arith);
1054     error |=
1055         rft_arith.run_impl<cl_uint4,
1056                            BALLOT_INVERSE<cl_uint4, BallotOp::inverse_ballot>>(
1057             "sub_group_inverse_ballot");
1058     error |= rft_arith.run_impl<
1059         cl_uint4, BALLOT_BIT_EXTRACT<cl_uint4, BallotOp::ballot_bit_extract>>(
1060         "sub_group_ballot_bit_extract");
1061     error |= rft_arith.run_impl<
1062         cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_bit_count>>(
1063         "sub_group_ballot_bit_count");
1064     error |= rft_arith.run_impl<
1065         cl_uint4,
1066         BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_inclusive_scan>>(
1067         "sub_group_ballot_inclusive_scan");
1068     error |= rft_arith.run_impl<
1069         cl_uint4,
1070         BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_exclusive_scan>>(
1071         "sub_group_ballot_exclusive_scan");
1072     error |= rft_arith.run_impl<
1073         cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_lsb>>(
1074         "sub_group_ballot_find_lsb");
1075     error |= rft_arith.run_impl<
1076         cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_msb>>(
1077         "sub_group_ballot_find_msb");
1078 
1079     return error;
1080 }
1081