1 //
2 // Copyright (c) 2021 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #include "procs.h"
17 #include "subhelpers.h"
18 #include "harness/typeWrappers.h"
19 #include <set>
20
21 namespace {
22
23 template <typename T, NonUniformVoteOp operation> struct VOTE
24 {
log_test__anon825482cd0111::VOTE25 static void log_test(const WorkGroupParams &test_params,
26 const char *extra_text)
27 {
28 log_info(" sub_group_%s%s(%s)...%s\n",
29 (operation == NonUniformVoteOp::elect) ? "" : "non_uniform_",
30 operation_names(operation), TypeManager<T>::name(),
31 extra_text);
32 }
33
gen__anon825482cd0111::VOTE34 static void gen(T *x, T *t, cl_int *m, const WorkGroupParams &test_params)
35 {
36 int i, ii, j, k, n;
37 int nw = test_params.local_workgroup_size;
38 int ns = test_params.subgroup_size;
39 int ng = test_params.global_workgroup_size;
40 int nj = (nw + ns - 1) / ns;
41 int non_uniform_size = ng % nw;
42 ng = ng / nw;
43 int last_subgroup_size = 0;
44 ii = 0;
45
46 if (operation == NonUniformVoteOp::elect) return;
47
48 for (k = 0; k < ng; ++k)
49 { // for each work_group
50 if (non_uniform_size && k == ng - 1)
51 {
52 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
53 last_subgroup_size);
54 }
55 for (j = 0; j < nj; ++j)
56 { // for each subgroup
57 ii = j * ns;
58 if (last_subgroup_size && j == nj - 1)
59 {
60 n = last_subgroup_size;
61 }
62 else
63 {
64 n = ii + ns > nw ? nw - ii : ns;
65 }
66 int e = genrand_int32(gMTdata) % 3;
67
68 for (i = 0; i < n; i++)
69 {
70 if (e == 2)
71 { // set once 0 and once 1 alternately
72 int value = i % 2;
73 set_value(t[ii + i], value);
74 }
75 else
76 { // set 0/1 for all work items in subgroup
77 set_value(t[ii + i], e);
78 }
79 }
80 }
81 // Now map into work group using map from device
82 for (j = 0; j < nw; ++j)
83 {
84 x[j] = t[j];
85 }
86 x += nw;
87 m += 4 * nw;
88 }
89 }
90
chk__anon825482cd0111::VOTE91 static test_status chk(T *x, T *y, T *mx, T *my, cl_int *m,
92 const WorkGroupParams &test_params)
93 {
94 int ii, i, j, k, n;
95 int nw = test_params.local_workgroup_size;
96 int ns = test_params.subgroup_size;
97 int ng = test_params.global_workgroup_size;
98 int nj = (nw + ns - 1) / ns;
99 cl_int tr, rr;
100 int non_uniform_size = ng % nw;
101 ng = ng / nw;
102 if (non_uniform_size) ng++;
103 int last_subgroup_size = 0;
104
105 for (k = 0; k < ng; ++k)
106 { // for each work_group
107 if (non_uniform_size && k == ng - 1)
108 {
109 set_last_workgroup_params(non_uniform_size, nj, ns, nw,
110 last_subgroup_size);
111 }
112 for (j = 0; j < nw; ++j)
113 { // inside the work_group
114 mx[j] = x[j]; // read host inputs for work_group
115 my[j] = y[j]; // read device outputs for work_group
116 }
117
118 for (j = 0; j < nj; ++j)
119 { // for each subgroup
120 ii = j * ns;
121 if (last_subgroup_size && j == nj - 1)
122 {
123 n = last_subgroup_size;
124 }
125 else
126 {
127 n = ii + ns > nw ? nw - ii : ns;
128 }
129
130 rr = 0;
131 if (operation == NonUniformVoteOp::all
132 || operation == NonUniformVoteOp::all_equal)
133 tr = 1;
134 if (operation == NonUniformVoteOp::any) tr = 0;
135
136 std::set<int> active_work_items;
137 for (i = 0; i < n; ++i)
138 {
139 if (test_params.work_items_mask.test(i))
140 {
141 active_work_items.insert(i);
142 switch (operation)
143 {
144 case NonUniformVoteOp::elect: break;
145
146 case NonUniformVoteOp::all:
147 tr &=
148 !compare_ordered<T>(mx[ii + i], 0) ? 1 : 0;
149 break;
150 case NonUniformVoteOp::any:
151 tr |=
152 !compare_ordered<T>(mx[ii + i], 0) ? 1 : 0;
153 break;
154 case NonUniformVoteOp::all_equal:
155 tr &= compare_ordered<T>(
156 mx[ii + i],
157 mx[ii + *active_work_items.begin()])
158 ? 1
159 : 0;
160 break;
161 default:
162 log_error("Unknown operation\n");
163 return TEST_FAIL;
164 }
165 }
166 }
167 if (active_work_items.empty())
168 {
169 continue;
170 }
171 auto lowest_active = active_work_items.begin();
172 for (const int &active_work_item : active_work_items)
173 {
174 i = active_work_item;
175 if (operation == NonUniformVoteOp::elect)
176 {
177 i == *lowest_active ? tr = 1 : tr = 0;
178 }
179
180 // normalize device values on host, non zero set 1.
181 rr = compare_ordered<T>(my[ii + i], 0) ? 0 : 1;
182
183 if (rr != tr)
184 {
185 log_error("ERROR: sub_group_%s() \n",
186 operation_names(operation));
187 log_error("mismatch for work item %d sub group %d in "
188 "work group %d. Expected: %d Obtained: %d\n",
189 i, j, k, tr, rr);
190 return TEST_FAIL;
191 }
192 }
193 }
194
195 x += nw;
196 y += nw;
197 m += 4 * nw;
198 }
199
200 return TEST_PASS;
201 }
202 };
203
204 std::string sub_group_elect_source = R"(
205 __kernel void test_sub_group_elect(const __global Type *in, __global int4 *xy, __global Type *out, uint4 work_item_mask_vector) {
206 int gid = get_global_id(0);
207 XY(xy,gid);
208 uint subgroup_local_id = get_sub_group_local_id();
209 uint elect_work_item = 1 << (subgroup_local_id % 32);
210 uint work_item_mask;
211 if(subgroup_local_id < 32) {
212 work_item_mask = work_item_mask_vector.x;
213 } else if(subgroup_local_id < 64) {
214 work_item_mask = work_item_mask_vector.y;
215 } else if(subgroup_local_id < 96) {
216 work_item_mask = work_item_mask_vector.z;
217 } else if(subgroup_local_id < 128) {
218 work_item_mask = work_item_mask_vector.w;
219 }
220 if (elect_work_item & work_item_mask){
221 out[gid] = sub_group_elect();
222 }
223 }
224 )";
225
226 std::string sub_group_non_uniform_any_all_all_equal_source = R"(
227 __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out, uint4 work_item_mask_vector) {
228 int gid = get_global_id(0);
229 XY(xy,gid);
230 uint subgroup_local_id = get_sub_group_local_id();
231 uint elect_work_item = 1 << (subgroup_local_id % 32);
232 uint work_item_mask;
233 if(subgroup_local_id < 32) {
234 work_item_mask = work_item_mask_vector.x;
235 } else if(subgroup_local_id < 64) {
236 work_item_mask = work_item_mask_vector.y;
237 } else if(subgroup_local_id < 96) {
238 work_item_mask = work_item_mask_vector.z;
239 } else if(subgroup_local_id < 128) {
240 work_item_mask = work_item_mask_vector.w;
241 }
242 if (elect_work_item & work_item_mask){
243 out[gid] = %s(in[gid]);
244 }
245 }
246 )";
247
run_vote_all_equal_for_type(RunTestForType rft)248 template <typename T> int run_vote_all_equal_for_type(RunTestForType rft)
249 {
250 int error = rft.run_impl<T, VOTE<T, NonUniformVoteOp::all_equal>>(
251 "sub_group_non_uniform_all_equal");
252 return error;
253 }
254 }
255
test_subgroup_functions_non_uniform_vote(cl_device_id device,cl_context context,cl_command_queue queue,int num_elements)256 int test_subgroup_functions_non_uniform_vote(cl_device_id device,
257 cl_context context,
258 cl_command_queue queue,
259 int num_elements)
260 {
261 if (!is_extension_available(device, "cl_khr_subgroup_non_uniform_vote"))
262 {
263 log_info("cl_khr_subgroup_non_uniform_vote is not supported on this "
264 "device, skipping test.\n");
265 return TEST_SKIPPED_ITSELF;
266 }
267
268 constexpr size_t global_work_size = 170;
269 constexpr size_t local_work_size = 64;
270 WorkGroupParams test_params(global_work_size, local_work_size, 3);
271 test_params.save_kernel_source(
272 sub_group_non_uniform_any_all_all_equal_source);
273 test_params.save_kernel_source(sub_group_elect_source, "sub_group_elect");
274 RunTestForType rft(device, context, queue, num_elements, test_params);
275
276 int error = run_vote_all_equal_for_type<cl_int>(rft);
277 error |= run_vote_all_equal_for_type<cl_uint>(rft);
278 error |= run_vote_all_equal_for_type<cl_long>(rft);
279 error |= run_vote_all_equal_for_type<cl_ulong>(rft);
280 error |= run_vote_all_equal_for_type<cl_float>(rft);
281 error |= run_vote_all_equal_for_type<cl_double>(rft);
282 error |= run_vote_all_equal_for_type<subgroups::cl_half>(rft);
283
284 error |= rft.run_impl<cl_int, VOTE<cl_int, NonUniformVoteOp::all>>(
285 "sub_group_non_uniform_all");
286 error |= rft.run_impl<cl_int, VOTE<cl_int, NonUniformVoteOp::elect>>(
287 "sub_group_elect");
288 error |= rft.run_impl<cl_int, VOTE<cl_int, NonUniformVoteOp::any>>(
289 "sub_group_non_uniform_any");
290 return error;
291 }
292