xref: /aosp_15_r20/external/angle/src/libANGLE/renderer/vulkan/clspv_utils.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Utilities to map clspv interface variables to OpenCL and Vulkan mappings.
7 //
8 
9 #include "libANGLE/renderer/vulkan/clspv_utils.h"
10 #include "libANGLE/renderer/vulkan/CLDeviceVk.h"
11 
12 #include <string>
13 
14 #include "common/string_utils.h"
15 
16 #include "CL/cl_half.h"
17 
18 namespace rx
19 {
20 constexpr std::string_view kPrintfConversionSpecifiers = "diouxXfFeEgGaAcsp";
21 constexpr std::string_view kPrintfFlagsSpecifiers      = "-+ #0";
22 constexpr std::string_view kPrintfPrecisionSpecifiers  = "123456789.";
23 constexpr std::string_view kPrintfVectorSizeSpecifiers = "23468";
24 
25 namespace
26 {
27 
28 template <typename T>
ReadPtrAs(const unsigned char * data)29 T ReadPtrAs(const unsigned char *data)
30 {
31     return *(reinterpret_cast<const T *>(data));
32 }
33 
34 template <typename T>
ReadPtrAsAndIncrement(unsigned char * & data)35 T ReadPtrAsAndIncrement(unsigned char *&data)
36 {
37     T out = *(reinterpret_cast<T *>(data));
38     data += sizeof(T);
39     return out;
40 }
41 
getPrintfConversionSpecifier(std::string_view formatString)42 char getPrintfConversionSpecifier(std::string_view formatString)
43 {
44     return formatString.at(formatString.find_first_of(kPrintfConversionSpecifiers));
45 }
46 
IsVectorFormat(std::string_view formatString)47 bool IsVectorFormat(std::string_view formatString)
48 {
49     ASSERT(formatString.at(0) == '%');
50 
51     // go past the flags, field width and precision
52     size_t pos = formatString.find_first_not_of(kPrintfFlagsSpecifiers, 1ul);
53     pos        = formatString.find_first_not_of(kPrintfPrecisionSpecifiers, pos);
54 
55     return (formatString.at(pos) == 'v');
56 }
57 
58 // Printing an individual formatted string into a std::string
59 // snprintf is used for parsing as OpenCL C printf is similar to printf
PrintFormattedString(const std::string & formatString,const unsigned char * data,size_t size)60 std::string PrintFormattedString(const std::string &formatString,
61                                  const unsigned char *data,
62                                  size_t size)
63 {
64     ASSERT(std::count(formatString.begin(), formatString.end(), '%') == 1);
65 
66     size_t outSize = 1024;
67     std::vector<char> out(outSize);
68     out[0] = '\0';
69 
70     char conversion = std::tolower(getPrintfConversionSpecifier(formatString));
71     bool finished   = false;
72     while (!finished)
73     {
74         int bytesWritten = 0;
75         switch (conversion)
76         {
77             case 's':
78             {
79                 bytesWritten = snprintf(out.data(), outSize, formatString.c_str(), data);
80                 break;
81             }
82             case 'f':
83             case 'e':
84             case 'g':
85             case 'a':
86             {
87                 // all floats with same convention as snprintf
88                 if (size == 2)
89                     bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
90                                             cl_half_to_float(ReadPtrAs<cl_half>(data)));
91                 else if (size == 4)
92                     bytesWritten =
93                         snprintf(out.data(), outSize, formatString.c_str(), ReadPtrAs<float>(data));
94                 else
95                     bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
96                                             ReadPtrAs<double>(data));
97                 break;
98             }
99             default:
100             {
101                 if (size == 1)
102                     bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
103                                             ReadPtrAs<uint8_t>(data));
104                 else if (size == 2)
105                     bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
106                                             ReadPtrAs<uint16_t>(data));
107                 else if (size == 4)
108                     bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
109                                             ReadPtrAs<uint32_t>(data));
110                 else
111                     bytesWritten = snprintf(out.data(), outSize, formatString.c_str(),
112                                             ReadPtrAs<uint64_t>(data));
113                 break;
114             }
115         }
116         if (bytesWritten < 0)
117         {
118             out[0]   = '\0';
119             finished = true;
120         }
121         else if (bytesWritten < static_cast<long>(outSize))
122         {
123             finished = true;
124         }
125         else
126         {
127             // insufficient size redo above post increment of size
128             outSize *= 2;
129             out.resize(outSize);
130         }
131     }
132 
133     return std::string(out.data());
134 }
135 
136 // Spec mention vn modifier to be printed in the form v1,v2...vn
PrintVectorFormatIntoString(std::string formatString,const unsigned char * data,const uint32_t size)137 std::string PrintVectorFormatIntoString(std::string formatString,
138                                         const unsigned char *data,
139                                         const uint32_t size)
140 {
141     ASSERT(IsVectorFormat(formatString));
142 
143     size_t conversionPos = formatString.find_first_of(kPrintfConversionSpecifiers);
144     // keep everything after conversion specifier in remainingFormat
145     std::string remainingFormat = formatString.substr(conversionPos + 1);
146     formatString                = formatString.substr(0, conversionPos + 1);
147 
148     size_t vectorPos       = formatString.find_first_of('v');
149     size_t vectorLengthPos = ++vectorPos;
150     size_t vectorLengthPosEnd =
151         formatString.find_first_not_of(kPrintfVectorSizeSpecifiers, vectorLengthPos);
152 
153     std::string preVectorString  = formatString.substr(0, vectorPos - 1);
154     std::string postVectorString = formatString.substr(vectorLengthPosEnd, formatString.size());
155     std::string vectorLengthStr  = formatString.substr(vectorLengthPos, vectorLengthPosEnd);
156     int vectorLength             = std::atoi(vectorLengthStr.c_str());
157 
158     // skip the vector specifier
159     formatString = preVectorString + postVectorString;
160 
161     // Get the length modifier
162     int elementSize = 0;
163     if (postVectorString.find("hh") != std::string::npos)
164     {
165         elementSize = 1;
166     }
167     else if (postVectorString.find("hl") != std::string::npos)
168     {
169         elementSize = 4;
170         // snprintf doesn't recognize the hl modifier so strip it
171         size_t hl = formatString.find("hl");
172         formatString.erase(hl, 2);
173     }
174     else if (postVectorString.find("h") != std::string::npos)
175     {
176         elementSize = 2;
177     }
178     else if (postVectorString.find("l") != std::string::npos)
179     {
180         elementSize = 8;
181     }
182     else
183     {
184         WARN() << "Vector specifier is used without a length modifier. Guessing it from "
185                   "vector length and argument sizes in PrintInfo. Kernel modification is "
186                   "recommended.";
187         elementSize = size / vectorLength;
188     }
189 
190     std::string out{""};
191     for (int i = 0; i < vectorLength - 1; i++)
192     {
193         out += PrintFormattedString(formatString, data, size / vectorLength) + ",";
194         data += elementSize;
195     }
196     out += PrintFormattedString(formatString, data, size / vectorLength) + remainingFormat;
197 
198     return out;
199 }
200 
201 // Process the printf stream by breaking them down into individual format specifier and processing
202 // them.
ProcessPrintfStatement(unsigned char * & data,const angle::HashMap<uint32_t,ClspvPrintfInfo> * descs,const unsigned char * dataEnd)203 void ProcessPrintfStatement(unsigned char *&data,
204                             const angle::HashMap<uint32_t, ClspvPrintfInfo> *descs,
205                             const unsigned char *dataEnd)
206 {
207     // printf storage buffer contents - | id | formatString | argSizes... |
208     uint32_t printfID               = ReadPtrAsAndIncrement<uint32_t>(data);
209     const std::string &formatString = descs->at(printfID).formatSpecifier;
210 
211     std::string printfOutput = "";
212 
213     // formatString could be "<string literal> <% format specifiers ...> <string literal>"
214     // print the literal part if any first
215     size_t nextFormatSpecPos = formatString.find_first_of('%');
216     printfOutput += formatString.substr(0, nextFormatSpecPos);
217 
218     // print each <% format specifier> + any string literal separately using snprintf
219     size_t idx = 0;
220     while (nextFormatSpecPos < formatString.size() - 1)
221     {
222         // Get the part of the format string before the next format specifier
223         size_t partStart             = nextFormatSpecPos;
224         size_t partEnd               = formatString.find_first_of('%', partStart + 1);
225         std::string partFormatString = formatString.substr(partStart, partEnd - partStart);
226 
227         // Handle special cases
228         if (partEnd == partStart + 1)
229         {
230             printfOutput += "%";
231             nextFormatSpecPos = partEnd + 1;
232             continue;
233         }
234         else if (partEnd == std::string::npos && idx >= descs->at(printfID).argSizes.size())
235         {
236             // If there are no remaining arguments, the rest of the format
237             // should be printed verbatim
238             printfOutput += partFormatString;
239             break;
240         }
241 
242         // The size of the argument that this format part will consume
243         const uint32_t &size = descs->at(printfID).argSizes[idx];
244 
245         if (data + size > dataEnd)
246         {
247             data += size;
248             return;
249         }
250 
251         // vector format need special care for snprintf
252         if (!IsVectorFormat(partFormatString))
253         {
254             // not a vector format can be printed through snprintf
255             // except for %s
256             if (getPrintfConversionSpecifier(partFormatString) == 's')
257             {
258                 uint32_t stringID = ReadPtrAs<uint32_t>(data);
259                 printfOutput +=
260                     PrintFormattedString(partFormatString,
261                                          reinterpret_cast<const unsigned char *>(
262                                              descs->at(stringID).formatSpecifier.c_str()),
263                                          size);
264             }
265             else
266             {
267                 printfOutput += PrintFormattedString(partFormatString, data, size);
268             }
269             data += size;
270         }
271         else
272         {
273             printfOutput += PrintVectorFormatIntoString(partFormatString, data, size);
274             data += size;
275         }
276 
277         // Move to the next format part and prepare to handle the next arg
278         nextFormatSpecPos = partEnd;
279         idx++;
280     }
281 
282     std::printf("%s", printfOutput.c_str());
283 }
284 
285 }  // namespace
286 
287 // Process the data recorded into printf storage buffer along with the info in printfino descriptor
288 // and write it to stdout.
ClspvProcessPrintfBuffer(unsigned char * buffer,const size_t bufferSize,const angle::HashMap<uint32_t,ClspvPrintfInfo> * infoMap)289 angle::Result ClspvProcessPrintfBuffer(unsigned char *buffer,
290                                        const size_t bufferSize,
291                                        const angle::HashMap<uint32_t, ClspvPrintfInfo> *infoMap)
292 {
293     // printf storage buffer contains a series of uint32_t values
294     // the first integer is offset from second to next available free memory -- this is the amount
295     // of data written by kernel.
296     const size_t bytesWritten = ReadPtrAsAndIncrement<uint32_t>(buffer) * sizeof(uint32_t);
297     const size_t dataSize     = bufferSize - sizeof(uint32_t);
298     const size_t limit        = std::min(bytesWritten, dataSize);
299 
300     const unsigned char *dataEnd = buffer + limit;
301     while (buffer < dataEnd)
302     {
303         ProcessPrintfStatement(buffer, infoMap, dataEnd);
304     }
305 
306     if (bufferSize < bytesWritten)
307     {
308         WARN() << "Printf storage buffer was not sufficient for all printfs. Around "
309                << 100.0 * (float)(bytesWritten - bufferSize) / bytesWritten
310                << "% of them have been skipped.";
311     }
312 
313     return angle::Result::Continue;
314 }
315 
ClspvGetCompilerOptions(const CLDeviceVk * device)316 std::string ClspvGetCompilerOptions(const CLDeviceVk *device)
317 {
318     ASSERT(device && device->getRenderer());
319     const vk::Renderer *rendererVk = device->getRenderer();
320     std::string options{""};
321 
322     cl_uint addressBits;
323     if (IsError(device->getInfoUInt(cl::DeviceInfo::AddressBits, &addressBits)))
324     {
325         // This should'nt fail here
326         ASSERT(false);
327     }
328     options += addressBits == 64 ? " -arch=spir64" : " -arch=spir";
329 
330     // Other internal Clspv compiler flags that are needed/required
331     options += " --long-vector";
332     options += " --global-offset";
333     options += " --enable-printf";
334 
335     // 8 bit storage buffer support
336     if (!rendererVk->getFeatures().supports8BitStorageBuffer.enabled)
337     {
338         options += " --no-8bit-storage=ssbo";
339     }
340     if (!rendererVk->getFeatures().supports8BitUniformAndStorageBuffer.enabled)
341     {
342         options += " --no-8bit-storage=ubo";
343     }
344     if (!rendererVk->getFeatures().supports8BitPushConstant.enabled)
345     {
346         options += " --no-8bit-storage=pushconstant";
347     }
348 
349     // 16 bit storage options
350     if (!rendererVk->getFeatures().supports16BitStorageBuffer.enabled)
351     {
352         options += " --no-16bit-storage=ssbo";
353     }
354     if (!rendererVk->getFeatures().supports16BitUniformAndStorageBuffer.enabled)
355     {
356         options += " --no-16bit-storage=ubo";
357     }
358     if (!rendererVk->getFeatures().supports16BitPushConstant.enabled)
359     {
360         options += " --no-16bit-storage=pushconstant";
361     }
362 
363     return options;
364 }
365 
366 }  // namespace rx
367