xref: /aosp_15_r20/external/angle/src/libANGLE/renderer/vulkan/CLKernelVk.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // CLKernelVk.cpp: Implements the class methods for CLKernelVk.
7 
8 #include "common/PackedEnums.h"
9 
10 #include "libANGLE/renderer/vulkan/CLContextVk.h"
11 #include "libANGLE/renderer/vulkan/CLDeviceVk.h"
12 #include "libANGLE/renderer/vulkan/CLKernelVk.h"
13 #include "libANGLE/renderer/vulkan/CLProgramVk.h"
14 #include "libANGLE/renderer/vulkan/vk_wrapper.h"
15 
16 #include "libANGLE/CLContext.h"
17 #include "libANGLE/CLKernel.h"
18 #include "libANGLE/CLProgram.h"
19 #include "libANGLE/cl_utils.h"
20 #include "spirv/unified1/NonSemanticClspvReflection.h"
21 
22 namespace rx
23 {
24 
CLKernelVk(const cl::Kernel & kernel,std::string & name,std::string & attributes,CLKernelArguments & args)25 CLKernelVk::CLKernelVk(const cl::Kernel &kernel,
26                        std::string &name,
27                        std::string &attributes,
28                        CLKernelArguments &args)
29     : CLKernelImpl(kernel),
30       mProgram(&kernel.getProgram().getImpl<CLProgramVk>()),
31       mContext(&kernel.getProgram().getContext().getImpl<CLContextVk>()),
32       mName(name),
33       mAttributes(attributes),
34       mArgs(args)
35 {
36     mShaderProgramHelper.setShader(gl::ShaderType::Compute,
37                                    mKernel.getProgram().getImpl<CLProgramVk>().getShaderModule());
38 }
39 
~CLKernelVk()40 CLKernelVk::~CLKernelVk()
41 {
42     for (auto &dsLayouts : mDescriptorSetLayouts)
43     {
44         dsLayouts.reset();
45     }
46 
47     mPipelineLayout.reset();
48     for (auto &pipelineHelper : mComputePipelineCache)
49     {
50         pipelineHelper.destroy(mContext->getDevice());
51     }
52     mShaderProgramHelper.destroy(mContext->getRenderer());
53 }
54 
init()55 angle::Result CLKernelVk::init()
56 {
57     vk::DescriptorSetLayoutDesc &descriptorSetLayoutDesc =
58         mDescriptorSetLayoutDescs[DescriptorSetIndex::KernelArguments];
59     VkPushConstantRange pcRange = mProgram->getDeviceProgramData(mName.c_str())->pushConstRange;
60     for (const auto &arg : getArgs())
61     {
62         VkDescriptorType descType = VK_DESCRIPTOR_TYPE_MAX_ENUM;
63         switch (arg.type)
64         {
65             case NonSemanticClspvReflectionArgumentStorageBuffer:
66             case NonSemanticClspvReflectionArgumentPodStorageBuffer:
67                 descType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
68                 break;
69             case NonSemanticClspvReflectionArgumentUniform:
70             case NonSemanticClspvReflectionArgumentPodUniform:
71             case NonSemanticClspvReflectionArgumentPointerUniform:
72                 descType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
73                 break;
74             case NonSemanticClspvReflectionArgumentPodPushConstant:
75                 // Get existing push constant range and see if we need to update
76                 if (arg.pushConstOffset + arg.pushConstantSize > pcRange.offset + pcRange.size)
77                 {
78                     pcRange.size = arg.pushConstOffset + arg.pushConstantSize - pcRange.offset;
79                 }
80                 continue;
81             case NonSemanticClspvReflectionArgumentSampledImage:
82                 descType = VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE;
83                 break;
84             case NonSemanticClspvReflectionArgumentStorageImage:
85                 descType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
86                 break;
87             case NonSemanticClspvReflectionArgumentSampler:
88                 descType = VK_DESCRIPTOR_TYPE_SAMPLER;
89                 break;
90             case NonSemanticClspvReflectionArgumentStorageTexelBuffer:
91                 descType = VK_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER;
92                 break;
93             case NonSemanticClspvReflectionArgumentUniformTexelBuffer:
94                 descType = VK_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER;
95                 break;
96             default:
97                 continue;
98         }
99         if (descType != VK_DESCRIPTOR_TYPE_MAX_ENUM)
100         {
101             descriptorSetLayoutDesc.addBinding(arg.descriptorBinding, descType, 1,
102                                                VK_SHADER_STAGE_COMPUTE_BIT, nullptr);
103         }
104     }
105 
106     if (usesPrintf())
107     {
108         mDescriptorSetLayoutDescs[DescriptorSetIndex::Printf].addBinding(
109             mProgram->getDeviceProgramData(mName.c_str())
110                 ->reflectionData.printfBufferStorage.binding,
111             VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr);
112     }
113 
114     // Get pipeline layout from cache (creates if missed)
115     // A given kernel need not have resulted in use of all the descriptor sets. Unless the
116     // graphicsPipelineLibrary extension is supported, the pipeline layout need all the descriptor
117     // set layouts to be valide. So set them up in the order of their occurrence.
118     mPipelineLayoutDesc = {};
119     for (DescriptorSetIndex index : angle::AllEnums<DescriptorSetIndex>())
120     {
121         if (!mDescriptorSetLayoutDescs[index].empty())
122         {
123             mPipelineLayoutDesc.updateDescriptorSetLayout(index, mDescriptorSetLayoutDescs[index]);
124         }
125     }
126 
127     // push constant setup
128     // push constant size must be multiple of 4
129     pcRange.size = roundUpPow2(pcRange.size, 4u);
130     // set the pod arguments data to this size
131     mPodArgumentsData.resize(pcRange.size);
132 
133     // push constant offset must be multiple of 4, round down to ensure this
134     pcRange.offset = roundDownPow2(pcRange.offset, 4u);
135 
136     mPipelineLayoutDesc.updatePushConstantRange(pcRange.stageFlags, pcRange.offset, pcRange.size);
137 
138     return angle::Result::Continue;
139 }
140 
setArg(cl_uint argIndex,size_t argSize,const void * argValue)141 angle::Result CLKernelVk::setArg(cl_uint argIndex, size_t argSize, const void *argValue)
142 {
143     auto &arg = mArgs.at(argIndex);
144     if (arg.used)
145     {
146         arg.handle     = const_cast<void *>(argValue);
147         arg.handleSize = argSize;
148 
149         // For POD data, copy the contents as the app is free to delete the contents post this call.
150         if (arg.type == NonSemanticClspvReflectionArgumentPodPushConstant && argSize > 0 &&
151             argValue != nullptr)
152         {
153             ASSERT(mPodArgumentsData.size() >= arg.pushConstantSize + arg.pushConstOffset);
154             memcpy(&mPodArgumentsData[arg.pushConstOffset], argValue, argSize);
155         }
156 
157         if (arg.type == NonSemanticClspvReflectionArgumentWorkgroup)
158         {
159             mSpecConstants.push_back(
160                 KernelSpecConstant{.ID   = arg.workgroupSpecId,
161                                    .data = static_cast<uint32_t>(argSize / arg.workgroupSize)});
162         }
163     }
164 
165     return angle::Result::Continue;
166 }
167 
createInfo(CLKernelImpl::Info * info) const168 angle::Result CLKernelVk::createInfo(CLKernelImpl::Info *info) const
169 {
170     info->functionName = mName;
171     info->attributes   = mAttributes;
172     info->numArgs      = static_cast<cl_uint>(mArgs.size());
173     for (const auto &arg : mArgs)
174     {
175         ArgInfo argInfo;
176         argInfo.name             = arg.info.name;
177         argInfo.typeName         = arg.info.typeName;
178         argInfo.accessQualifier  = arg.info.accessQualifier;
179         argInfo.addressQualifier = arg.info.addressQualifier;
180         argInfo.typeQualifier    = arg.info.typeQualifier;
181         info->args.push_back(std::move(argInfo));
182     }
183 
184     auto &ctx = mKernel.getProgram().getContext();
185     info->workGroups.resize(ctx.getDevices().size());
186     const CLProgramVk::DeviceProgramData *deviceProgramData = nullptr;
187     for (auto i = 0u; i < ctx.getDevices().size(); ++i)
188     {
189         auto &workGroup     = info->workGroups[i];
190         const auto deviceVk = &ctx.getDevices()[i]->getImpl<CLDeviceVk>();
191         deviceProgramData   = mProgram->getDeviceProgramData(ctx.getDevices()[i]->getNative());
192         if (deviceProgramData == nullptr)
193         {
194             continue;
195         }
196 
197         // TODO: http://anglebug.com/42267005
198         ANGLE_TRY(
199             deviceVk->getInfoSizeT(cl::DeviceInfo::MaxWorkGroupSize, &workGroup.workGroupSize));
200 
201         // TODO: http://anglebug.com/42267004
202         workGroup.privateMemSize = 0;
203         workGroup.localMemSize   = 0;
204 
205         workGroup.prefWorkGroupSizeMultiple = 16u;
206         workGroup.globalWorkSize            = {0, 0, 0};
207         if (deviceProgramData->reflectionData.kernelCompileWorkgroupSize.contains(mName))
208         {
209             workGroup.compileWorkGroupSize = {
210                 deviceProgramData->reflectionData.kernelCompileWorkgroupSize.at(mName)[0],
211                 deviceProgramData->reflectionData.kernelCompileWorkgroupSize.at(mName)[1],
212                 deviceProgramData->reflectionData.kernelCompileWorkgroupSize.at(mName)[2]};
213         }
214         else
215         {
216             workGroup.compileWorkGroupSize = {0, 0, 0};
217         }
218     }
219 
220     return angle::Result::Continue;
221 }
222 
getOrCreateComputePipeline(vk::PipelineCacheAccess * pipelineCache,const cl::NDRange & ndrange,const cl::Device & device,vk::PipelineHelper ** pipelineOut,cl::WorkgroupCount * workgroupCountOut)223 angle::Result CLKernelVk::getOrCreateComputePipeline(vk::PipelineCacheAccess *pipelineCache,
224                                                      const cl::NDRange &ndrange,
225                                                      const cl::Device &device,
226                                                      vk::PipelineHelper **pipelineOut,
227                                                      cl::WorkgroupCount *workgroupCountOut)
228 {
229     const CLProgramVk::DeviceProgramData *devProgramData =
230         getProgram()->getDeviceProgramData(device.getNative());
231     ASSERT(devProgramData != nullptr);
232 
233     // Start with Workgroup size (WGS) from kernel attribute (if available)
234     cl::WorkgroupSize workgroupSize = devProgramData->getCompiledWorkgroupSize(getKernelName());
235 
236     if (workgroupSize == kEmptyWorkgroupSize)
237     {
238         if (ndrange.nullLocalWorkSize)
239         {
240             // NULL value was passed, in which case the OpenCL implementation will determine
241             // how to be break the global work-items into appropriate work-group instances.
242             workgroupSize = device.getImpl<CLDeviceVk>().selectWorkGroupSize(ndrange);
243         }
244         else
245         {
246             // Local work size (LWS) was valid, use that as WGS
247             workgroupSize = ndrange.localWorkSize;
248         }
249     }
250 
251     // Calculate the workgroup count
252     // TODO: Add support for non-uniform WGS
253     // http://angleproject:8631
254     ASSERT(workgroupSize[0] != 0);
255     ASSERT(workgroupSize[1] != 0);
256     ASSERT(workgroupSize[2] != 0);
257     (*workgroupCountOut)[0] = static_cast<uint32_t>((ndrange.globalWorkSize[0] / workgroupSize[0]));
258     (*workgroupCountOut)[1] = static_cast<uint32_t>((ndrange.globalWorkSize[1] / workgroupSize[1]));
259     (*workgroupCountOut)[2] = static_cast<uint32_t>((ndrange.globalWorkSize[2] / workgroupSize[2]));
260 
261     // Populate program specialization constants (if any)
262     uint32_t constantDataOffset = 0;
263     std::vector<uint32_t> specConstantData;
264     std::vector<VkSpecializationMapEntry> mapEntries;
265     for (const auto specConstantUsed : devProgramData->reflectionData.specConstantsUsed)
266     {
267         switch (specConstantUsed)
268         {
269             case SpecConstantType::WorkDimension:
270                 specConstantData.push_back(ndrange.workDimensions);
271                 break;
272             case SpecConstantType::WorkgroupSizeX:
273                 specConstantData.push_back(static_cast<uint32_t>(workgroupSize[0]));
274                 break;
275             case SpecConstantType::WorkgroupSizeY:
276                 specConstantData.push_back(static_cast<uint32_t>(workgroupSize[1]));
277                 break;
278             case SpecConstantType::WorkgroupSizeZ:
279                 specConstantData.push_back(static_cast<uint32_t>(workgroupSize[2]));
280                 break;
281             case SpecConstantType::GlobalOffsetX:
282                 specConstantData.push_back(static_cast<uint32_t>(ndrange.globalWorkOffset[0]));
283                 break;
284             case SpecConstantType::GlobalOffsetY:
285                 specConstantData.push_back(static_cast<uint32_t>(ndrange.globalWorkOffset[1]));
286                 break;
287             case SpecConstantType::GlobalOffsetZ:
288                 specConstantData.push_back(static_cast<uint32_t>(ndrange.globalWorkOffset[2]));
289                 break;
290             default:
291                 UNIMPLEMENTED();
292                 continue;
293         }
294         mapEntries.push_back(VkSpecializationMapEntry{
295             .constantID = devProgramData->reflectionData.specConstantIDs[specConstantUsed],
296             .offset     = constantDataOffset,
297             .size       = sizeof(uint32_t)});
298         constantDataOffset += sizeof(uint32_t);
299     }
300     // Populate kernel specialization constants (if any)
301     for (const auto &specConstant : mSpecConstants)
302     {
303         specConstantData.push_back(specConstant.data);
304         mapEntries.push_back(VkSpecializationMapEntry{
305             .constantID = specConstant.ID, .offset = constantDataOffset, .size = sizeof(uint32_t)});
306         constantDataOffset += sizeof(uint32_t);
307     }
308     VkSpecializationInfo computeSpecializationInfo{
309         .mapEntryCount = static_cast<uint32_t>(mapEntries.size()),
310         .pMapEntries   = mapEntries.data(),
311         .dataSize      = specConstantData.size() * sizeof(uint32_t),
312         .pData         = specConstantData.data(),
313     };
314 
315     // Now get or create (on compute pipeline cache miss) compute pipeline and return it
316     return mShaderProgramHelper.getOrCreateComputePipeline(
317         mContext, &mComputePipelineCache, pipelineCache, getPipelineLayout(),
318         vk::ComputePipelineOptions{}, PipelineSource::Draw, pipelineOut, mName.c_str(),
319         &computeSpecializationInfo);
320 }
321 
usesPrintf() const322 bool CLKernelVk::usesPrintf() const
323 {
324     return mProgram->getDeviceProgramData(mName.c_str())->getKernelFlags(mName) &
325            NonSemanticClspvReflectionMayUsePrintf;
326 }
327 
allocateDescriptorSet(DescriptorSetIndex index,angle::EnumIterator<DescriptorSetIndex> layoutIndex,vk::OutsideRenderPassCommandBufferHelper * computePassCommands)328 angle::Result CLKernelVk::allocateDescriptorSet(
329     DescriptorSetIndex index,
330     angle::EnumIterator<DescriptorSetIndex> layoutIndex,
331     vk::OutsideRenderPassCommandBufferHelper *computePassCommands)
332 {
333     return mProgram->allocateDescriptorSet(index, *mDescriptorSetLayouts[*layoutIndex],
334                                            computePassCommands, &mDescriptorSets[index]);
335 }
336 }  // namespace rx
337