xref: /aosp_15_r20/external/OpenCL-CTS/test_conformance/subgroups/test_subgroup_non_uniform_vote.cpp (revision 6467f958c7de8070b317fc65bcb0f6472e388d82)
1 //
2 // Copyright (c) 2021 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //    http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #include "procs.h"
17 #include "subhelpers.h"
18 #include "harness/typeWrappers.h"
19 #include <set>
20 
21 namespace {
22 
23 template <typename T, NonUniformVoteOp operation> struct VOTE
24 {
log_test__anon825482cd0111::VOTE25     static void log_test(const WorkGroupParams &test_params,
26                          const char *extra_text)
27     {
28         log_info("  sub_group_%s%s(%s)...%s\n",
29                  (operation == NonUniformVoteOp::elect) ? "" : "non_uniform_",
30                  operation_names(operation), TypeManager<T>::name(),
31                  extra_text);
32     }
33 
gen__anon825482cd0111::VOTE34     static void gen(T *x, T *t, cl_int *m, const WorkGroupParams &test_params)
35     {
36         int i, ii, j, k, n;
37         int nw = test_params.local_workgroup_size;
38         int ns = test_params.subgroup_size;
39         int ng = test_params.global_workgroup_size;
40         int nj = (nw + ns - 1) / ns;
41         int non_uniform_size = ng % nw;
42         ng = ng / nw;
43         int last_subgroup_size = 0;
44         ii = 0;
45 
46         if (operation == NonUniformVoteOp::elect) return;
47 
48         for (k = 0; k < ng; ++k)
49         { // for each work_group
50             if (non_uniform_size && k == ng - 1)
51             {
52                 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
53                                           last_subgroup_size);
54             }
55             for (j = 0; j < nj; ++j)
56             { // for each subgroup
57                 ii = j * ns;
58                 if (last_subgroup_size && j == nj - 1)
59                 {
60                     n = last_subgroup_size;
61                 }
62                 else
63                 {
64                     n = ii + ns > nw ? nw - ii : ns;
65                 }
66                 int e = genrand_int32(gMTdata) % 3;
67 
68                 for (i = 0; i < n; i++)
69                 {
70                     if (e == 2)
71                     { // set once 0 and once 1 alternately
72                         int value = i % 2;
73                         set_value(t[ii + i], value);
74                     }
75                     else
76                     { // set 0/1 for all work items in subgroup
77                         set_value(t[ii + i], e);
78                     }
79                 }
80             }
81             // Now map into work group using map from device
82             for (j = 0; j < nw; ++j)
83             {
84                 x[j] = t[j];
85             }
86             x += nw;
87             m += 4 * nw;
88         }
89     }
90 
chk__anon825482cd0111::VOTE91     static test_status chk(T *x, T *y, T *mx, T *my, cl_int *m,
92                            const WorkGroupParams &test_params)
93     {
94         int ii, i, j, k, n;
95         int nw = test_params.local_workgroup_size;
96         int ns = test_params.subgroup_size;
97         int ng = test_params.global_workgroup_size;
98         int nj = (nw + ns - 1) / ns;
99         cl_int tr, rr;
100         int non_uniform_size = ng % nw;
101         ng = ng / nw;
102         if (non_uniform_size) ng++;
103         int last_subgroup_size = 0;
104 
105         for (k = 0; k < ng; ++k)
106         { // for each work_group
107             if (non_uniform_size && k == ng - 1)
108             {
109                 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
110                                           last_subgroup_size);
111             }
112             for (j = 0; j < nw; ++j)
113             { // inside the work_group
114                 mx[j] = x[j]; // read host inputs for work_group
115                 my[j] = y[j]; // read device outputs for work_group
116             }
117 
118             for (j = 0; j < nj; ++j)
119             { // for each subgroup
120                 ii = j * ns;
121                 if (last_subgroup_size && j == nj - 1)
122                 {
123                     n = last_subgroup_size;
124                 }
125                 else
126                 {
127                     n = ii + ns > nw ? nw - ii : ns;
128                 }
129 
130                 rr = 0;
131                 if (operation == NonUniformVoteOp::all
132                     || operation == NonUniformVoteOp::all_equal)
133                     tr = 1;
134                 if (operation == NonUniformVoteOp::any) tr = 0;
135 
136                 std::set<int> active_work_items;
137                 for (i = 0; i < n; ++i)
138                 {
139                     if (test_params.work_items_mask.test(i))
140                     {
141                         active_work_items.insert(i);
142                         switch (operation)
143                         {
144                             case NonUniformVoteOp::elect: break;
145 
146                             case NonUniformVoteOp::all:
147                                 tr &=
148                                     !compare_ordered<T>(mx[ii + i], 0) ? 1 : 0;
149                                 break;
150                             case NonUniformVoteOp::any:
151                                 tr |=
152                                     !compare_ordered<T>(mx[ii + i], 0) ? 1 : 0;
153                                 break;
154                             case NonUniformVoteOp::all_equal:
155                                 tr &= compare_ordered<T>(
156                                           mx[ii + i],
157                                           mx[ii + *active_work_items.begin()])
158                                     ? 1
159                                     : 0;
160                                 break;
161                             default:
162                                 log_error("Unknown operation\n");
163                                 return TEST_FAIL;
164                         }
165                     }
166                 }
167                 if (active_work_items.empty())
168                 {
169                     continue;
170                 }
171                 auto lowest_active = active_work_items.begin();
172                 for (const int &active_work_item : active_work_items)
173                 {
174                     i = active_work_item;
175                     if (operation == NonUniformVoteOp::elect)
176                     {
177                         i == *lowest_active ? tr = 1 : tr = 0;
178                     }
179 
180                     // normalize device values on host, non zero set 1.
181                     rr = compare_ordered<T>(my[ii + i], 0) ? 0 : 1;
182 
183                     if (rr != tr)
184                     {
185                         log_error("ERROR: sub_group_%s() \n",
186                                   operation_names(operation));
187                         log_error("mismatch for work item %d sub group %d in "
188                                   "work group %d. Expected: %d Obtained: %d\n",
189                                   i, j, k, tr, rr);
190                         return TEST_FAIL;
191                     }
192                 }
193             }
194 
195             x += nw;
196             y += nw;
197             m += 4 * nw;
198         }
199 
200         return TEST_PASS;
201     }
202 };
203 
204 std::string sub_group_elect_source = R"(
205     __kernel void test_sub_group_elect(const __global Type *in, __global int4 *xy, __global Type *out, uint4 work_item_mask_vector) {
206         int gid = get_global_id(0);
207         XY(xy,gid);
208         uint subgroup_local_id = get_sub_group_local_id();
209         uint elect_work_item = 1 << (subgroup_local_id % 32);
210         uint work_item_mask;
211         if(subgroup_local_id < 32) {
212             work_item_mask = work_item_mask_vector.x;
213         } else if(subgroup_local_id < 64) {
214             work_item_mask = work_item_mask_vector.y;
215         } else if(subgroup_local_id < 96) {
216             work_item_mask = work_item_mask_vector.z;
217         } else if(subgroup_local_id < 128) {
218             work_item_mask = work_item_mask_vector.w;
219         }
220         if (elect_work_item & work_item_mask){
221             out[gid] = sub_group_elect();
222         }
223     }
224 )";
225 
226 std::string sub_group_non_uniform_any_all_all_equal_source = R"(
227     __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out, uint4 work_item_mask_vector) {
228         int gid = get_global_id(0);
229         XY(xy,gid);
230         uint subgroup_local_id = get_sub_group_local_id();
231         uint elect_work_item = 1 << (subgroup_local_id % 32);
232         uint work_item_mask;
233         if(subgroup_local_id < 32) {
234             work_item_mask = work_item_mask_vector.x;
235         } else if(subgroup_local_id < 64) {
236             work_item_mask = work_item_mask_vector.y;
237         } else if(subgroup_local_id < 96) {
238             work_item_mask = work_item_mask_vector.z;
239         } else if(subgroup_local_id < 128) {
240             work_item_mask = work_item_mask_vector.w;
241         }
242         if (elect_work_item & work_item_mask){
243                 out[gid] = %s(in[gid]);
244             }
245     }
246 )";
247 
run_vote_all_equal_for_type(RunTestForType rft)248 template <typename T> int run_vote_all_equal_for_type(RunTestForType rft)
249 {
250     int error = rft.run_impl<T, VOTE<T, NonUniformVoteOp::all_equal>>(
251         "sub_group_non_uniform_all_equal");
252     return error;
253 }
254 }
255 
test_subgroup_functions_non_uniform_vote(cl_device_id device,cl_context context,cl_command_queue queue,int num_elements)256 int test_subgroup_functions_non_uniform_vote(cl_device_id device,
257                                              cl_context context,
258                                              cl_command_queue queue,
259                                              int num_elements)
260 {
261     if (!is_extension_available(device, "cl_khr_subgroup_non_uniform_vote"))
262     {
263         log_info("cl_khr_subgroup_non_uniform_vote is not supported on this "
264                  "device, skipping test.\n");
265         return TEST_SKIPPED_ITSELF;
266     }
267 
268     constexpr size_t global_work_size = 170;
269     constexpr size_t local_work_size = 64;
270     WorkGroupParams test_params(global_work_size, local_work_size, 3);
271     test_params.save_kernel_source(
272         sub_group_non_uniform_any_all_all_equal_source);
273     test_params.save_kernel_source(sub_group_elect_source, "sub_group_elect");
274     RunTestForType rft(device, context, queue, num_elements, test_params);
275 
276     int error = run_vote_all_equal_for_type<cl_int>(rft);
277     error |= run_vote_all_equal_for_type<cl_uint>(rft);
278     error |= run_vote_all_equal_for_type<cl_long>(rft);
279     error |= run_vote_all_equal_for_type<cl_ulong>(rft);
280     error |= run_vote_all_equal_for_type<cl_float>(rft);
281     error |= run_vote_all_equal_for_type<cl_double>(rft);
282     error |= run_vote_all_equal_for_type<subgroups::cl_half>(rft);
283 
284     error |= rft.run_impl<cl_int, VOTE<cl_int, NonUniformVoteOp::all>>(
285         "sub_group_non_uniform_all");
286     error |= rft.run_impl<cl_int, VOTE<cl_int, NonUniformVoteOp::elect>>(
287         "sub_group_elect");
288     error |= rft.run_impl<cl_int, VOTE<cl_int, NonUniformVoteOp::any>>(
289         "sub_group_non_uniform_any");
290     return error;
291 }
292