1 //
2 // Copyright (c) 2020 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBGROUPCOMMONTEMPLATES_H
17 #define SUBGROUPCOMMONTEMPLATES_H
18
19 #include "typeWrappers.h"
20 #include "CL/cl_half.h"
21 #include "subhelpers.h"
22 #include <set>
23 #include <algorithm>
24
25 // DESCRIPTION :
26 // sub_group_broadcast - each work_item registers it's own value.
27 // All work_items in subgroup takes one value from only one (any) work_item
28 // sub_group_broadcast_first - same as type 0. All work_items in
29 // subgroup takes only one value from only one chosen (the smallest subgroup ID)
30 // work_item
31 // sub_group_non_uniform_broadcast - same as type 0 but
32 // only 4 work_items from subgroup enter the code (are active)
33 template <typename Ty, SubgroupsBroadcastOp operation> struct BC
34 {
log_testBC35 static void log_test(const WorkGroupParams &test_params,
36 const char *extra_text)
37 {
38 log_info(" sub_group_%s(%s)...%s\n", operation_names(operation),
39 TypeManager<Ty>::name(), extra_text);
40 }
41
genBC42 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
43 {
44 int i, ii, j, k, n;
45 int ng = test_params.global_workgroup_size;
46 int nw = test_params.local_workgroup_size;
47 int ns = test_params.subgroup_size;
48 int nj = (nw + ns - 1) / ns;
49 int d = ns > 100 ? 100 : ns;
50 int non_uniform_size = ng % nw;
51 ng = ng / nw;
52 int last_subgroup_size = 0;
53 ii = 0;
54
55 if (non_uniform_size)
56 {
57 ng++;
58 }
59 for (k = 0; k < ng; ++k)
60 { // for each work_group
61 if (non_uniform_size && k == ng - 1)
62 {
63 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
64 last_subgroup_size);
65 }
66 for (j = 0; j < nj; ++j)
67 { // for each subgroup
68 ii = j * ns;
69 if (last_subgroup_size && j == nj - 1)
70 {
71 n = last_subgroup_size;
72 }
73 else
74 {
75 n = ii + ns > nw ? nw - ii : ns;
76 }
77 int bcast_if = 0;
78 int bcast_elseif = 0;
79 int bcast_index = (int)(genrand_int32(gMTdata) & 0x7fffffff)
80 % (d > n ? n : d);
81 // l - calculate subgroup local id from which value will be
82 // broadcasted (one the same value for whole subgroup)
83 if (operation != SubgroupsBroadcastOp::broadcast)
84 {
85 // reduce brodcasting index in case of non_uniform and
86 // last workgroup last subgroup
87 if (last_subgroup_size && j == nj - 1
88 && last_subgroup_size < NR_OF_ACTIVE_WORK_ITEMS)
89 {
90 bcast_if = bcast_index % last_subgroup_size;
91 bcast_elseif = bcast_if;
92 }
93 else
94 {
95 bcast_if = bcast_index % NR_OF_ACTIVE_WORK_ITEMS;
96 bcast_elseif = NR_OF_ACTIVE_WORK_ITEMS
97 + bcast_index % (n - NR_OF_ACTIVE_WORK_ITEMS);
98 }
99 }
100
101 for (i = 0; i < n; ++i)
102 {
103 if (operation == SubgroupsBroadcastOp::broadcast)
104 {
105 int midx = 4 * ii + 4 * i + 2;
106 m[midx] = (cl_int)bcast_index;
107 }
108 else
109 {
110 if (i < NR_OF_ACTIVE_WORK_ITEMS)
111 {
112 // index of the third
113 // element int the vector.
114 int midx = 4 * ii + 4 * i + 2;
115 // storing information about
116 // broadcasting index -
117 // earlier calculated
118 m[midx] = (cl_int)bcast_if;
119 }
120 else
121 { // index of the third
122 // element int the vector.
123 int midx = 4 * ii + 4 * i + 3;
124 m[midx] = (cl_int)bcast_elseif;
125 }
126 }
127
128 // calculate value for broadcasting
129 cl_ulong number = genrand_int64(gMTdata);
130 set_value(t[ii + i], number);
131 }
132 }
133 // Now map into work group using map from device
134 for (j = 0; j < nw; ++j)
135 { // for each element in work_group
136 // calculate index as number of subgroup
137 // plus subgroup local id
138 x[j] = t[j];
139 }
140 x += nw;
141 m += 4 * nw;
142 }
143 }
144
chkBC145 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
146 const WorkGroupParams &test_params)
147 {
148 int ii, i, j, k, l, n;
149 int ng = test_params.global_workgroup_size;
150 int nw = test_params.local_workgroup_size;
151 int ns = test_params.subgroup_size;
152 int nj = (nw + ns - 1) / ns;
153 Ty tr, rr;
154 int non_uniform_size = ng % nw;
155 ng = ng / nw;
156 int last_subgroup_size = 0;
157 if (non_uniform_size) ng++;
158
159 for (k = 0; k < ng; ++k)
160 { // for each work_group
161 if (non_uniform_size && k == ng - 1)
162 {
163 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
164 last_subgroup_size);
165 }
166 for (j = 0; j < nw; ++j)
167 { // inside the work_group
168 mx[j] = x[j]; // read host inputs for work_group
169 my[j] = y[j]; // read device outputs for work_group
170 }
171
172 for (j = 0; j < nj; ++j)
173 { // for each subgroup
174 ii = j * ns;
175 if (last_subgroup_size && j == nj - 1)
176 {
177 n = last_subgroup_size;
178 }
179 else
180 {
181 n = ii + ns > nw ? nw - ii : ns;
182 }
183
184 // Check result
185 if (operation == SubgroupsBroadcastOp::broadcast_first)
186 {
187 int lowest_active_id = -1;
188 for (i = 0; i < n; ++i)
189 {
190
191 lowest_active_id = i < NR_OF_ACTIVE_WORK_ITEMS
192 ? 0
193 : NR_OF_ACTIVE_WORK_ITEMS;
194 // findout if broadcasted
195 // value is the same
196 tr = mx[ii + lowest_active_id];
197 // findout if broadcasted to all
198 rr = my[ii + i];
199
200 if (!compare(rr, tr))
201 {
202 log_error(
203 "ERROR: sub_group_broadcast_first(%s) "
204 "mismatch "
205 "for local id %d in sub group %d in group "
206 "%d\n",
207 TypeManager<Ty>::name(), i, j, k);
208 return TEST_FAIL;
209 }
210 }
211 }
212 else
213 {
214 for (i = 0; i < n; ++i)
215 {
216 if (operation == SubgroupsBroadcastOp::broadcast)
217 {
218 int midx = 4 * ii + 4 * i + 2;
219 l = (int)m[midx];
220 tr = mx[ii + l];
221 }
222 else
223 {
224 if (i < NR_OF_ACTIVE_WORK_ITEMS)
225 { // take index of array where info
226 // which work_item will be
227 // broadcast its value is stored
228 int midx = 4 * ii + 4 * i + 2;
229 // take subgroup local id of
230 // this work_item
231 l = (int)m[midx];
232 // take value generated on host
233 // for this work_item
234 tr = mx[ii + l];
235 }
236 else
237 {
238 int midx = 4 * ii + 4 * i + 3;
239 l = (int)m[midx];
240 tr = mx[ii + l];
241 }
242 }
243 rr = my[ii + i]; // read device outputs for
244 // work_item in the subgroup
245
246 if (!compare(rr, tr))
247 {
248 log_error("ERROR: sub_group_%s(%s) "
249 "mismatch for local id %d in sub "
250 "group %d in group %d - %s\n",
251 operation_names(operation),
252 TypeManager<Ty>::name(), i, j, k,
253 print_expected_obtained(tr, rr).c_str());
254 return TEST_FAIL;
255 }
256 }
257 }
258 }
259 x += nw;
260 y += nw;
261 m += 4 * nw;
262 }
263 return TEST_PASS;
264 }
265 };
266
to_float(subgroups::cl_half x)267 static float to_float(subgroups::cl_half x) { return cl_half_to_float(x.data); }
268
to_half(float x)269 static subgroups::cl_half to_half(float x)
270 {
271 subgroups::cl_half value;
272 value.data = cl_half_from_float(x, g_rounding_mode);
273 return value;
274 }
275
276 // for integer types
calculate(Ty a,Ty b,ArithmeticOp operation)277 template <typename Ty> inline Ty calculate(Ty a, Ty b, ArithmeticOp operation)
278 {
279 switch (operation)
280 {
281 case ArithmeticOp::add_: return a + b;
282 case ArithmeticOp::max_: return a > b ? a : b;
283 case ArithmeticOp::min_: return a < b ? a : b;
284 case ArithmeticOp::mul_: return a * b;
285 case ArithmeticOp::and_: return a & b;
286 case ArithmeticOp::or_: return a | b;
287 case ArithmeticOp::xor_: return a ^ b;
288 case ArithmeticOp::logical_and: return a && b;
289 case ArithmeticOp::logical_or: return a || b;
290 case ArithmeticOp::logical_xor: return !a ^ !b;
291 default: log_error("Unknown operation request\n"); break;
292 }
293 return 0;
294 }
295 // Specialize for floating points.
296 template <>
calculate(cl_double a,cl_double b,ArithmeticOp operation)297 inline cl_double calculate(cl_double a, cl_double b, ArithmeticOp operation)
298 {
299 switch (operation)
300 {
301 case ArithmeticOp::add_: {
302 return a + b;
303 }
304 case ArithmeticOp::max_: {
305 return a > b ? a : b;
306 }
307 case ArithmeticOp::min_: {
308 return a < b ? a : b;
309 }
310 case ArithmeticOp::mul_: {
311 return a * b;
312 }
313 default: log_error("Unknown operation request\n"); break;
314 }
315 return 0;
316 }
317
318 template <>
calculate(cl_float a,cl_float b,ArithmeticOp operation)319 inline cl_float calculate(cl_float a, cl_float b, ArithmeticOp operation)
320 {
321 switch (operation)
322 {
323 case ArithmeticOp::add_: {
324 return a + b;
325 }
326 case ArithmeticOp::max_: {
327 return a > b ? a : b;
328 }
329 case ArithmeticOp::min_: {
330 return a < b ? a : b;
331 }
332 case ArithmeticOp::mul_: {
333 return a * b;
334 }
335 default: log_error("Unknown operation request\n"); break;
336 }
337 return 0;
338 }
339
340 template <>
calculate(subgroups::cl_half a,subgroups::cl_half b,ArithmeticOp operation)341 inline subgroups::cl_half calculate(subgroups::cl_half a, subgroups::cl_half b,
342 ArithmeticOp operation)
343 {
344 switch (operation)
345 {
346 case ArithmeticOp::add_: return to_half(to_float(a) + to_float(b));
347 case ArithmeticOp::max_:
348 return to_float(a) > to_float(b) || is_half_nan(b.data) ? a : b;
349 case ArithmeticOp::min_:
350 return to_float(a) < to_float(b) || is_half_nan(b.data) ? a : b;
351 case ArithmeticOp::mul_: return to_half(to_float(a) * to_float(b));
352 default: log_error("Unknown operation request\n"); break;
353 }
354 return to_half(0);
355 }
356
is_floating_point()357 template <typename Ty> bool is_floating_point()
358 {
359 return std::is_floating_point<Ty>::value
360 || std::is_same<Ty, subgroups::cl_half>::value;
361 }
362
363 template <typename Ty, ArithmeticOp operation>
generate_inputs(Ty * x,Ty * t,cl_int * m,int ns,int nw,int ng)364 void generate_inputs(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng)
365 {
366 int nj = (nw + ns - 1) / ns;
367
368 std::vector<cl_ulong> safe_values;
369 if (operation == ArithmeticOp::mul_ || operation == ArithmeticOp::add_)
370 {
371 fill_and_shuffle_safe_values(safe_values, ns);
372 }
373
374 for (int k = 0; k < ng; ++k)
375 {
376 for (int j = 0; j < nj; ++j)
377 {
378 int ii = j * ns;
379 int n = ii + ns > nw ? nw - ii : ns;
380
381 for (int i = 0; i < n; ++i)
382 {
383 cl_ulong out_value;
384 if (operation == ArithmeticOp::mul_
385 || operation == ArithmeticOp::add_)
386 {
387 out_value = safe_values[i];
388 }
389 else
390 {
391 out_value = genrand_int64(gMTdata) % (32 * n);
392 if ((operation == ArithmeticOp::logical_and
393 || operation == ArithmeticOp::logical_or
394 || operation == ArithmeticOp::logical_xor)
395 && ((out_value >> 32) & 1) == 0)
396 out_value = 0; // increase probability of false
397 }
398 set_value(t[ii + i], out_value);
399 }
400 }
401
402 // Now map into work group using map from device
403 for (int j = 0; j < nw; ++j)
404 {
405 x[j] = t[j];
406 }
407
408 x += nw;
409 m += 4 * nw;
410 }
411 }
412
413 template <typename Ty, ShuffleOp operation> struct SHF
414 {
log_testSHF415 static void log_test(const WorkGroupParams &test_params,
416 const char *extra_text)
417 {
418 log_info(" sub_group_%s(%s)...%s\n", operation_names(operation),
419 TypeManager<Ty>::name(), extra_text);
420 }
421
genSHF422 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
423 {
424 int i, ii, j, k, n;
425 cl_uint l;
426 int nw = test_params.local_workgroup_size;
427 int ns = test_params.subgroup_size;
428 int ng = test_params.global_workgroup_size;
429 int nj = (nw + ns - 1) / ns;
430 ii = 0;
431 ng = ng / nw;
432 for (k = 0; k < ng; ++k)
433 { // for each work_group
434 for (j = 0; j < nj; ++j)
435 { // for each subgroup
436 ii = j * ns;
437 n = ii + ns > nw ? nw - ii : ns;
438 for (i = 0; i < n; ++i)
439 {
440 int midx = 4 * ii + 4 * i + 2;
441 l = (((cl_uint)(genrand_int32(gMTdata) & 0x7fffffff) + 1)
442 % (ns * 2 + 1))
443 - 1;
444 switch (operation)
445 {
446 case ShuffleOp::shuffle:
447 case ShuffleOp::shuffle_xor:
448 case ShuffleOp::shuffle_up:
449 case ShuffleOp::shuffle_down:
450 // storing information about shuffle index/delta
451 m[midx] = (cl_int)l;
452 break;
453 case ShuffleOp::rotate:
454 case ShuffleOp::clustered_rotate:
455 // Storing information about rotate delta.
456 // The delta must be the same for each thread in
457 // the subgroup.
458 if (i == 0)
459 {
460 m[midx] = (cl_int)l;
461 }
462 else
463 {
464 m[midx] = m[midx - 4];
465 }
466 break;
467 default: break;
468 }
469 cl_ulong number = genrand_int64(gMTdata);
470 set_value(t[ii + i], number);
471 }
472 }
473 // Now map into work group using map from device
474 for (j = 0; j < nw; ++j)
475 { // for each element in work_group
476 x[j] = t[j];
477 }
478 x += nw;
479 m += 4 * nw;
480 }
481 }
482
chkSHF483 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
484 const WorkGroupParams &test_params)
485 {
486 int ii, k;
487 size_t n;
488 cl_uint l;
489 size_t nw = test_params.local_workgroup_size;
490 size_t ns = test_params.subgroup_size;
491 int ng = test_params.global_workgroup_size;
492 size_t nj = (nw + ns - 1) / ns;
493 Ty tr, rr;
494 ng = ng / nw;
495
496 for (k = 0; k < ng; ++k)
497 { // for each work_group
498 for (size_t j = 0; j < nw; ++j)
499 { // inside the work_group
500 mx[j] = x[j]; // read host inputs for work_group
501 my[j] = y[j]; // read device outputs for work_group
502 }
503
504 for (size_t j = 0; j < nj; ++j)
505 { // for each subgroup
506 ii = j * ns;
507 n = ii + ns > nw ? nw - ii : ns;
508
509 for (size_t i = 0; i < n; ++i)
510 { // inside the subgroup
511 // shuffle index storage
512 int midx = 4 * ii + 4 * i + 2;
513 l = m[midx];
514 rr = my[ii + i];
515 cl_uint tr_idx;
516 bool skip = false;
517 switch (operation)
518 {
519 // shuffle basic - treat l as index
520 case ShuffleOp::shuffle: tr_idx = l; break;
521 // shuffle xor - treat l as mask
522 case ShuffleOp::shuffle_xor: tr_idx = i ^ l; break;
523 // shuffle up - treat l as delta
524 case ShuffleOp::shuffle_up:
525 if (l >= ns) skip = true;
526 tr_idx = i - l;
527 break;
528 // shuffle down - treat l as delta
529 case ShuffleOp::shuffle_down:
530 if (l >= ns) skip = true;
531 tr_idx = i + l;
532 break;
533 // rotate - treat l as delta
534 case ShuffleOp::rotate:
535 tr_idx = (i + l) % test_params.subgroup_size;
536 break;
537 case ShuffleOp::clustered_rotate: {
538 tr_idx = ((i & ~(test_params.cluster_size - 1))
539 + ((i + l) % test_params.cluster_size));
540 break;
541 }
542 default: break;
543 }
544
545 if (!skip && tr_idx < n)
546 {
547 tr = mx[ii + tr_idx];
548
549 if (!compare(rr, tr))
550 {
551 log_error("ERROR: sub_group_%s(%s) mismatch for "
552 "local id %d in sub group %d in group "
553 "%d\n",
554 operation_names(operation),
555 TypeManager<Ty>::name(), i, j, k);
556 return TEST_FAIL;
557 }
558 }
559 }
560 }
561 x += nw;
562 y += nw;
563 m += 4 * nw;
564 }
565 return TEST_PASS;
566 }
567 };
568
569 template <typename Ty, ArithmeticOp operation> struct SCEX_NU
570 {
log_testSCEX_NU571 static void log_test(const WorkGroupParams &test_params,
572 const char *extra_text)
573 {
574 std::string func_name = (test_params.all_work_item_masks.size() > 0
575 ? "sub_group_non_uniform_scan_exclusive"
576 : "sub_group_scan_exclusive");
577 log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
578 operation_names(operation), TypeManager<Ty>::name(),
579 extra_text);
580 }
581
genSCEX_NU582 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
583 {
584 int nw = test_params.local_workgroup_size;
585 int ns = test_params.subgroup_size;
586 int ng = test_params.global_workgroup_size;
587 ng = ng / nw;
588 generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
589 }
590
chkSCEX_NU591 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
592 const WorkGroupParams &test_params)
593 {
594 int ii, i, j, k, n;
595 int nw = test_params.local_workgroup_size;
596 int ns = test_params.subgroup_size;
597 int ng = test_params.global_workgroup_size;
598 bs128 work_items_mask = test_params.work_items_mask;
599 int nj = (nw + ns - 1) / ns;
600 Ty tr, rr;
601 ng = ng / nw;
602
603 std::string func_name = (test_params.all_work_item_masks.size() > 0
604 ? "sub_group_non_uniform_scan_exclusive"
605 : "sub_group_scan_exclusive");
606
607 // for uniform case take into consideration all workitems
608 if (!work_items_mask.any())
609 {
610 work_items_mask.set();
611 }
612 for (k = 0; k < ng; ++k)
613 { // for each work_group
614 // Map to array indexed to array indexed by local ID and sub group
615 for (j = 0; j < nw; ++j)
616 { // inside the work_group
617 mx[j] = x[j]; // read host inputs for work_group
618 my[j] = y[j]; // read device outputs for work_group
619 }
620 for (j = 0; j < nj; ++j)
621 {
622 ii = j * ns;
623 n = ii + ns > nw ? nw - ii : ns;
624 std::set<int> active_work_items;
625 for (i = 0; i < n; ++i)
626 {
627 if (work_items_mask.test(i))
628 {
629 active_work_items.insert(i);
630 }
631 }
632 if (active_work_items.empty())
633 {
634 continue;
635 }
636 else
637 {
638 tr = TypeManager<Ty>::identify_limits(operation);
639 for (const int &active_work_item : active_work_items)
640 {
641 rr = my[ii + active_work_item];
642 if (!compare_ordered(rr, tr))
643 {
644 log_error(
645 "ERROR: %s_%s(%s) "
646 "mismatch for local id %d in sub group %d in "
647 "group %d %s\n",
648 func_name.c_str(), operation_names(operation),
649 TypeManager<Ty>::name(), i, j, k,
650 print_expected_obtained(tr, rr).c_str());
651 return TEST_FAIL;
652 }
653 tr = calculate<Ty>(tr, mx[ii + active_work_item],
654 operation);
655 }
656 }
657 }
658 x += nw;
659 y += nw;
660 m += 4 * nw;
661 }
662
663 return TEST_PASS;
664 }
665 };
666
667 // Test for scan inclusive non uniform functions
668 template <typename Ty, ArithmeticOp operation> struct SCIN_NU
669 {
log_testSCIN_NU670 static void log_test(const WorkGroupParams &test_params,
671 const char *extra_text)
672 {
673 std::string func_name = (test_params.all_work_item_masks.size() > 0
674 ? "sub_group_non_uniform_scan_inclusive"
675 : "sub_group_scan_inclusive");
676 log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
677 operation_names(operation), TypeManager<Ty>::name(),
678 extra_text);
679 }
680
genSCIN_NU681 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
682 {
683 int nw = test_params.local_workgroup_size;
684 int ns = test_params.subgroup_size;
685 int ng = test_params.global_workgroup_size;
686 ng = ng / nw;
687 generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
688 }
689
chkSCIN_NU690 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
691 const WorkGroupParams &test_params)
692 {
693 int ii, i, j, k, n;
694 int nw = test_params.local_workgroup_size;
695 int ns = test_params.subgroup_size;
696 int ng = test_params.global_workgroup_size;
697 bs128 work_items_mask = test_params.work_items_mask;
698
699 int nj = (nw + ns - 1) / ns;
700 Ty tr, rr;
701 ng = ng / nw;
702
703 std::string func_name = (test_params.all_work_item_masks.size() > 0
704 ? "sub_group_non_uniform_scan_inclusive"
705 : "sub_group_scan_inclusive");
706
707 // for uniform case take into consideration all workitems
708 if (!work_items_mask.any())
709 {
710 work_items_mask.set();
711 }
712 // std::bitset<32> mask32(use_work_items_mask);
713 // for (int k) mask32.count();
714 for (k = 0; k < ng; ++k)
715 { // for each work_group
716 // Map to array indexed to array indexed by local ID and sub group
717 for (j = 0; j < nw; ++j)
718 { // inside the work_group
719 mx[j] = x[j]; // read host inputs for work_group
720 my[j] = y[j]; // read device outputs for work_group
721 }
722 for (j = 0; j < nj; ++j)
723 {
724 ii = j * ns;
725 n = ii + ns > nw ? nw - ii : ns;
726 std::set<int> active_work_items;
727 int catch_frist_active = -1;
728
729 for (i = 0; i < n; ++i)
730 {
731 if (work_items_mask.test(i))
732 {
733 if (catch_frist_active == -1)
734 {
735 catch_frist_active = i;
736 }
737 active_work_items.insert(i);
738 }
739 }
740 if (active_work_items.empty())
741 {
742 continue;
743 }
744 else
745 {
746 tr = TypeManager<Ty>::identify_limits(operation);
747 for (const int &active_work_item : active_work_items)
748 {
749 rr = my[ii + active_work_item];
750 if (active_work_items.size() == 1)
751 {
752 tr = mx[ii + catch_frist_active];
753 }
754 else
755 {
756 tr = calculate<Ty>(tr, mx[ii + active_work_item],
757 operation);
758 }
759 if (!compare_ordered<Ty>(rr, tr))
760 {
761 log_error(
762 "ERROR: %s_%s(%s) "
763 "mismatch for local id %d in sub group %d "
764 "in "
765 "group %d %s\n",
766 func_name.c_str(), operation_names(operation),
767 TypeManager<Ty>::name(), active_work_item, j, k,
768 print_expected_obtained(tr, rr).c_str());
769 return TEST_FAIL;
770 }
771 }
772 }
773 }
774 x += nw;
775 y += nw;
776 m += 4 * nw;
777 }
778
779 return TEST_PASS;
780 }
781 };
782
783 // Test for reduce non uniform functions
784 template <typename Ty, ArithmeticOp operation> struct RED_NU
785 {
log_testRED_NU786 static void log_test(const WorkGroupParams &test_params,
787 const char *extra_text)
788 {
789 std::string func_name = (test_params.all_work_item_masks.size() > 0
790 ? "sub_group_non_uniform_reduce"
791 : "sub_group_reduce");
792 log_info(" %s_%s(%s)...%s\n", func_name.c_str(),
793 operation_names(operation), TypeManager<Ty>::name(),
794 extra_text);
795 }
796
genRED_NU797 static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
798 {
799 int nw = test_params.local_workgroup_size;
800 int ns = test_params.subgroup_size;
801 int ng = test_params.global_workgroup_size;
802 ng = ng / nw;
803 generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
804 }
805
chkRED_NU806 static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
807 const WorkGroupParams &test_params)
808 {
809 int ii, i, j, k, n;
810 int nw = test_params.local_workgroup_size;
811 int ns = test_params.subgroup_size;
812 int ng = test_params.global_workgroup_size;
813 bs128 work_items_mask = test_params.work_items_mask;
814 int nj = (nw + ns - 1) / ns;
815 ng = ng / nw;
816 Ty tr, rr;
817
818 std::string func_name = (test_params.all_work_item_masks.size() > 0
819 ? "sub_group_non_uniform_reduce"
820 : "sub_group_reduce");
821
822 for (k = 0; k < ng; ++k)
823 {
824 // Map to array indexed to array indexed by local ID and sub
825 // group
826 for (j = 0; j < nw; ++j)
827 {
828 mx[j] = x[j];
829 my[j] = y[j];
830 }
831
832 if (!work_items_mask.any())
833 {
834 work_items_mask.set();
835 }
836
837 for (j = 0; j < nj; ++j)
838 {
839 ii = j * ns;
840 n = ii + ns > nw ? nw - ii : ns;
841 std::set<int> active_work_items;
842 int catch_frist_active = -1;
843 for (i = 0; i < n; ++i)
844 {
845 if (work_items_mask.test(i))
846 {
847 if (catch_frist_active == -1)
848 {
849 catch_frist_active = i;
850 tr = mx[ii + i];
851 active_work_items.insert(i);
852 continue;
853 }
854 active_work_items.insert(i);
855 tr = calculate<Ty>(tr, mx[ii + i], operation);
856 }
857 }
858
859 if (active_work_items.empty())
860 {
861 continue;
862 }
863
864 for (const int &active_work_item : active_work_items)
865 {
866 rr = my[ii + active_work_item];
867 if (!compare_ordered<Ty>(rr, tr))
868 {
869 log_error("ERROR: %s_%s(%s) "
870 "mismatch for local id %d in sub group %d in "
871 "group %d %s\n",
872 func_name.c_str(), operation_names(operation),
873 TypeManager<Ty>::name(), active_work_item, j,
874 k, print_expected_obtained(tr, rr).c_str());
875 return TEST_FAIL;
876 }
877 }
878 }
879 x += nw;
880 y += nw;
881 m += 4 * nw;
882 }
883
884 return TEST_PASS;
885 }
886 };
887
888 #endif
889