1 //
2 // Copyright (c) 2022 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #include "procs.h"
17 #include "subhelpers.h"
18 #include "subgroup_common_kernels.h"
19 #include "subgroup_common_templates.h"
20 #include "harness/conversions.h"
21 #include "harness/typeWrappers.h"
22
23 namespace {
24
run_rotate_for_type(RunTestForType rft)25 template <typename T> int run_rotate_for_type(RunTestForType rft)
26 {
27 int error = rft.run_impl<T, SHF<T, ShuffleOp::rotate>>("sub_group_rotate");
28 return error;
29 }
30
31 std::string sub_group_clustered_rotate_source = R"(
32 __kernel void test_%s(const __global Type *in, __global int4 *xy, __global Type *out,
33 uint cluster_size) {
34 Type r;
35 int gid = get_global_id(0);
36 XY(xy,gid);
37 Type x = in[gid];
38 int delta = xy[gid].z;
39 switch (cluster_size) {
40 case 1: r = %s(x, delta, 1); break;
41 case 2: r = %s(x, delta, 2); break;
42 case 4: r = %s(x, delta, 4); break;
43 case 8: r = %s(x, delta, 8); break;
44 case 16: r = %s(x, delta, 16); break;
45 case 32: r = %s(x, delta, 32); break;
46 case 64: r = %s(x, delta, 64); break;
47 case 128: r = %s(x, delta, 128); break;
48 }
49 out[gid] = r;
50 }
51 )";
52
run_clustered_rotate_for_type(RunTestForType rft)53 template <typename T> int run_clustered_rotate_for_type(RunTestForType rft)
54 {
55 int error = rft.run_impl<T, SHF<T, ShuffleOp::clustered_rotate>>(
56 "sub_group_clustered_rotate");
57 return error;
58 }
59
60 }
61
test_subgroup_functions_rotate(cl_device_id device,cl_context context,cl_command_queue queue,int num_elements)62 int test_subgroup_functions_rotate(cl_device_id device, cl_context context,
63 cl_command_queue queue, int num_elements)
64 {
65 if (!is_extension_available(device, "cl_khr_subgroup_rotate"))
66 {
67 log_info("cl_khr_subgroup_rotate is not supported on this device, "
68 "skipping test.\n");
69 return TEST_SKIPPED_ITSELF;
70 }
71
72 constexpr size_t global_work_size = 2000;
73 constexpr size_t local_work_size = 200;
74 WorkGroupParams test_params(global_work_size, local_work_size);
75 test_params.save_kernel_source(sub_group_generic_source);
76 RunTestForType rft(device, context, queue, num_elements, test_params);
77
78 int error = run_rotate_for_type<cl_int>(rft);
79 error |= run_rotate_for_type<cl_uint>(rft);
80 error |= run_rotate_for_type<cl_long>(rft);
81 error |= run_rotate_for_type<cl_ulong>(rft);
82 error |= run_rotate_for_type<cl_short>(rft);
83 error |= run_rotate_for_type<cl_ushort>(rft);
84 error |= run_rotate_for_type<cl_char>(rft);
85 error |= run_rotate_for_type<cl_uchar>(rft);
86 error |= run_rotate_for_type<cl_float>(rft);
87 error |= run_rotate_for_type<cl_double>(rft);
88 error |= run_rotate_for_type<subgroups::cl_half>(rft);
89
90 WorkGroupParams test_params_clustered(global_work_size, local_work_size, -1,
91 3);
92 test_params_clustered.save_kernel_source(sub_group_clustered_rotate_source);
93 RunTestForType rft_clustered(device, context, queue, num_elements,
94 test_params_clustered);
95
96 error |= run_clustered_rotate_for_type<cl_int>(rft_clustered);
97 error |= run_clustered_rotate_for_type<cl_uint>(rft_clustered);
98 error |= run_clustered_rotate_for_type<cl_long>(rft_clustered);
99 error |= run_clustered_rotate_for_type<cl_ulong>(rft_clustered);
100 error |= run_clustered_rotate_for_type<cl_short>(rft_clustered);
101 error |= run_clustered_rotate_for_type<cl_ushort>(rft_clustered);
102 error |= run_clustered_rotate_for_type<cl_char>(rft_clustered);
103 error |= run_clustered_rotate_for_type<cl_uchar>(rft_clustered);
104 error |= run_clustered_rotate_for_type<cl_float>(rft_clustered);
105 error |= run_clustered_rotate_for_type<cl_double>(rft_clustered);
106 error |= run_clustered_rotate_for_type<subgroups::cl_half>(rft_clustered);
107
108 return error;
109 }
110