1 //
2 // Copyright (c) 2021 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #include "procs.h"
17 #include "subhelpers.h"
18 #include "subgroup_common_templates.h"
19 #include "harness/typeWrappers.h"
20 #include <bitset>
21
22 namespace {
23 // Test for ballot functions
24 template <typename Ty> struct BALLOT
25 {
log_test__anoncae851e40111::BALLOT26 static void log_test(const WorkGroupParams &test_params,
27 const char *extra_text)
28 {
29 log_info(" sub_group_ballot...%s\n", extra_text);
30 }
31
gen__anoncae851e40111::BALLOT32 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
33 {
34 int gws = test_params.global_workgroup_size;
35 int lws = test_params.local_workgroup_size;
36 int sbs = test_params.subgroup_size;
37 int sb_number = (lws + sbs - 1) / sbs;
38 int non_uniform_size = gws % lws;
39 int wg_number = gws / lws;
40 wg_number = non_uniform_size ? wg_number + 1 : wg_number;
41 int last_subgroup_size = 0;
42
43 for (int wg_id = 0; wg_id < wg_number; ++wg_id)
44 { // for each work_group
45 if (non_uniform_size && wg_id == wg_number - 1)
46 {
47 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
48 last_subgroup_size);
49 }
50 for (int sb_id = 0; sb_id < sb_number; ++sb_id)
51 { // for each subgroup
52 int wg_offset = sb_id * sbs;
53 int current_sbs;
54 if (last_subgroup_size && sb_id == sb_number - 1)
55 {
56 current_sbs = last_subgroup_size;
57 }
58 else
59 {
60 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
61 }
62
63 for (int wi_id = 0; wi_id < current_sbs; wi_id++)
64 {
65 cl_uint v;
66 if (genrand_bool(gMTdata))
67 {
68 v = genrand_bool(gMTdata);
69 }
70 else if (genrand_bool(gMTdata))
71 {
72 v = 1U << ((genrand_int32(gMTdata) % 31) + 1);
73 }
74 else
75 {
76 v = genrand_int32(gMTdata);
77 }
78 cl_uint4 v4 = { { v, 0, 0, 0 } };
79 t[wi_id + wg_offset] = v4;
80 }
81 }
82 // Now map into work group using map from device
83 for (int wi_id = 0; wi_id < lws; ++wi_id)
84 {
85 x[wi_id] = t[wi_id];
86 }
87 x += lws;
88 m += 4 * lws;
89 }
90 }
91
chk__anoncae851e40111::BALLOT92 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
93 const WorkGroupParams &test_params)
94 {
95 int gws = test_params.global_workgroup_size;
96 int lws = test_params.local_workgroup_size;
97 int sbs = test_params.subgroup_size;
98 int sb_number = (lws + sbs - 1) / sbs;
99 int non_uniform_size = gws % lws;
100 int wg_number = gws / lws;
101 wg_number = non_uniform_size ? wg_number + 1 : wg_number;
102 int last_subgroup_size = 0;
103
104 for (int wg_id = 0; wg_id < wg_number; ++wg_id)
105 { // for each work_group
106 if (non_uniform_size && wg_id == wg_number - 1)
107 {
108 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
109 last_subgroup_size);
110 }
111 for (int wi_id = 0; wi_id < lws; ++wi_id)
112 { // inside the work_group
113 mx[wi_id] = x[wi_id]; // read host inputs for work_group
114 my[wi_id] = y[wi_id]; // read device outputs for work_group
115 }
116
117 for (int sb_id = 0; sb_id < sb_number; ++sb_id)
118 { // for each subgroup
119 int wg_offset = sb_id * sbs;
120 int current_sbs;
121 if (last_subgroup_size && sb_id == sb_number - 1)
122 {
123 current_sbs = last_subgroup_size;
124 }
125 else
126 {
127 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
128 }
129
130 bs128 expected_result_bs = 0;
131
132 std::set<int> active_work_items;
133 for (int wi_id = 0; wi_id < current_sbs; ++wi_id)
134 {
135 if (test_params.work_items_mask.test(wi_id))
136 {
137 bool predicate = (mx[wg_offset + wi_id].s0 != 0);
138 expected_result_bs |= (bs128(predicate) << wi_id);
139 active_work_items.insert(wi_id);
140 }
141 }
142 if (active_work_items.empty())
143 {
144 continue;
145 }
146
147 cl_uint4 expected_result =
148 bs128_to_cl_uint4(expected_result_bs);
149 for (const int &active_work_item : active_work_items)
150 {
151 int wi_id = active_work_item;
152
153 cl_uint4 device_result = my[wg_offset + wi_id];
154 bs128 device_result_bs = cl_uint4_to_bs128(device_result);
155
156 if (device_result_bs != expected_result_bs)
157 {
158 log_error(
159 "ERROR: sub_group_ballot mismatch for local id "
160 "%d in sub group %d in group %d obtained {%d, %d, "
161 "%d, %d}, expected {%d, %d, %d, %d}\n",
162 wi_id, sb_id, wg_id, device_result.s0,
163 device_result.s1, device_result.s2,
164 device_result.s3, expected_result.s0,
165 expected_result.s1, expected_result.s2,
166 expected_result.s3);
167 return TEST_FAIL;
168 }
169 }
170 }
171
172 x += lws;
173 y += lws;
174 m += 4 * lws;
175 }
176
177 return TEST_PASS;
178 }
179 };
180
181 // Test for bit extract ballot functions
182 template <typename Ty, BallotOp operation> struct BALLOT_BIT_EXTRACT
183 {
log_test__anoncae851e40111::BALLOT_BIT_EXTRACT184 static void log_test(const WorkGroupParams &test_params,
185 const char *extra_text)
186 {
187 log_info(" sub_group_ballot_%s(%s)...%s\n", operation_names(operation),
188 TypeManager<Ty>::name(), extra_text);
189 }
190
gen__anoncae851e40111::BALLOT_BIT_EXTRACT191 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
192 {
193 int wi_id, sb_id, wg_id;
194 int gws = test_params.global_workgroup_size;
195 int lws = test_params.local_workgroup_size;
196 int sbs = test_params.subgroup_size;
197 int sb_number = (lws + sbs - 1) / sbs;
198 int wg_number = gws / lws;
199 int limit_sbs = sbs > 100 ? 100 : sbs;
200
201 for (wg_id = 0; wg_id < wg_number; ++wg_id)
202 { // for each work_group
203 for (sb_id = 0; sb_id < sb_number; ++sb_id)
204 { // for each subgroup
205 int wg_offset = sb_id * sbs;
206 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
207 // rand index to bit extract
208 int index_for_odd = (int)(genrand_int32(gMTdata) & 0x7fffffff)
209 % (limit_sbs > current_sbs ? current_sbs : limit_sbs);
210 int index_for_even = (int)(genrand_int32(gMTdata) & 0x7fffffff)
211 % (limit_sbs > current_sbs ? current_sbs : limit_sbs);
212 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
213 {
214 // index of the third element int the vector.
215 int midx = 4 * wg_offset + 4 * wi_id + 2;
216 // storing information about index to bit extract
217 m[midx] = (cl_int)index_for_odd;
218 m[++midx] = (cl_int)index_for_even;
219 }
220 set_randomdata_for_subgroup<Ty>(t, wg_offset, current_sbs);
221 }
222
223 // Now map into work group using map from device
224 for (wi_id = 0; wi_id < lws; ++wi_id)
225 {
226 x[wi_id] = t[wi_id];
227 }
228
229 x += lws;
230 m += 4 * lws;
231 }
232 }
233
chk__anoncae851e40111::BALLOT_BIT_EXTRACT234 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
235 const WorkGroupParams &test_params)
236 {
237 int wi_id, wg_id, sb_id;
238 int gws = test_params.global_workgroup_size;
239 int lws = test_params.local_workgroup_size;
240 int sbs = test_params.subgroup_size;
241 int sb_number = (lws + sbs - 1) / sbs;
242 int wg_number = gws / lws;
243 cl_uint4 expected_result, device_result;
244 int last_subgroup_size = 0;
245 int current_sbs = 0;
246 int non_uniform_size = gws % lws;
247
248 for (wg_id = 0; wg_id < wg_number; ++wg_id)
249 { // for each work_group
250 if (non_uniform_size && wg_id == wg_number - 1)
251 {
252 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
253 last_subgroup_size);
254 }
255 // Map to array indexed to array indexed by local ID and sub group
256 for (wi_id = 0; wi_id < lws; ++wi_id)
257 { // inside the work_group
258 // read host inputs for work_group
259 mx[wi_id] = x[wi_id];
260 // read device outputs for work_group
261 my[wi_id] = y[wi_id];
262 }
263
264 for (sb_id = 0; sb_id < sb_number; ++sb_id)
265 { // for each subgroup
266 int wg_offset = sb_id * sbs;
267 if (last_subgroup_size && sb_id == sb_number - 1)
268 {
269 current_sbs = last_subgroup_size;
270 }
271 else
272 {
273 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
274 }
275 // take index of array where info which work_item will
276 // be broadcast its value is stored
277 int midx = 4 * wg_offset + 2;
278 // take subgroup local id of this work_item
279 int index_for_odd = (int)m[midx];
280 int index_for_even = (int)m[++midx];
281
282 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
283 { // for each subgroup
284 int bit_value = 0;
285 // from which value of bitfield bit
286 // verification will be done
287 int take_shift =
288 (wi_id & 1) ? index_for_odd % 32 : index_for_even % 32;
289 int bit_mask = 1 << take_shift;
290
291 if (wi_id < 32)
292 (mx[wg_offset + wi_id].s0 & bit_mask) > 0
293 ? bit_value = 1
294 : bit_value = 0;
295 if (wi_id >= 32 && wi_id < 64)
296 (mx[wg_offset + wi_id].s1 & bit_mask) > 0
297 ? bit_value = 1
298 : bit_value = 0;
299 if (wi_id >= 64 && wi_id < 96)
300 (mx[wg_offset + wi_id].s2 & bit_mask) > 0
301 ? bit_value = 1
302 : bit_value = 0;
303 if (wi_id >= 96 && wi_id < 128)
304 (mx[wg_offset + wi_id].s3 & bit_mask) > 0
305 ? bit_value = 1
306 : bit_value = 0;
307
308 if (wi_id & 1)
309 {
310 bit_value ? expected_result = { { 1, 0, 0, 1 } }
311 : expected_result = { { 0, 0, 0, 1 } };
312 }
313 else
314 {
315 bit_value ? expected_result = { { 1, 0, 0, 2 } }
316 : expected_result = { { 0, 0, 0, 2 } };
317 }
318
319 device_result = my[wg_offset + wi_id];
320 if (!compare(device_result, expected_result))
321 {
322 log_error(
323 "ERROR: sub_group_%s mismatch for local id %d in "
324 "sub group %d in group %d obtained {%d, %d, %d, "
325 "%d}, expected {%d, %d, %d, %d}\n",
326 operation_names(operation), wi_id, sb_id, wg_id,
327 device_result.s0, device_result.s1,
328 device_result.s2, device_result.s3,
329 expected_result.s0, expected_result.s1,
330 expected_result.s2, expected_result.s3);
331 return TEST_FAIL;
332 }
333 }
334 }
335 x += lws;
336 y += lws;
337 m += 4 * lws;
338 }
339 return TEST_PASS;
340 }
341 };
342
343 template <typename Ty, BallotOp operation> struct BALLOT_INVERSE
344 {
log_test__anoncae851e40111::BALLOT_INVERSE345 static void log_test(const WorkGroupParams &test_params,
346 const char *extra_text)
347 {
348 log_info(" sub_group_inverse_ballot...%s\n", extra_text);
349 }
350
gen__anoncae851e40111::BALLOT_INVERSE351 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
352 {
353 // no work here
354 }
355
chk__anoncae851e40111::BALLOT_INVERSE356 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
357 const WorkGroupParams &test_params)
358 {
359 int wi_id, wg_id, sb_id;
360 int gws = test_params.global_workgroup_size;
361 int lws = test_params.local_workgroup_size;
362 int sbs = test_params.subgroup_size;
363 int sb_number = (lws + sbs - 1) / sbs;
364 cl_uint4 expected_result, device_result;
365 int non_uniform_size = gws % lws;
366 int wg_number = gws / lws;
367 int last_subgroup_size = 0;
368 int current_sbs = 0;
369 if (non_uniform_size) wg_number++;
370
371 for (wg_id = 0; wg_id < wg_number; ++wg_id)
372 { // for each work_group
373 if (non_uniform_size && wg_id == wg_number - 1)
374 {
375 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
376 last_subgroup_size);
377 }
378 // Map to array indexed to array indexed by local ID and sub group
379 for (wi_id = 0; wi_id < lws; ++wi_id)
380 { // inside the work_group
381 mx[wi_id] = x[wi_id]; // read host inputs for work_group
382 my[wi_id] = y[wi_id]; // read device outputs for work_group
383 }
384
385 for (sb_id = 0; sb_id < sb_number; ++sb_id)
386 { // for each subgroup
387 int wg_offset = sb_id * sbs;
388 if (last_subgroup_size && sb_id == sb_number - 1)
389 {
390 current_sbs = last_subgroup_size;
391 }
392 else
393 {
394 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
395 }
396 // take subgroup local id of this work_item
397 // Check result
398 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
399 { // for each subgroup work item
400
401 wi_id & 1 ? expected_result = { { 1, 0, 0, 1 } }
402 : expected_result = { { 1, 0, 0, 2 } };
403
404 device_result = my[wg_offset + wi_id];
405 if (!compare(device_result, expected_result))
406 {
407 log_error(
408 "ERROR: sub_group_%s mismatch for local id %d in "
409 "sub group %d in group %d obtained {%d, %d, %d, "
410 "%d}, expected {%d, %d, %d, %d}\n",
411 operation_names(operation), wi_id, sb_id, wg_id,
412 device_result.s0, device_result.s1,
413 device_result.s2, device_result.s3,
414 expected_result.s0, expected_result.s1,
415 expected_result.s2, expected_result.s3);
416 return TEST_FAIL;
417 }
418 }
419 }
420 x += lws;
421 y += lws;
422 m += 4 * lws;
423 }
424
425 return TEST_PASS;
426 }
427 };
428
429
430 // Test for bit count/inclusive and exclusive scan/ find lsb msb ballot function
431 template <typename Ty, BallotOp operation> struct BALLOT_COUNT_SCAN_FIND
432 {
log_test__anoncae851e40111::BALLOT_COUNT_SCAN_FIND433 static void log_test(const WorkGroupParams &test_params,
434 const char *extra_text)
435 {
436 log_info(" sub_group_%s(%s)...%s\n", operation_names(operation),
437 TypeManager<Ty>::name(), extra_text);
438 }
439
gen__anoncae851e40111::BALLOT_COUNT_SCAN_FIND440 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
441 {
442 int wi_id, wg_id, sb_id;
443 int gws = test_params.global_workgroup_size;
444 int lws = test_params.local_workgroup_size;
445 int sbs = test_params.subgroup_size;
446 int sb_number = (lws + sbs - 1) / sbs;
447 int non_uniform_size = gws % lws;
448 int wg_number = gws / lws;
449 int last_subgroup_size = 0;
450 int current_sbs = 0;
451
452 if (non_uniform_size)
453 {
454 wg_number++;
455 }
456 for (wg_id = 0; wg_id < wg_number; ++wg_id)
457 { // for each work_group
458 if (non_uniform_size && wg_id == wg_number - 1)
459 {
460 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
461 last_subgroup_size);
462 }
463 for (sb_id = 0; sb_id < sb_number; ++sb_id)
464 { // for each subgroup
465 int wg_offset = sb_id * sbs;
466 if (last_subgroup_size && sb_id == sb_number - 1)
467 {
468 current_sbs = last_subgroup_size;
469 }
470 else
471 {
472 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
473 }
474 if (operation == BallotOp::ballot_bit_count
475 || operation == BallotOp::ballot_inclusive_scan
476 || operation == BallotOp::ballot_exclusive_scan)
477 {
478 set_randomdata_for_subgroup<Ty>(t, wg_offset, current_sbs);
479 }
480 else if (operation == BallotOp::ballot_find_lsb
481 || operation == BallotOp::ballot_find_msb)
482 {
483 // Regarding to the spec, find lsb and find msb result is
484 // undefined behavior if input value is zero, so generate
485 // only non-zero values.
486 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
487 {
488 char x = (genrand_int32(gMTdata)) & 0xff;
489 // undefined behaviour in case of 0;
490 x = x ? x : 1;
491 memset(&t[wg_offset + wi_id], x, sizeof(Ty));
492 }
493 }
494 else
495 {
496 log_error("Unknown operation...\n");
497 }
498 }
499
500 // Now map into work group using map from device
501 for (wi_id = 0; wi_id < lws; ++wi_id)
502 {
503 x[wi_id] = t[wi_id];
504 }
505
506 x += lws;
507 m += 4 * lws;
508 }
509 }
510
getImportantBits__anoncae851e40111::BALLOT_COUNT_SCAN_FIND511 static bs128 getImportantBits(cl_uint sub_group_local_id,
512 cl_uint sub_group_size)
513 {
514 bs128 mask;
515 if (operation == BallotOp::ballot_bit_count
516 || operation == BallotOp::ballot_find_lsb
517 || operation == BallotOp::ballot_find_msb)
518 {
519 for (cl_uint i = 0; i < sub_group_size; ++i) mask.set(i);
520 }
521 else if (operation == BallotOp::ballot_inclusive_scan
522 || operation == BallotOp::ballot_exclusive_scan)
523 {
524 for (cl_uint i = 0; i < sub_group_local_id; ++i) mask.set(i);
525 if (operation == BallotOp::ballot_inclusive_scan)
526 mask.set(sub_group_local_id);
527 }
528 return mask;
529 }
530
chk__anoncae851e40111::BALLOT_COUNT_SCAN_FIND531 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
532 const WorkGroupParams &test_params)
533 {
534 int wi_id, wg_id, sb_id;
535 int gws = test_params.global_workgroup_size;
536 int lws = test_params.local_workgroup_size;
537 int sbs = test_params.subgroup_size;
538 int sb_number = (lws + sbs - 1) / sbs;
539 int non_uniform_size = gws % lws;
540 int wg_number = gws / lws;
541 wg_number = non_uniform_size ? wg_number + 1 : wg_number;
542 cl_uint expected_result, device_result;
543 int last_subgroup_size = 0;
544 int current_sbs = 0;
545
546 for (wg_id = 0; wg_id < wg_number; ++wg_id)
547 { // for each work_group
548 if (non_uniform_size && wg_id == wg_number - 1)
549 {
550 set_last_workgroup_params(non_uniform_size, sb_number, sbs, lws,
551 last_subgroup_size);
552 }
553 // Map to array indexed to array indexed by local ID and sub group
554 for (wi_id = 0; wi_id < lws; ++wi_id)
555 { // inside the work_group
556 // read host inputs for work_group
557 mx[wi_id] = x[wi_id];
558 // read device outputs for work_group
559 my[wi_id] = y[wi_id];
560 }
561
562 for (sb_id = 0; sb_id < sb_number; ++sb_id)
563 { // for each subgroup
564 int wg_offset = sb_id * sbs;
565 if (last_subgroup_size && sb_id == sb_number - 1)
566 {
567 current_sbs = last_subgroup_size;
568 }
569 else
570 {
571 current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
572 }
573 // Check result
574 expected_result = 0;
575 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
576 { // for subgroup element
577 bs128 bs;
578 // convert cl_uint4 input into std::bitset<128>
579 bs |= bs128(mx[wg_offset + wi_id].s0)
580 | (bs128(mx[wg_offset + wi_id].s1) << 32)
581 | (bs128(mx[wg_offset + wi_id].s2) << 64)
582 | (bs128(mx[wg_offset + wi_id].s3) << 96);
583 bs &= getImportantBits(wi_id, sbs);
584 device_result = my[wg_offset + wi_id].s0;
585 if (operation == BallotOp::ballot_inclusive_scan
586 || operation == BallotOp::ballot_exclusive_scan
587 || operation == BallotOp::ballot_bit_count)
588 {
589 expected_result = bs.count();
590 if (!compare(device_result, expected_result))
591 {
592 log_error("ERROR: sub_group_%s "
593 "mismatch for local id %d in sub group "
594 "%d in group %d obtained %d, "
595 "expected %d\n",
596 operation_names(operation), wi_id, sb_id,
597 wg_id, device_result, expected_result);
598 return TEST_FAIL;
599 }
600 }
601 else if (operation == BallotOp::ballot_find_lsb)
602 {
603 if (bs.none())
604 {
605 // Return value is undefined when no bits are set,
606 // so skip validation:
607 continue;
608 }
609 for (int id = 0; id < sbs; ++id)
610 {
611 if (bs.test(id))
612 {
613 expected_result = id;
614 break;
615 }
616 }
617 if (!compare(device_result, expected_result))
618 {
619 log_error("ERROR: sub_group_ballot_find_lsb "
620 "mismatch for local id %d in sub group "
621 "%d in group %d obtained %d, "
622 "expected %d\n",
623 wi_id, sb_id, wg_id, device_result,
624 expected_result);
625 return TEST_FAIL;
626 }
627 }
628 else if (operation == BallotOp::ballot_find_msb)
629 {
630 if (bs.none())
631 {
632 // Return value is undefined when no bits are set,
633 // so skip validation:
634 continue;
635 }
636 for (int id = sbs - 1; id >= 0; --id)
637 {
638 if (bs.test(id))
639 {
640 expected_result = id;
641 break;
642 }
643 }
644 if (!compare(device_result, expected_result))
645 {
646 log_error("ERROR: sub_group_ballot_find_msb "
647 "mismatch for local id %d in sub group "
648 "%d in group %d obtained %d, "
649 "expected %d\n",
650 wi_id, sb_id, wg_id, device_result,
651 expected_result);
652 return TEST_FAIL;
653 }
654 }
655 }
656 }
657 x += lws;
658 y += lws;
659 m += 4 * lws;
660 }
661 return TEST_PASS;
662 }
663 };
664
665 // test mask functions
666 template <typename Ty, BallotOp operation> struct SMASK
667 {
log_test__anoncae851e40111::SMASK668 static void log_test(const WorkGroupParams &test_params,
669 const char *extra_text)
670 {
671 log_info(" get_sub_group_%s_mask...%s\n", operation_names(operation),
672 extra_text);
673 }
674
gen__anoncae851e40111::SMASK675 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
676 {
677 int wi_id, wg_id, sb_id;
678 int gws = test_params.global_workgroup_size;
679 int lws = test_params.local_workgroup_size;
680 int sbs = test_params.subgroup_size;
681 int sb_number = (lws + sbs - 1) / sbs;
682 int wg_number = gws / lws;
683 for (wg_id = 0; wg_id < wg_number; ++wg_id)
684 { // for each work_group
685 for (sb_id = 0; sb_id < sb_number; ++sb_id)
686 { // for each subgroup
687 int wg_offset = sb_id * sbs;
688 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
689 // Produce expected masks for each work item in the subgroup
690 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
691 {
692 int midx = 4 * wg_offset + 4 * wi_id;
693 cl_uint max_sub_group_size = m[midx + 2];
694 cl_uint4 expected_mask = { { 0 } };
695 expected_mask = generate_bit_mask(
696 wi_id, operation_names(operation), max_sub_group_size);
697 set_value(t[wg_offset + wi_id], expected_mask);
698 }
699 }
700
701 // Now map into work group using map from device
702 for (wi_id = 0; wi_id < lws; ++wi_id)
703 {
704 x[wi_id] = t[wi_id];
705 }
706 x += lws;
707 m += 4 * lws;
708 }
709 }
710
chk__anoncae851e40111::SMASK711 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
712 const WorkGroupParams &test_params)
713 {
714 int wi_id, wg_id, sb_id;
715 int gws = test_params.global_workgroup_size;
716 int lws = test_params.local_workgroup_size;
717 int sbs = test_params.subgroup_size;
718 int sb_number = (lws + sbs - 1) / sbs;
719 Ty expected_result, device_result;
720 int wg_number = gws / lws;
721
722 for (wg_id = 0; wg_id < wg_number; ++wg_id)
723 { // for each work_group
724 for (wi_id = 0; wi_id < lws; ++wi_id)
725 { // inside the work_group
726 mx[wi_id] = x[wi_id]; // read host inputs for work_group
727 my[wi_id] = y[wi_id]; // read device outputs for work_group
728 }
729
730 for (sb_id = 0; sb_id < sb_number; ++sb_id)
731 {
732 int wg_offset = sb_id * sbs;
733 int current_sbs = wg_offset + sbs > lws ? lws - wg_offset : sbs;
734
735 // Check result
736 for (wi_id = 0; wi_id < current_sbs; ++wi_id)
737 { // inside the subgroup
738 expected_result =
739 mx[wg_offset + wi_id]; // read host input for subgroup
740 device_result =
741 my[wg_offset
742 + wi_id]; // read device outputs for subgroup
743 if (!compare(device_result, expected_result))
744 {
745 log_error("ERROR: get_sub_group_%s_mask... mismatch "
746 "for local id %d in sub group %d in group "
747 "%d, %s\n",
748 operation_names(operation), wi_id, sb_id,
749 wg_id,
750 print_expected_obtained(expected_result,
751 device_result)
752 .c_str());
753 return TEST_FAIL;
754 }
755 }
756 }
757 x += lws;
758 y += lws;
759 m += 4 * lws;
760 }
761 return TEST_PASS;
762 }
763 };
764
765 std::string sub_group_non_uniform_broadcast_source = R"(
766 __kernel void test_sub_group_non_uniform_broadcast(const __global Type *in, __global int4 *xy, __global Type *out) {
767 int gid = get_global_id(0);
768 XY(xy,gid);
769 Type x = in[gid];
770 if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {
771 out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].z);
772 } else {
773 out[gid] = sub_group_non_uniform_broadcast(x, xy[gid].w);
774 }
775 }
776 )";
777 std::string sub_group_broadcast_first_source = R"(
778 __kernel void test_sub_group_broadcast_first(const __global Type *in, __global int4 *xy, __global Type *out) {
779 int gid = get_global_id(0);
780 XY(xy,gid);
781 Type x = in[gid];
782 if (xy[gid].x < NR_OF_ACTIVE_WORK_ITEMS) {
783 out[gid] = sub_group_broadcast_first(x);;
784 } else {
785 out[gid] = sub_group_broadcast_first(x);;
786 }
787 }
788 )";
789 std::string sub_group_ballot_bit_scan_find_source = R"(
790 __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out) {
791 int gid = get_global_id(0);
792 XY(xy,gid);
793 Type x = in[gid];
794 uint4 value = (uint4)(0,0,0,0);
795 value = (uint4)(%s(x),0,0,0);
796 out[gid] = value;
797 }
798 )";
799 std::string sub_group_ballot_mask_source = R"(
800 __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out) {
801 int gid = get_global_id(0);
802 XY(xy,gid);
803 xy[gid].z = get_max_sub_group_size();
804 Type x = in[gid];
805 uint4 mask = %s();
806 out[gid] = mask;
807 }
808 )";
809 std::string sub_group_ballot_source = R"(
810 __kernel void test_sub_group_ballot(const __global Type *in, __global int4 *xy, __global Type *out, uint4 work_item_mask_vector) {
811 uint gid = get_global_id(0);
812 XY(xy,gid);
813 uint subgroup_local_id = get_sub_group_local_id();
814 uint elect_work_item = 1 << (subgroup_local_id % 32);
815 uint work_item_mask;
816 if (subgroup_local_id < 32) {
817 work_item_mask = work_item_mask_vector.x;
818 } else if(subgroup_local_id < 64) {
819 work_item_mask = work_item_mask_vector.y;
820 } else if(subgroup_local_id < 96) {
821 work_item_mask = work_item_mask_vector.z;
822 } else if(subgroup_local_id < 128) {
823 work_item_mask = work_item_mask_vector.w;
824 }
825 uint4 value = (uint4)(0, 0, 0, 0);
826 if (elect_work_item & work_item_mask) {
827 value = sub_group_ballot(in[gid].s0);
828 }
829 out[gid] = value;
830 }
831 )";
832 std::string sub_group_inverse_ballot_source = R"(
833 __kernel void test_sub_group_inverse_ballot(const __global Type *in, __global int4 *xy, __global Type *out) {
834 int gid = get_global_id(0);
835 XY(xy,gid);
836 Type x = in[gid];
837 uint4 value = (uint4)(10,0,0,0);
838 if (get_sub_group_local_id() & 1) {
839 uint4 partial_ballot_mask = (uint4)(0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA,0xAAAAAAAA);
840 if (sub_group_inverse_ballot(partial_ballot_mask)) {
841 value = (uint4)(1,0,0,1);
842 } else {
843 value = (uint4)(0,0,0,1);
844 }
845 } else {
846 uint4 partial_ballot_mask = (uint4)(0x55555555,0x55555555,0x55555555,0x55555555);
847 if (sub_group_inverse_ballot(partial_ballot_mask)) {
848 value = (uint4)(1,0,0,2);
849 } else {
850 value = (uint4)(0,0,0,2);
851 }
852 }
853 out[gid] = value;
854 }
855 )";
856 std::string sub_group_ballot_bit_extract_source = R"(
857 __kernel void test_sub_group_ballot_bit_extract(const __global Type *in, __global int4 *xy, __global Type *out) {
858 int gid = get_global_id(0);
859 XY(xy,gid);
860 Type x = in[gid];
861 uint index = xy[gid].z;
862 uint4 value = (uint4)(10,0,0,0);
863 if (get_sub_group_local_id() & 1) {
864 if (sub_group_ballot_bit_extract(x, xy[gid].z)) {
865 value = (uint4)(1,0,0,1);
866 } else {
867 value = (uint4)(0,0,0,1);
868 }
869 } else {
870 if (sub_group_ballot_bit_extract(x, xy[gid].w)) {
871 value = (uint4)(1,0,0,2);
872 } else {
873 value = (uint4)(0,0,0,2);
874 }
875 }
876 out[gid] = value;
877 }
878 )";
879
run_non_uniform_broadcast_for_type(RunTestForType rft)880 template <typename T> int run_non_uniform_broadcast_for_type(RunTestForType rft)
881 {
882 int error =
883 rft.run_impl<T, BC<T, SubgroupsBroadcastOp::non_uniform_broadcast>>(
884 "sub_group_non_uniform_broadcast");
885 return error;
886 }
887
888
889 }
890
test_subgroup_functions_ballot(cl_device_id device,cl_context context,cl_command_queue queue,int num_elements)891 int test_subgroup_functions_ballot(cl_device_id device, cl_context context,
892 cl_command_queue queue, int num_elements)
893 {
894 if (!is_extension_available(device, "cl_khr_subgroup_ballot"))
895 {
896 log_info("cl_khr_subgroup_ballot is not supported on this device, "
897 "skipping test.\n");
898 return TEST_SKIPPED_ITSELF;
899 }
900
901 constexpr size_t global_work_size = 170;
902 constexpr size_t local_work_size = 64;
903 WorkGroupParams test_params(global_work_size, local_work_size);
904 test_params.save_kernel_source(sub_group_ballot_mask_source);
905 test_params.save_kernel_source(sub_group_non_uniform_broadcast_source,
906 "sub_group_non_uniform_broadcast");
907 test_params.save_kernel_source(sub_group_broadcast_first_source,
908 "sub_group_broadcast_first");
909 RunTestForType rft(device, context, queue, num_elements, test_params);
910
911 // non uniform broadcast functions
912 int error = run_non_uniform_broadcast_for_type<cl_int>(rft);
913 error |= run_non_uniform_broadcast_for_type<cl_int2>(rft);
914 error |= run_non_uniform_broadcast_for_type<subgroups::cl_int3>(rft);
915 error |= run_non_uniform_broadcast_for_type<cl_int4>(rft);
916 error |= run_non_uniform_broadcast_for_type<cl_int8>(rft);
917 error |= run_non_uniform_broadcast_for_type<cl_int16>(rft);
918
919 error |= run_non_uniform_broadcast_for_type<cl_uint>(rft);
920 error |= run_non_uniform_broadcast_for_type<cl_uint2>(rft);
921 error |= run_non_uniform_broadcast_for_type<subgroups::cl_uint3>(rft);
922 error |= run_non_uniform_broadcast_for_type<cl_uint4>(rft);
923 error |= run_non_uniform_broadcast_for_type<cl_uint8>(rft);
924 error |= run_non_uniform_broadcast_for_type<cl_uint16>(rft);
925
926 error |= run_non_uniform_broadcast_for_type<cl_char>(rft);
927 error |= run_non_uniform_broadcast_for_type<cl_char2>(rft);
928 error |= run_non_uniform_broadcast_for_type<subgroups::cl_char3>(rft);
929 error |= run_non_uniform_broadcast_for_type<cl_char4>(rft);
930 error |= run_non_uniform_broadcast_for_type<cl_char8>(rft);
931 error |= run_non_uniform_broadcast_for_type<cl_char16>(rft);
932
933 error |= run_non_uniform_broadcast_for_type<cl_uchar>(rft);
934 error |= run_non_uniform_broadcast_for_type<cl_uchar2>(rft);
935 error |= run_non_uniform_broadcast_for_type<subgroups::cl_uchar3>(rft);
936 error |= run_non_uniform_broadcast_for_type<cl_uchar4>(rft);
937 error |= run_non_uniform_broadcast_for_type<cl_uchar8>(rft);
938 error |= run_non_uniform_broadcast_for_type<cl_uchar16>(rft);
939
940 error |= run_non_uniform_broadcast_for_type<cl_short>(rft);
941 error |= run_non_uniform_broadcast_for_type<cl_short2>(rft);
942 error |= run_non_uniform_broadcast_for_type<subgroups::cl_short3>(rft);
943 error |= run_non_uniform_broadcast_for_type<cl_short4>(rft);
944 error |= run_non_uniform_broadcast_for_type<cl_short8>(rft);
945 error |= run_non_uniform_broadcast_for_type<cl_short16>(rft);
946
947 error |= run_non_uniform_broadcast_for_type<cl_ushort>(rft);
948 error |= run_non_uniform_broadcast_for_type<cl_ushort2>(rft);
949 error |= run_non_uniform_broadcast_for_type<subgroups::cl_ushort3>(rft);
950 error |= run_non_uniform_broadcast_for_type<cl_ushort4>(rft);
951 error |= run_non_uniform_broadcast_for_type<cl_ushort8>(rft);
952 error |= run_non_uniform_broadcast_for_type<cl_ushort16>(rft);
953
954 error |= run_non_uniform_broadcast_for_type<cl_long>(rft);
955 error |= run_non_uniform_broadcast_for_type<cl_long2>(rft);
956 error |= run_non_uniform_broadcast_for_type<subgroups::cl_long3>(rft);
957 error |= run_non_uniform_broadcast_for_type<cl_long4>(rft);
958 error |= run_non_uniform_broadcast_for_type<cl_long8>(rft);
959 error |= run_non_uniform_broadcast_for_type<cl_long16>(rft);
960
961 error |= run_non_uniform_broadcast_for_type<cl_ulong>(rft);
962 error |= run_non_uniform_broadcast_for_type<cl_ulong2>(rft);
963 error |= run_non_uniform_broadcast_for_type<subgroups::cl_ulong3>(rft);
964 error |= run_non_uniform_broadcast_for_type<cl_ulong4>(rft);
965 error |= run_non_uniform_broadcast_for_type<cl_ulong8>(rft);
966 error |= run_non_uniform_broadcast_for_type<cl_ulong16>(rft);
967
968 error |= run_non_uniform_broadcast_for_type<cl_float>(rft);
969 error |= run_non_uniform_broadcast_for_type<cl_float2>(rft);
970 error |= run_non_uniform_broadcast_for_type<subgroups::cl_float3>(rft);
971 error |= run_non_uniform_broadcast_for_type<cl_float4>(rft);
972 error |= run_non_uniform_broadcast_for_type<cl_float8>(rft);
973 error |= run_non_uniform_broadcast_for_type<cl_float16>(rft);
974
975 error |= run_non_uniform_broadcast_for_type<cl_double>(rft);
976 error |= run_non_uniform_broadcast_for_type<cl_double2>(rft);
977 error |= run_non_uniform_broadcast_for_type<subgroups::cl_double3>(rft);
978 error |= run_non_uniform_broadcast_for_type<cl_double4>(rft);
979 error |= run_non_uniform_broadcast_for_type<cl_double8>(rft);
980 error |= run_non_uniform_broadcast_for_type<cl_double16>(rft);
981
982 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half>(rft);
983 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half2>(rft);
984 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half3>(rft);
985 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half4>(rft);
986 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half8>(rft);
987 error |= run_non_uniform_broadcast_for_type<subgroups::cl_half16>(rft);
988
989 // broadcast first functions
990 error |=
991 rft.run_impl<cl_int, BC<cl_int, SubgroupsBroadcastOp::broadcast_first>>(
992 "sub_group_broadcast_first");
993 error |= rft.run_impl<cl_uint,
994 BC<cl_uint, SubgroupsBroadcastOp::broadcast_first>>(
995 "sub_group_broadcast_first");
996 error |= rft.run_impl<cl_long,
997 BC<cl_long, SubgroupsBroadcastOp::broadcast_first>>(
998 "sub_group_broadcast_first");
999 error |= rft.run_impl<cl_ulong,
1000 BC<cl_ulong, SubgroupsBroadcastOp::broadcast_first>>(
1001 "sub_group_broadcast_first");
1002 error |= rft.run_impl<cl_short,
1003 BC<cl_short, SubgroupsBroadcastOp::broadcast_first>>(
1004 "sub_group_broadcast_first");
1005 error |= rft.run_impl<cl_ushort,
1006 BC<cl_ushort, SubgroupsBroadcastOp::broadcast_first>>(
1007 "sub_group_broadcast_first");
1008 error |= rft.run_impl<cl_char,
1009 BC<cl_char, SubgroupsBroadcastOp::broadcast_first>>(
1010 "sub_group_broadcast_first");
1011 error |= rft.run_impl<cl_uchar,
1012 BC<cl_uchar, SubgroupsBroadcastOp::broadcast_first>>(
1013 "sub_group_broadcast_first");
1014 error |= rft.run_impl<cl_float,
1015 BC<cl_float, SubgroupsBroadcastOp::broadcast_first>>(
1016 "sub_group_broadcast_first");
1017 error |= rft.run_impl<cl_double,
1018 BC<cl_double, SubgroupsBroadcastOp::broadcast_first>>(
1019 "sub_group_broadcast_first");
1020 error |= rft.run_impl<
1021 subgroups::cl_half,
1022 BC<subgroups::cl_half, SubgroupsBroadcastOp::broadcast_first>>(
1023 "sub_group_broadcast_first");
1024
1025 // mask functions
1026 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::eq_mask>>(
1027 "get_sub_group_eq_mask");
1028 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::ge_mask>>(
1029 "get_sub_group_ge_mask");
1030 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::gt_mask>>(
1031 "get_sub_group_gt_mask");
1032 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::le_mask>>(
1033 "get_sub_group_le_mask");
1034 error |= rft.run_impl<cl_uint4, SMASK<cl_uint4, BallotOp::lt_mask>>(
1035 "get_sub_group_lt_mask");
1036
1037 // sub_group_ballot function
1038 WorkGroupParams test_params_ballot(global_work_size, local_work_size, 3);
1039 test_params_ballot.save_kernel_source(sub_group_ballot_source);
1040 RunTestForType rft_ballot(device, context, queue, num_elements,
1041 test_params_ballot);
1042 error |=
1043 rft_ballot.run_impl<cl_uint4, BALLOT<cl_uint4>>("sub_group_ballot");
1044
1045 // ballot arithmetic functions
1046 WorkGroupParams test_params_arith(global_work_size, local_work_size);
1047 test_params_arith.save_kernel_source(sub_group_ballot_bit_scan_find_source);
1048 test_params_arith.save_kernel_source(sub_group_inverse_ballot_source,
1049 "sub_group_inverse_ballot");
1050 test_params_arith.save_kernel_source(sub_group_ballot_bit_extract_source,
1051 "sub_group_ballot_bit_extract");
1052 RunTestForType rft_arith(device, context, queue, num_elements,
1053 test_params_arith);
1054 error |=
1055 rft_arith.run_impl<cl_uint4,
1056 BALLOT_INVERSE<cl_uint4, BallotOp::inverse_ballot>>(
1057 "sub_group_inverse_ballot");
1058 error |= rft_arith.run_impl<
1059 cl_uint4, BALLOT_BIT_EXTRACT<cl_uint4, BallotOp::ballot_bit_extract>>(
1060 "sub_group_ballot_bit_extract");
1061 error |= rft_arith.run_impl<
1062 cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_bit_count>>(
1063 "sub_group_ballot_bit_count");
1064 error |= rft_arith.run_impl<
1065 cl_uint4,
1066 BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_inclusive_scan>>(
1067 "sub_group_ballot_inclusive_scan");
1068 error |= rft_arith.run_impl<
1069 cl_uint4,
1070 BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_exclusive_scan>>(
1071 "sub_group_ballot_exclusive_scan");
1072 error |= rft_arith.run_impl<
1073 cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_lsb>>(
1074 "sub_group_ballot_find_lsb");
1075 error |= rft_arith.run_impl<
1076 cl_uint4, BALLOT_COUNT_SCAN_FIND<cl_uint4, BallotOp::ballot_find_msb>>(
1077 "sub_group_ballot_find_msb");
1078
1079 return error;
1080 }
1081