xref: /aosp_15_r20/external/OpenCL-CTS/test_conformance/subgroups/subgroup_common_templates.h (revision 6467f958c7de8070b317fc65bcb0f6472e388d82)
1 //
2 // Copyright (c) 2020 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef SUBGROUPCOMMONTEMPLATES_H
17 #define SUBGROUPCOMMONTEMPLATES_H
18 
19 #include "typeWrappers.h"
20 #include "CL/cl_half.h"
21 #include "subhelpers.h"
22 #include <set>
23 #include <algorithm>
24 
25 // DESCRIPTION :
26 // sub_group_broadcast - each work_item registers it's own value.
27 // All work_items in subgroup takes one value from only one (any) work_item
28 // sub_group_broadcast_first - same as type 0. All work_items in
29 // subgroup takes only one value from only one chosen (the smallest subgroup ID)
30 // work_item
31 // sub_group_non_uniform_broadcast - same as type 0 but
32 // only 4 work_items from subgroup enter the code (are active)
33 template <typename Ty, SubgroupsBroadcastOp operation> struct BC
34 {
log_testBC35     static void log_test(const WorkGroupParams &test_params,
36                          const char *extra_text)
37     {
38         log_info("  sub_group_%s(%s)...%s\n", operation_names(operation),
39                  TypeManager<Ty>::name(), extra_text);
40     }
41 
genBC42     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
43     {
44         int i, ii, j, k, n;
45         int ng = test_params.global_workgroup_size;
46         int nw = test_params.local_workgroup_size;
47         int ns = test_params.subgroup_size;
48         int nj = (nw + ns - 1) / ns;
49         int d = ns > 100 ? 100 : ns;
50         int non_uniform_size = ng % nw;
51         ng = ng / nw;
52         int last_subgroup_size = 0;
53         ii = 0;
54 
55         if (non_uniform_size)
56         {
57             ng++;
58         }
59         for (k = 0; k < ng; ++k)
60         { // for each work_group
61             if (non_uniform_size && k == ng - 1)
62             {
63                 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
64                                           last_subgroup_size);
65             }
66             for (j = 0; j < nj; ++j)
67             { // for each subgroup
68                 ii = j * ns;
69                 if (last_subgroup_size && j == nj - 1)
70                 {
71                     n = last_subgroup_size;
72                 }
73                 else
74                 {
75                     n = ii + ns > nw ? nw - ii : ns;
76                 }
77                 int bcast_if = 0;
78                 int bcast_elseif = 0;
79                 int bcast_index = (int)(genrand_int32(gMTdata) & 0x7fffffff)
80                     % (d > n ? n : d);
81                 // l - calculate subgroup local id from which value will be
82                 // broadcasted (one the same value for whole subgroup)
83                 if (operation != SubgroupsBroadcastOp::broadcast)
84                 {
85                     // reduce brodcasting index in case of non_uniform and
86                     // last workgroup last subgroup
87                     if (last_subgroup_size && j == nj - 1
88                         && last_subgroup_size < NR_OF_ACTIVE_WORK_ITEMS)
89                     {
90                         bcast_if = bcast_index % last_subgroup_size;
91                         bcast_elseif = bcast_if;
92                     }
93                     else
94                     {
95                         bcast_if = bcast_index % NR_OF_ACTIVE_WORK_ITEMS;
96                         bcast_elseif = NR_OF_ACTIVE_WORK_ITEMS
97                             + bcast_index % (n - NR_OF_ACTIVE_WORK_ITEMS);
98                     }
99                 }
100 
101                 for (i = 0; i < n; ++i)
102                 {
103                     if (operation == SubgroupsBroadcastOp::broadcast)
104                     {
105                         int midx = 4 * ii + 4 * i + 2;
106                         m[midx] = (cl_int)bcast_index;
107                     }
108                     else
109                     {
110                         if (i < NR_OF_ACTIVE_WORK_ITEMS)
111                         {
112                             // index of the third
113                             // element int the vector.
114                             int midx = 4 * ii + 4 * i + 2;
115                             // storing information about
116                             // broadcasting index -
117                             // earlier calculated
118                             m[midx] = (cl_int)bcast_if;
119                         }
120                         else
121                         { // index of the third
122                           // element int the vector.
123                             int midx = 4 * ii + 4 * i + 3;
124                             m[midx] = (cl_int)bcast_elseif;
125                         }
126                     }
127 
128                     // calculate value for broadcasting
129                     cl_ulong number = genrand_int64(gMTdata);
130                     set_value(t[ii + i], number);
131                 }
132             }
133             // Now map into work group using map from device
134             for (j = 0; j < nw; ++j)
135             { // for each element in work_group
136                 // calculate index as number of subgroup
137                 // plus subgroup local id
138                 x[j] = t[j];
139             }
140             x += nw;
141             m += 4 * nw;
142         }
143     }
144 
chkBC145     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
146                            const WorkGroupParams &test_params)
147     {
148         int ii, i, j, k, l, n;
149         int ng = test_params.global_workgroup_size;
150         int nw = test_params.local_workgroup_size;
151         int ns = test_params.subgroup_size;
152         int nj = (nw + ns - 1) / ns;
153         Ty tr, rr;
154         int non_uniform_size = ng % nw;
155         ng = ng / nw;
156         int last_subgroup_size = 0;
157         if (non_uniform_size) ng++;
158 
159         for (k = 0; k < ng; ++k)
160         { // for each work_group
161             if (non_uniform_size && k == ng - 1)
162             {
163                 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
164                                           last_subgroup_size);
165             }
166             for (j = 0; j < nw; ++j)
167             { // inside the work_group
168                 mx[j] = x[j]; // read host inputs for work_group
169                 my[j] = y[j]; // read device outputs for work_group
170             }
171 
172             for (j = 0; j < nj; ++j)
173             { // for each subgroup
174                 ii = j * ns;
175                 if (last_subgroup_size && j == nj - 1)
176                 {
177                     n = last_subgroup_size;
178                 }
179                 else
180                 {
181                     n = ii + ns > nw ? nw - ii : ns;
182                 }
183 
184                 // Check result
185                 if (operation == SubgroupsBroadcastOp::broadcast_first)
186                 {
187                     int lowest_active_id = -1;
188                     for (i = 0; i < n; ++i)
189                     {
190 
191                         lowest_active_id = i < NR_OF_ACTIVE_WORK_ITEMS
192                             ? 0
193                             : NR_OF_ACTIVE_WORK_ITEMS;
194                         //  findout if broadcasted
195                         //  value is the same
196                         tr = mx[ii + lowest_active_id];
197                         //  findout if broadcasted to all
198                         rr = my[ii + i];
199 
200                         if (!compare(rr, tr))
201                         {
202                             log_error(
203                                 "ERROR: sub_group_broadcast_first(%s) "
204                                 "mismatch "
205                                 "for local id %d in sub group %d in group "
206                                 "%d\n",
207                                 TypeManager<Ty>::name(), i, j, k);
208                             return TEST_FAIL;
209                         }
210                     }
211                 }
212                 else
213                 {
214                     for (i = 0; i < n; ++i)
215                     {
216                         if (operation == SubgroupsBroadcastOp::broadcast)
217                         {
218                             int midx = 4 * ii + 4 * i + 2;
219                             l = (int)m[midx];
220                             tr = mx[ii + l];
221                         }
222                         else
223                         {
224                             if (i < NR_OF_ACTIVE_WORK_ITEMS)
225                             { // take index of array where info
226                               // which work_item will be
227                               // broadcast its value is stored
228                                 int midx = 4 * ii + 4 * i + 2;
229                                 // take subgroup local id of
230                                 // this work_item
231                                 l = (int)m[midx];
232                                 // take value generated on host
233                                 // for this work_item
234                                 tr = mx[ii + l];
235                             }
236                             else
237                             {
238                                 int midx = 4 * ii + 4 * i + 3;
239                                 l = (int)m[midx];
240                                 tr = mx[ii + l];
241                             }
242                         }
243                         rr = my[ii + i]; // read device outputs for
244                                          // work_item in the subgroup
245 
246                         if (!compare(rr, tr))
247                         {
248                             log_error("ERROR: sub_group_%s(%s) "
249                                       "mismatch for local id %d in sub "
250                                       "group %d in group %d - %s\n",
251                                       operation_names(operation),
252                                       TypeManager<Ty>::name(), i, j, k,
253                                       print_expected_obtained(tr, rr).c_str());
254                             return TEST_FAIL;
255                         }
256                     }
257                 }
258             }
259             x += nw;
260             y += nw;
261             m += 4 * nw;
262         }
263         return TEST_PASS;
264     }
265 };
266 
to_float(subgroups::cl_half x)267 static float to_float(subgroups::cl_half x) { return cl_half_to_float(x.data); }
268 
to_half(float x)269 static subgroups::cl_half to_half(float x)
270 {
271     subgroups::cl_half value;
272     value.data = cl_half_from_float(x, g_rounding_mode);
273     return value;
274 }
275 
276 // for integer types
calculate(Ty a,Ty b,ArithmeticOp operation)277 template <typename Ty> inline Ty calculate(Ty a, Ty b, ArithmeticOp operation)
278 {
279     switch (operation)
280     {
281         case ArithmeticOp::add_: return a + b;
282         case ArithmeticOp::max_: return a > b ? a : b;
283         case ArithmeticOp::min_: return a < b ? a : b;
284         case ArithmeticOp::mul_: return a * b;
285         case ArithmeticOp::and_: return a & b;
286         case ArithmeticOp::or_: return a | b;
287         case ArithmeticOp::xor_: return a ^ b;
288         case ArithmeticOp::logical_and: return a && b;
289         case ArithmeticOp::logical_or: return a || b;
290         case ArithmeticOp::logical_xor: return !a ^ !b;
291         default: log_error("Unknown operation request\n"); break;
292     }
293     return 0;
294 }
295 // Specialize for floating points.
296 template <>
calculate(cl_double a,cl_double b,ArithmeticOp operation)297 inline cl_double calculate(cl_double a, cl_double b, ArithmeticOp operation)
298 {
299     switch (operation)
300     {
301         case ArithmeticOp::add_: {
302             return a + b;
303         }
304         case ArithmeticOp::max_: {
305             return a > b ? a : b;
306         }
307         case ArithmeticOp::min_: {
308             return a < b ? a : b;
309         }
310         case ArithmeticOp::mul_: {
311             return a * b;
312         }
313         default: log_error("Unknown operation request\n"); break;
314     }
315     return 0;
316 }
317 
318 template <>
calculate(cl_float a,cl_float b,ArithmeticOp operation)319 inline cl_float calculate(cl_float a, cl_float b, ArithmeticOp operation)
320 {
321     switch (operation)
322     {
323         case ArithmeticOp::add_: {
324             return a + b;
325         }
326         case ArithmeticOp::max_: {
327             return a > b ? a : b;
328         }
329         case ArithmeticOp::min_: {
330             return a < b ? a : b;
331         }
332         case ArithmeticOp::mul_: {
333             return a * b;
334         }
335         default: log_error("Unknown operation request\n"); break;
336     }
337     return 0;
338 }
339 
340 template <>
calculate(subgroups::cl_half a,subgroups::cl_half b,ArithmeticOp operation)341 inline subgroups::cl_half calculate(subgroups::cl_half a, subgroups::cl_half b,
342                                     ArithmeticOp operation)
343 {
344     switch (operation)
345     {
346         case ArithmeticOp::add_: return to_half(to_float(a) + to_float(b));
347         case ArithmeticOp::max_:
348             return to_float(a) > to_float(b) || is_half_nan(b.data) ? a : b;
349         case ArithmeticOp::min_:
350             return to_float(a) < to_float(b) || is_half_nan(b.data) ? a : b;
351         case ArithmeticOp::mul_: return to_half(to_float(a) * to_float(b));
352         default: log_error("Unknown operation request\n"); break;
353     }
354     return to_half(0);
355 }
356 
is_floating_point()357 template <typename Ty> bool is_floating_point()
358 {
359     return std::is_floating_point<Ty>::value
360         || std::is_same<Ty, subgroups::cl_half>::value;
361 }
362 
363 template <typename Ty, ArithmeticOp operation>
generate_inputs(Ty * x,Ty * t,cl_int * m,int ns,int nw,int ng)364 void generate_inputs(Ty *x, Ty *t, cl_int *m, int ns, int nw, int ng)
365 {
366     int nj = (nw + ns - 1) / ns;
367 
368     std::vector<cl_ulong> safe_values;
369     if (operation == ArithmeticOp::mul_ || operation == ArithmeticOp::add_)
370     {
371         fill_and_shuffle_safe_values(safe_values, ns);
372     }
373 
374     for (int k = 0; k < ng; ++k)
375     {
376         for (int j = 0; j < nj; ++j)
377         {
378             int ii = j * ns;
379             int n = ii + ns > nw ? nw - ii : ns;
380 
381             for (int i = 0; i < n; ++i)
382             {
383                 cl_ulong out_value;
384                 if (operation == ArithmeticOp::mul_
385                     || operation == ArithmeticOp::add_)
386                 {
387                     out_value = safe_values[i];
388                 }
389                 else
390                 {
391                     out_value = genrand_int64(gMTdata) % (32 * n);
392                     if ((operation == ArithmeticOp::logical_and
393                          || operation == ArithmeticOp::logical_or
394                          || operation == ArithmeticOp::logical_xor)
395                         && ((out_value >> 32) & 1) == 0)
396                         out_value = 0; // increase probability of false
397                 }
398                 set_value(t[ii + i], out_value);
399             }
400         }
401 
402         // Now map into work group using map from device
403         for (int j = 0; j < nw; ++j)
404         {
405             x[j] = t[j];
406         }
407 
408         x += nw;
409         m += 4 * nw;
410     }
411 }
412 
413 template <typename Ty, ShuffleOp operation> struct SHF
414 {
log_testSHF415     static void log_test(const WorkGroupParams &test_params,
416                          const char *extra_text)
417     {
418         log_info("  sub_group_%s(%s)...%s\n", operation_names(operation),
419                  TypeManager<Ty>::name(), extra_text);
420     }
421 
genSHF422     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
423     {
424         int i, ii, j, k, n;
425         cl_uint l;
426         int nw = test_params.local_workgroup_size;
427         int ns = test_params.subgroup_size;
428         int ng = test_params.global_workgroup_size;
429         int nj = (nw + ns - 1) / ns;
430         ii = 0;
431         ng = ng / nw;
432         for (k = 0; k < ng; ++k)
433         { // for each work_group
434             for (j = 0; j < nj; ++j)
435             { // for each subgroup
436                 ii = j * ns;
437                 n = ii + ns > nw ? nw - ii : ns;
438                 for (i = 0; i < n; ++i)
439                 {
440                     int midx = 4 * ii + 4 * i + 2;
441                     l = (((cl_uint)(genrand_int32(gMTdata) & 0x7fffffff) + 1)
442                          % (ns * 2 + 1))
443                         - 1;
444                     switch (operation)
445                     {
446                         case ShuffleOp::shuffle:
447                         case ShuffleOp::shuffle_xor:
448                         case ShuffleOp::shuffle_up:
449                         case ShuffleOp::shuffle_down:
450                             // storing information about shuffle index/delta
451                             m[midx] = (cl_int)l;
452                             break;
453                         case ShuffleOp::rotate:
454                         case ShuffleOp::clustered_rotate:
455                             // Storing information about rotate delta.
456                             // The delta must be the same for each thread in
457                             // the subgroup.
458                             if (i == 0)
459                             {
460                                 m[midx] = (cl_int)l;
461                             }
462                             else
463                             {
464                                 m[midx] = m[midx - 4];
465                             }
466                             break;
467                         default: break;
468                     }
469                     cl_ulong number = genrand_int64(gMTdata);
470                     set_value(t[ii + i], number);
471                 }
472             }
473             // Now map into work group using map from device
474             for (j = 0; j < nw; ++j)
475             { // for each element in work_group
476                 x[j] = t[j];
477             }
478             x += nw;
479             m += 4 * nw;
480         }
481     }
482 
chkSHF483     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
484                            const WorkGroupParams &test_params)
485     {
486         int ii, k;
487         size_t n;
488         cl_uint l;
489         size_t nw = test_params.local_workgroup_size;
490         size_t ns = test_params.subgroup_size;
491         int ng = test_params.global_workgroup_size;
492         size_t nj = (nw + ns - 1) / ns;
493         Ty tr, rr;
494         ng = ng / nw;
495 
496         for (k = 0; k < ng; ++k)
497         { // for each work_group
498             for (size_t j = 0; j < nw; ++j)
499             { // inside the work_group
500                 mx[j] = x[j]; // read host inputs for work_group
501                 my[j] = y[j]; // read device outputs for work_group
502             }
503 
504             for (size_t j = 0; j < nj; ++j)
505             { // for each subgroup
506                 ii = j * ns;
507                 n = ii + ns > nw ? nw - ii : ns;
508 
509                 for (size_t i = 0; i < n; ++i)
510                 { // inside the subgroup
511                   // shuffle index storage
512                     int midx = 4 * ii + 4 * i + 2;
513                     l = m[midx];
514                     rr = my[ii + i];
515                     cl_uint tr_idx;
516                     bool skip = false;
517                     switch (operation)
518                     {
519                         // shuffle basic - treat l as index
520                         case ShuffleOp::shuffle: tr_idx = l; break;
521                         // shuffle xor - treat l as mask
522                         case ShuffleOp::shuffle_xor: tr_idx = i ^ l; break;
523                         // shuffle up - treat l as delta
524                         case ShuffleOp::shuffle_up:
525                             if (l >= ns) skip = true;
526                             tr_idx = i - l;
527                             break;
528                         // shuffle down - treat l as delta
529                         case ShuffleOp::shuffle_down:
530                             if (l >= ns) skip = true;
531                             tr_idx = i + l;
532                             break;
533                         // rotate - treat l as delta
534                         case ShuffleOp::rotate:
535                             tr_idx = (i + l) % test_params.subgroup_size;
536                             break;
537                         case ShuffleOp::clustered_rotate: {
538                             tr_idx = ((i & ~(test_params.cluster_size - 1))
539                                       + ((i + l) % test_params.cluster_size));
540                             break;
541                         }
542                         default: break;
543                     }
544 
545                     if (!skip && tr_idx < n)
546                     {
547                         tr = mx[ii + tr_idx];
548 
549                         if (!compare(rr, tr))
550                         {
551                             log_error("ERROR: sub_group_%s(%s) mismatch for "
552                                       "local id %d in sub group %d in group "
553                                       "%d\n",
554                                       operation_names(operation),
555                                       TypeManager<Ty>::name(), i, j, k);
556                             return TEST_FAIL;
557                         }
558                     }
559                 }
560             }
561             x += nw;
562             y += nw;
563             m += 4 * nw;
564         }
565         return TEST_PASS;
566     }
567 };
568 
569 template <typename Ty, ArithmeticOp operation> struct SCEX_NU
570 {
log_testSCEX_NU571     static void log_test(const WorkGroupParams &test_params,
572                          const char *extra_text)
573     {
574         std::string func_name = (test_params.all_work_item_masks.size() > 0
575                                      ? "sub_group_non_uniform_scan_exclusive"
576                                      : "sub_group_scan_exclusive");
577         log_info("  %s_%s(%s)...%s\n", func_name.c_str(),
578                  operation_names(operation), TypeManager<Ty>::name(),
579                  extra_text);
580     }
581 
genSCEX_NU582     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
583     {
584         int nw = test_params.local_workgroup_size;
585         int ns = test_params.subgroup_size;
586         int ng = test_params.global_workgroup_size;
587         ng = ng / nw;
588         generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
589     }
590 
chkSCEX_NU591     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
592                            const WorkGroupParams &test_params)
593     {
594         int ii, i, j, k, n;
595         int nw = test_params.local_workgroup_size;
596         int ns = test_params.subgroup_size;
597         int ng = test_params.global_workgroup_size;
598         bs128 work_items_mask = test_params.work_items_mask;
599         int nj = (nw + ns - 1) / ns;
600         Ty tr, rr;
601         ng = ng / nw;
602 
603         std::string func_name = (test_params.all_work_item_masks.size() > 0
604                                      ? "sub_group_non_uniform_scan_exclusive"
605                                      : "sub_group_scan_exclusive");
606 
607         // for uniform case take into consideration all workitems
608         if (!work_items_mask.any())
609         {
610             work_items_mask.set();
611         }
612         for (k = 0; k < ng; ++k)
613         { // for each work_group
614             // Map to array indexed to array indexed by local ID and sub group
615             for (j = 0; j < nw; ++j)
616             { // inside the work_group
617                 mx[j] = x[j]; // read host inputs for work_group
618                 my[j] = y[j]; // read device outputs for work_group
619             }
620             for (j = 0; j < nj; ++j)
621             {
622                 ii = j * ns;
623                 n = ii + ns > nw ? nw - ii : ns;
624                 std::set<int> active_work_items;
625                 for (i = 0; i < n; ++i)
626                 {
627                     if (work_items_mask.test(i))
628                     {
629                         active_work_items.insert(i);
630                     }
631                 }
632                 if (active_work_items.empty())
633                 {
634                     continue;
635                 }
636                 else
637                 {
638                     tr = TypeManager<Ty>::identify_limits(operation);
639                     for (const int &active_work_item : active_work_items)
640                     {
641                         rr = my[ii + active_work_item];
642                         if (!compare_ordered(rr, tr))
643                         {
644                             log_error(
645                                 "ERROR: %s_%s(%s) "
646                                 "mismatch for local id %d in sub group %d in "
647                                 "group %d %s\n",
648                                 func_name.c_str(), operation_names(operation),
649                                 TypeManager<Ty>::name(), i, j, k,
650                                 print_expected_obtained(tr, rr).c_str());
651                             return TEST_FAIL;
652                         }
653                         tr = calculate<Ty>(tr, mx[ii + active_work_item],
654                                            operation);
655                     }
656                 }
657             }
658             x += nw;
659             y += nw;
660             m += 4 * nw;
661         }
662 
663         return TEST_PASS;
664     }
665 };
666 
667 // Test for scan inclusive non uniform functions
668 template <typename Ty, ArithmeticOp operation> struct SCIN_NU
669 {
log_testSCIN_NU670     static void log_test(const WorkGroupParams &test_params,
671                          const char *extra_text)
672     {
673         std::string func_name = (test_params.all_work_item_masks.size() > 0
674                                      ? "sub_group_non_uniform_scan_inclusive"
675                                      : "sub_group_scan_inclusive");
676         log_info("  %s_%s(%s)...%s\n", func_name.c_str(),
677                  operation_names(operation), TypeManager<Ty>::name(),
678                  extra_text);
679     }
680 
genSCIN_NU681     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
682     {
683         int nw = test_params.local_workgroup_size;
684         int ns = test_params.subgroup_size;
685         int ng = test_params.global_workgroup_size;
686         ng = ng / nw;
687         generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
688     }
689 
chkSCIN_NU690     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
691                            const WorkGroupParams &test_params)
692     {
693         int ii, i, j, k, n;
694         int nw = test_params.local_workgroup_size;
695         int ns = test_params.subgroup_size;
696         int ng = test_params.global_workgroup_size;
697         bs128 work_items_mask = test_params.work_items_mask;
698 
699         int nj = (nw + ns - 1) / ns;
700         Ty tr, rr;
701         ng = ng / nw;
702 
703         std::string func_name = (test_params.all_work_item_masks.size() > 0
704                                      ? "sub_group_non_uniform_scan_inclusive"
705                                      : "sub_group_scan_inclusive");
706 
707         // for uniform case take into consideration all workitems
708         if (!work_items_mask.any())
709         {
710             work_items_mask.set();
711         }
712         // std::bitset<32> mask32(use_work_items_mask);
713         // for (int k) mask32.count();
714         for (k = 0; k < ng; ++k)
715         { // for each work_group
716             // Map to array indexed to array indexed by local ID and sub group
717             for (j = 0; j < nw; ++j)
718             { // inside the work_group
719                 mx[j] = x[j]; // read host inputs for work_group
720                 my[j] = y[j]; // read device outputs for work_group
721             }
722             for (j = 0; j < nj; ++j)
723             {
724                 ii = j * ns;
725                 n = ii + ns > nw ? nw - ii : ns;
726                 std::set<int> active_work_items;
727                 int catch_frist_active = -1;
728 
729                 for (i = 0; i < n; ++i)
730                 {
731                     if (work_items_mask.test(i))
732                     {
733                         if (catch_frist_active == -1)
734                         {
735                             catch_frist_active = i;
736                         }
737                         active_work_items.insert(i);
738                     }
739                 }
740                 if (active_work_items.empty())
741                 {
742                     continue;
743                 }
744                 else
745                 {
746                     tr = TypeManager<Ty>::identify_limits(operation);
747                     for (const int &active_work_item : active_work_items)
748                     {
749                         rr = my[ii + active_work_item];
750                         if (active_work_items.size() == 1)
751                         {
752                             tr = mx[ii + catch_frist_active];
753                         }
754                         else
755                         {
756                             tr = calculate<Ty>(tr, mx[ii + active_work_item],
757                                                operation);
758                         }
759                         if (!compare_ordered<Ty>(rr, tr))
760                         {
761                             log_error(
762                                 "ERROR: %s_%s(%s) "
763                                 "mismatch for local id %d in sub group %d "
764                                 "in "
765                                 "group %d %s\n",
766                                 func_name.c_str(), operation_names(operation),
767                                 TypeManager<Ty>::name(), active_work_item, j, k,
768                                 print_expected_obtained(tr, rr).c_str());
769                             return TEST_FAIL;
770                         }
771                     }
772                 }
773             }
774             x += nw;
775             y += nw;
776             m += 4 * nw;
777         }
778 
779         return TEST_PASS;
780     }
781 };
782 
783 // Test for reduce non uniform functions
784 template <typename Ty, ArithmeticOp operation> struct RED_NU
785 {
log_testRED_NU786     static void log_test(const WorkGroupParams &test_params,
787                          const char *extra_text)
788     {
789         std::string func_name = (test_params.all_work_item_masks.size() > 0
790                                      ? "sub_group_non_uniform_reduce"
791                                      : "sub_group_reduce");
792         log_info("  %s_%s(%s)...%s\n", func_name.c_str(),
793                  operation_names(operation), TypeManager<Ty>::name(),
794                  extra_text);
795     }
796 
genRED_NU797     static void gen(Ty *x, Ty *t, cl_int *m, const WorkGroupParams &test_params)
798     {
799         int nw = test_params.local_workgroup_size;
800         int ns = test_params.subgroup_size;
801         int ng = test_params.global_workgroup_size;
802         ng = ng / nw;
803         generate_inputs<Ty, operation>(x, t, m, ns, nw, ng);
804     }
805 
chkRED_NU806     static test_status chk(Ty *x, Ty *y, Ty *mx, Ty *my, cl_int *m,
807                            const WorkGroupParams &test_params)
808     {
809         int ii, i, j, k, n;
810         int nw = test_params.local_workgroup_size;
811         int ns = test_params.subgroup_size;
812         int ng = test_params.global_workgroup_size;
813         bs128 work_items_mask = test_params.work_items_mask;
814         int nj = (nw + ns - 1) / ns;
815         ng = ng / nw;
816         Ty tr, rr;
817 
818         std::string func_name = (test_params.all_work_item_masks.size() > 0
819                                      ? "sub_group_non_uniform_reduce"
820                                      : "sub_group_reduce");
821 
822         for (k = 0; k < ng; ++k)
823         {
824             // Map to array indexed to array indexed by local ID and sub
825             // group
826             for (j = 0; j < nw; ++j)
827             {
828                 mx[j] = x[j];
829                 my[j] = y[j];
830             }
831 
832             if (!work_items_mask.any())
833             {
834                 work_items_mask.set();
835             }
836 
837             for (j = 0; j < nj; ++j)
838             {
839                 ii = j * ns;
840                 n = ii + ns > nw ? nw - ii : ns;
841                 std::set<int> active_work_items;
842                 int catch_frist_active = -1;
843                 for (i = 0; i < n; ++i)
844                 {
845                     if (work_items_mask.test(i))
846                     {
847                         if (catch_frist_active == -1)
848                         {
849                             catch_frist_active = i;
850                             tr = mx[ii + i];
851                             active_work_items.insert(i);
852                             continue;
853                         }
854                         active_work_items.insert(i);
855                         tr = calculate<Ty>(tr, mx[ii + i], operation);
856                     }
857                 }
858 
859                 if (active_work_items.empty())
860                 {
861                     continue;
862                 }
863 
864                 for (const int &active_work_item : active_work_items)
865                 {
866                     rr = my[ii + active_work_item];
867                     if (!compare_ordered<Ty>(rr, tr))
868                     {
869                         log_error("ERROR: %s_%s(%s) "
870                                   "mismatch for local id %d in sub group %d in "
871                                   "group %d %s\n",
872                                   func_name.c_str(), operation_names(operation),
873                                   TypeManager<Ty>::name(), active_work_item, j,
874                                   k, print_expected_obtained(tr, rr).c_str());
875                         return TEST_FAIL;
876                     }
877                 }
878             }
879             x += nw;
880             y += nw;
881             m += 4 * nw;
882         }
883 
884         return TEST_PASS;
885     }
886 };
887 
888 #endif
889