xref: /aosp_15_r20/external/angle/src/libANGLE/renderer/vulkan/CLProgramVk.h (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2021 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // CLProgramVk.h: Defines the class interface for CLProgramVk, implementing CLProgramImpl.
7 
8 #ifndef LIBANGLE_RENDERER_VULKAN_CLPROGRAMVK_H_
9 #define LIBANGLE_RENDERER_VULKAN_CLPROGRAMVK_H_
10 
11 #include <cstdint>
12 
13 #include "common/SimpleMutex.h"
14 #include "common/hash_containers.h"
15 
16 #include "libANGLE/renderer/vulkan/CLContextVk.h"
17 #include "libANGLE/renderer/vulkan/CLKernelVk.h"
18 #include "libANGLE/renderer/vulkan/cl_types.h"
19 #include "libANGLE/renderer/vulkan/clspv_utils.h"
20 #include "libANGLE/renderer/vulkan/vk_cache_utils.h"
21 #include "libANGLE/renderer/vulkan/vk_helpers.h"
22 
23 #include "libANGLE/renderer/CLProgramImpl.h"
24 
25 #include "libANGLE/CLProgram.h"
26 
27 #include "clspv/Compiler.h"
28 
29 #include "vulkan/vulkan_core.h"
30 
31 #include "spirv-tools/libspirv.h"
32 
33 #include "spirv/unified1/NonSemanticClspvReflection.h"
34 
35 namespace rx
36 {
37 
38 class CLProgramVk : public CLProgramImpl
39 {
40   public:
41     using Ptr = std::unique_ptr<CLProgramVk>;
42     // TODO: Look into moving this information in CLKernelArgument
43     // https://anglebug.com/378514267
44     struct ImagePushConstant
45     {
46         VkPushConstantRange pcRange;
47         uint32_t ordinal;
48     };
49     struct SpvReflectionData
50     {
51         angle::HashMap<uint32_t, uint32_t> spvIntLookup;
52         angle::HashMap<uint32_t, std::string> spvStrLookup;
53         angle::HashMap<uint32_t, CLKernelVk::ArgInfo> kernelArgInfos;
54         angle::HashMap<std::string, uint32_t> kernelFlags;
55         angle::HashMap<std::string, std::string> kernelAttributes;
56         angle::HashMap<std::string, std::array<uint32_t, 3>> kernelCompileWorkgroupSize;
57         angle::HashMap<uint32_t, VkPushConstantRange> pushConstants;
58         angle::PackedEnumMap<SpecConstantType, uint32_t> specConstantIDs;
59         angle::PackedEnumBitSet<SpecConstantType, uint32_t> specConstantsUsed;
60         angle::HashMap<uint32_t, std::vector<ImagePushConstant>> imagePushConstants;
61         CLKernelArgsMap kernelArgsMap;
62         angle::HashMap<std::string, CLKernelArgument> kernelArgMap;
63         angle::HashSet<uint32_t> kernelIDs;
64         ClspvPrintfBufferStorage printfBufferStorage;
65         angle::HashMap<uint32_t, ClspvPrintfInfo> printfInfoMap;
66     };
67 
68     // Output binary structure (for CL_PROGRAM_BINARIES query)
69     static constexpr uint32_t kBinaryVersion = 2;
70     struct ProgramBinaryOutputHeader
71     {
72         uint32_t headerVersion{kBinaryVersion};
73         cl_program_binary_type binaryType{CL_PROGRAM_BINARY_TYPE_NONE};
74         cl_build_status buildStatus{CL_BUILD_NONE};
75     };
76 
77     struct ScopedClspvContext : angle::NonCopyable
78     {
79         ScopedClspvContext() = default;
~ScopedClspvContextScopedClspvContext80         ~ScopedClspvContext() { clspvFreeOutputBuildObjs(mOutputBin, mOutputBuildLog); }
81 
82         size_t mOutputBinSize{0};
83         char *mOutputBin{nullptr};
84         char *mOutputBuildLog{nullptr};
85     };
86 
87     struct ScopedProgramCallback : angle::NonCopyable
88     {
89         ScopedProgramCallback() = delete;
ScopedProgramCallbackScopedProgramCallback90         ScopedProgramCallback(cl::Program *notify) : mNotify(notify) {}
~ScopedProgramCallbackScopedProgramCallback91         ~ScopedProgramCallback()
92         {
93             if (mNotify)
94             {
95                 mNotify->callback();
96             }
97         }
98 
99         cl::Program *mNotify{nullptr};
100     };
101 
102     enum class BuildType
103     {
104         BUILD = 0,
105         COMPILE,
106         LINK,
107         BINARY
108     };
109 
110     struct DeviceProgramData
111     {
112         std::vector<char> IR;
113         std::string buildLog;
114         angle::spirv::Blob binary;
115         SpvReflectionData reflectionData;
116         VkPushConstantRange pushConstRange{};
117         cl_build_status buildStatus{CL_BUILD_NONE};
118         cl_program_binary_type binaryType{CL_PROGRAM_BINARY_TYPE_NONE};
119 
numKernelsDeviceProgramData120         size_t numKernels() const { return reflectionData.kernelArgsMap.size(); }
121 
numKernelArgsDeviceProgramData122         size_t numKernelArgs(const std::string &kernelName) const
123         {
124             return containsKernel(kernelName) ? getKernelArgsMap().at(kernelName).size() : 0;
125         }
126 
getKernelArgsMapDeviceProgramData127         const CLKernelArgsMap &getKernelArgsMap() const { return reflectionData.kernelArgsMap; }
128 
containsKernelDeviceProgramData129         bool containsKernel(const std::string &name) const
130         {
131             return reflectionData.kernelArgsMap.contains(name);
132         }
133 
getKernelNamesDeviceProgramData134         std::string getKernelNames() const
135         {
136             std::string names;
137             for (auto name = getKernelArgsMap().begin(); name != getKernelArgsMap().end(); ++name)
138             {
139                 names += name->first + (std::next(name) != getKernelArgsMap().end() ? ";" : "\0");
140             }
141             return names;
142         }
143 
getKernelFlagsDeviceProgramData144         uint32_t getKernelFlags(const std::string &kernelName) const
145         {
146             if (containsKernel(kernelName))
147             {
148                 return reflectionData.kernelFlags.at(kernelName);
149             }
150             return 0;
151         }
152 
getKernelArgumentsDeviceProgramData153         CLKernelArguments getKernelArguments(const std::string &kernelName) const
154         {
155             CLKernelArguments kargsCopy;
156             if (containsKernel(kernelName))
157             {
158                 const CLKernelArguments &kargs = getKernelArgsMap().at(kernelName);
159                 for (const CLKernelArgument &karg : kargs)
160                 {
161                     kargsCopy.push_back(karg);
162                 }
163             }
164             return kargsCopy;
165         }
166 
getCompiledWorkgroupSizeDeviceProgramData167         cl::WorkgroupSize getCompiledWorkgroupSize(const std::string &kernelName) const
168         {
169             cl::WorkgroupSize compiledWorkgroupSize{0, 0, 0};
170             if (reflectionData.kernelCompileWorkgroupSize.contains(kernelName))
171             {
172                 for (size_t i = 0; i < compiledWorkgroupSize.size(); ++i)
173                 {
174                     compiledWorkgroupSize[i] =
175                         reflectionData.kernelCompileWorkgroupSize.at(kernelName)[i];
176                 }
177             }
178             return compiledWorkgroupSize;
179         }
180 
getKernelAttributesDeviceProgramData181         std::string getKernelAttributes(const std::string &kernelName) const
182         {
183             if (containsKernel(kernelName))
184             {
185                 return reflectionData.kernelAttributes.at(kernelName.c_str());
186             }
187             return std::string{};
188         }
189 
getPushConstantRangeFromClspvReflectionTypeDeviceProgramData190         const VkPushConstantRange *getPushConstantRangeFromClspvReflectionType(
191             NonSemanticClspvReflectionInstructions type) const
192         {
193             const VkPushConstantRange *pushConstantRangePtr = nullptr;
194             if (reflectionData.pushConstants.contains(type))
195             {
196                 pushConstantRangePtr = &reflectionData.pushConstants.at(type);
197             }
198             return pushConstantRangePtr;
199         }
200 
getGlobalOffsetRangeDeviceProgramData201         inline const VkPushConstantRange *getGlobalOffsetRange() const
202         {
203             return getPushConstantRangeFromClspvReflectionType(
204                 NonSemanticClspvReflectionPushConstantGlobalOffset);
205         }
206 
getGlobalSizeRangeDeviceProgramData207         inline const VkPushConstantRange *getGlobalSizeRange() const
208         {
209             return getPushConstantRangeFromClspvReflectionType(
210                 NonSemanticClspvReflectionPushConstantGlobalSize);
211         }
212 
getEnqueuedLocalSizeRangeDeviceProgramData213         inline const VkPushConstantRange *getEnqueuedLocalSizeRange() const
214         {
215             return getPushConstantRangeFromClspvReflectionType(
216                 NonSemanticClspvReflectionPushConstantEnqueuedLocalSize);
217         }
218 
getNumWorkgroupsRangeDeviceProgramData219         inline const VkPushConstantRange *getNumWorkgroupsRange() const
220         {
221             return getPushConstantRangeFromClspvReflectionType(
222                 NonSemanticClspvReflectionPushConstantNumWorkgroups);
223         }
224 
getRegionOffsetRangeDeviceProgramData225         inline const VkPushConstantRange *getRegionOffsetRange() const
226         {
227             return getPushConstantRangeFromClspvReflectionType(
228                 NonSemanticClspvReflectionPushConstantRegionOffset);
229         }
230 
getRegionGroupOffsetRangeDeviceProgramData231         inline const VkPushConstantRange *getRegionGroupOffsetRange() const
232         {
233             return getPushConstantRangeFromClspvReflectionType(
234                 NonSemanticClspvReflectionPushConstantRegionGroupOffset);
235         }
236 
getImageDataChannelOrderRangeDeviceProgramData237         const VkPushConstantRange *getImageDataChannelOrderRange(size_t ordinal) const
238         {
239             const VkPushConstantRange *pushConstantRangePtr = nullptr;
240             if (reflectionData.imagePushConstants.contains(
241                     NonSemanticClspvReflectionImageArgumentInfoChannelOrderPushConstant))
242             {
243                 for (const auto &imageConstant : reflectionData.imagePushConstants.at(
244                          NonSemanticClspvReflectionImageArgumentInfoChannelOrderPushConstant))
245                 {
246                     if (static_cast<size_t>(imageConstant.ordinal) == ordinal)
247                     {
248                         pushConstantRangePtr = &imageConstant.pcRange;
249                     }
250                 }
251             }
252             return pushConstantRangePtr;
253         }
254 
getImageDataChannelDataTypeRangeDeviceProgramData255         const VkPushConstantRange *getImageDataChannelDataTypeRange(size_t ordinal) const
256         {
257             const VkPushConstantRange *pushConstantRangePtr = nullptr;
258             if (reflectionData.imagePushConstants.contains(
259                     NonSemanticClspvReflectionImageArgumentInfoChannelDataTypePushConstant))
260             {
261                 for (const auto &imageConstant : reflectionData.imagePushConstants.at(
262                          NonSemanticClspvReflectionImageArgumentInfoChannelDataTypePushConstant))
263                 {
264                     if (static_cast<size_t>(imageConstant.ordinal) == ordinal)
265                     {
266                         pushConstantRangePtr = &imageConstant.pcRange;
267                     }
268                 }
269             }
270             return pushConstantRangePtr;
271         }
272 
getNormalizedSamplerMaskRangeDeviceProgramData273         const VkPushConstantRange *getNormalizedSamplerMaskRange(size_t ordinal) const
274         {
275             const VkPushConstantRange *pushConstantRangePtr = nullptr;
276             if (reflectionData.imagePushConstants.contains(
277                     NonSemanticClspvReflectionNormalizedSamplerMaskPushConstant))
278             {
279                 for (const auto &imageConstant : reflectionData.imagePushConstants.at(
280                          NonSemanticClspvReflectionNormalizedSamplerMaskPushConstant))
281                 {
282                     if (static_cast<size_t>(imageConstant.ordinal) == ordinal)
283                     {
284                         pushConstantRangePtr = &imageConstant.pcRange;
285                     }
286                 }
287             }
288             return pushConstantRangePtr;
289         }
290     };
291     using DevicePrograms   = angle::HashMap<const _cl_device_id *, DeviceProgramData>;
292     using LinkPrograms     = std::vector<const DeviceProgramData *>;
293     using LinkProgramsList = std::vector<LinkPrograms>;
294 
295     CLProgramVk(const cl::Program &program);
296 
297     ~CLProgramVk() override;
298 
299     angle::Result init();
300     angle::Result init(const size_t *lengths, const unsigned char **binaries, cl_int *binaryStatus);
301 
302     angle::Result build(const cl::DevicePtrs &devices,
303                         const char *options,
304                         cl::Program *notify) override;
305 
306     angle::Result compile(const cl::DevicePtrs &devices,
307                           const char *options,
308                           const cl::ProgramPtrs &inputHeaders,
309                           const char **headerIncludeNames,
310                           cl::Program *notify) override;
311 
312     angle::Result getInfo(cl::ProgramInfo name,
313                           size_t valueSize,
314                           void *value,
315                           size_t *valueSizeRet) const override;
316 
317     angle::Result getBuildInfo(const cl::Device &device,
318                                cl::ProgramBuildInfo name,
319                                size_t valueSize,
320                                void *value,
321                                size_t *valueSizeRet) const override;
322 
323     angle::Result createKernel(const cl::Kernel &kernel,
324                                const char *name,
325                                CLKernelImpl::Ptr *kernelOut) override;
326 
327     angle::Result createKernels(cl_uint numKernels,
328                                 CLKernelImpl::CreateFuncs &createFuncs,
329                                 cl_uint *numKernelsRet) override;
330 
331     const DeviceProgramData *getDeviceProgramData(const char *kernelName) const;
332     const DeviceProgramData *getDeviceProgramData(const _cl_device_id *device) const;
getPlatform()333     CLPlatformVk *getPlatform() { return mContext->getPlatform(); }
getShaderModule()334     const vk::ShaderModulePtr &getShaderModule() const { return mShader; }
335 
336     bool buildInternal(const cl::DevicePtrs &devices,
337                        std::string options,
338                        std::string internalOptions,
339                        BuildType buildType,
340                        const LinkProgramsList &LinkProgramsList);
341     angle::spirv::Blob stripReflection(const DeviceProgramData *deviceProgramData);
342 
343     angle::Result allocateDescriptorSet(const DescriptorSetIndex setIndex,
344                                         const vk::DescriptorSetLayout &descriptorSetLayout,
345                                         vk::CommandBufferHelperCommon *commandBuffer,
346                                         vk::DescriptorSetPointer *descriptorSetOut);
347 
348     // Sets the status for given associated device programs
349     void setBuildStatus(const cl::DevicePtrs &devices, cl_build_status status);
350 
getMetaDescriptorPool(DescriptorSetIndex index)351     vk::MetaDescriptorPool &getMetaDescriptorPool(DescriptorSetIndex index)
352     {
353         return mMetaDescriptorPools[index];
354     }
355 
getDynamicDescriptorPoolPointer(DescriptorSetIndex index)356     vk::DynamicDescriptorPoolPointer &getDynamicDescriptorPoolPointer(DescriptorSetIndex index)
357     {
358         return mDynamicDescriptorPools[index];
359     }
360 
361     const angle::HashMap<uint32_t, ClspvPrintfInfo> *getPrintfDescriptors(
362         const std::string &kernelName) const;
363 
364   private:
365     CLContextVk *mContext;
366     std::string mProgramOpts;
367     vk::ShaderModulePtr mShader;
368     DevicePrograms mAssociatedDevicePrograms;
369     vk::DescriptorSetArray<vk::MetaDescriptorPool> mMetaDescriptorPools;
370     vk::DescriptorSetArray<vk::DynamicDescriptorPoolPointer> mDynamicDescriptorPools;
371     angle::SimpleMutex mProgramMutex;
372 
373     std::shared_ptr<angle::WaitableEvent> mAsyncBuildEvent;
374 };
375 
376 class CLAsyncBuildTask : public angle::Closure
377 {
378   public:
CLAsyncBuildTask(CLProgramVk * programVk,const cl::DevicePtrs & devices,std::string options,std::string internalOptions,CLProgramVk::BuildType buildType,const CLProgramVk::LinkProgramsList & LinkProgramsList,cl::Program * notify)379     CLAsyncBuildTask(CLProgramVk *programVk,
380                      const cl::DevicePtrs &devices,
381                      std::string options,
382                      std::string internalOptions,
383                      CLProgramVk::BuildType buildType,
384                      const CLProgramVk::LinkProgramsList &LinkProgramsList,
385                      cl::Program *notify)
386         : mProgramVk(programVk),
387           mDevices(devices),
388           mOptions(options),
389           mInternalOptions(internalOptions),
390           mBuildType(buildType),
391           mLinkProgramsList(LinkProgramsList),
392           mNotify(notify)
393     {}
394 
395     void operator()() override;
396 
397   private:
398     CLProgramVk *mProgramVk;
399     const cl::DevicePtrs mDevices;
400     std::string mOptions;
401     std::string mInternalOptions;
402     CLProgramVk::BuildType mBuildType;
403     const CLProgramVk::LinkProgramsList mLinkProgramsList;
404     cl::Program *mNotify;
405 };
406 
407 }  // namespace rx
408 
409 #endif  // LIBANGLE_RENDERER_VULKAN_CLPROGRAMVK_H_
410