xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/clover/spirv/invocation.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 //
2 // Copyright 2018 Pierre Moreau
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a
5 // copy of this software and associated documentation files (the "Software"),
6 // to deal in the Software without restriction, including without limitation
7 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 // and/or sell copies of the Software, and to permit persons to whom the
9 // Software is furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
17 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
18 // OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
19 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
20 // OTHER DEALINGS IN THE SOFTWARE.
21 //
22 
23 #include "invocation.hpp"
24 
25 #include <limits>
26 #include <unordered_map>
27 #include <unordered_set>
28 #include <vector>
29 
30 #ifdef HAVE_CLOVER_SPIRV
31 #include <spirv-tools/libspirv.hpp>
32 #include <spirv-tools/linker.hpp>
33 #endif
34 
35 #include "core/error.hpp"
36 #include "core/platform.hpp"
37 #include "invocation.hpp"
38 #include "llvm/util.hpp"
39 #include "pipe/p_state.h"
40 #include "util/algorithm.hpp"
41 #include "util/functional.hpp"
42 #include "util/u_math.h"
43 
44 #include "compiler/spirv/spirv.h"
45 
46 #define SPIRV_HEADER_WORD_SIZE 5
47 
48 using namespace clover;
49 
50 using clover::detokenize;
51 
52 #ifdef HAVE_CLOVER_SPIRV
53 namespace {
54 
55    static const std::array<std::string,7> type_strs = {
56       "uchar", "ushort", "uint", "ulong", "half", "float", "double"
57    };
58 
59    template<typename T>
get(const char * source,size_t index)60    T get(const char *source, size_t index) {
61       const uint32_t *word_ptr = reinterpret_cast<const uint32_t *>(source);
62       return static_cast<T>(word_ptr[index]);
63    }
64 
65    enum binary::argument::type
convert_storage_class(SpvStorageClass storage_class,std::string & err)66    convert_storage_class(SpvStorageClass storage_class, std::string &err) {
67       switch (storage_class) {
68       case SpvStorageClassFunction:
69          return binary::argument::scalar;
70       case SpvStorageClassUniformConstant:
71          return binary::argument::global;
72       case SpvStorageClassWorkgroup:
73          return binary::argument::local;
74       case SpvStorageClassCrossWorkgroup:
75          return binary::argument::global;
76       default:
77          err += "Invalid storage type " + std::to_string(storage_class) + "\n";
78          throw build_error();
79       }
80    }
81 
82    cl_kernel_arg_address_qualifier
convert_storage_class_to_cl(SpvStorageClass storage_class)83    convert_storage_class_to_cl(SpvStorageClass storage_class) {
84       switch (storage_class) {
85       case SpvStorageClassUniformConstant:
86          return CL_KERNEL_ARG_ADDRESS_CONSTANT;
87       case SpvStorageClassWorkgroup:
88          return CL_KERNEL_ARG_ADDRESS_LOCAL;
89       case SpvStorageClassCrossWorkgroup:
90          return CL_KERNEL_ARG_ADDRESS_GLOBAL;
91       case SpvStorageClassFunction:
92       default:
93          return CL_KERNEL_ARG_ADDRESS_PRIVATE;
94       }
95    }
96 
97    enum binary::argument::type
convert_image_type(SpvId id,SpvDim dim,SpvAccessQualifier access,std::string & err)98    convert_image_type(SpvId id, SpvDim dim, SpvAccessQualifier access,
99                       std::string &err) {
100       switch (dim) {
101       case SpvDim1D:
102       case SpvDim2D:
103       case SpvDim3D:
104       case SpvDimBuffer:
105          switch (access) {
106          case SpvAccessQualifierReadOnly:
107             return binary::argument::image_rd;
108          case SpvAccessQualifierWriteOnly:
109             return binary::argument::image_wr;
110          default:
111             err += "Unknown access qualifier " + std::to_string(access) + " for image "
112                 +  std::to_string(id) + ".\n";
113             throw build_error();
114          }
115       default:
116          err += "Unknown dimension " + std::to_string(dim) + " for image "
117              +  std::to_string(id) + ".\n";
118          throw build_error();
119       }
120    }
121 
122    binary::section
make_text_section(const std::string & code,enum binary::section::type section_type)123    make_text_section(const std::string &code,
124                      enum binary::section::type section_type) {
125       const pipe_binary_program_header header { uint32_t(code.size()) };
126       binary::section text { 0, section_type, header.num_bytes, {} };
127 
128       text.data.reserve(sizeof(header) + header.num_bytes);
129       text.data.insert(text.data.end(), reinterpret_cast<const char *>(&header),
130                        reinterpret_cast<const char *>(&header) + sizeof(header));
131       text.data.insert(text.data.end(), code.begin(), code.end());
132 
133       return text;
134    }
135 
136    binary
create_binary_from_spirv(const std::string & source,size_t pointer_byte_size,std::string & err)137    create_binary_from_spirv(const std::string &source,
138                             size_t pointer_byte_size,
139                             std::string &err) {
140       const size_t length = source.size() / sizeof(uint32_t);
141       size_t i = SPIRV_HEADER_WORD_SIZE; // Skip header
142 
143       std::string kernel_name;
144       size_t kernel_nb = 0u;
145       std::vector<binary::argument> args;
146       std::vector<size_t> req_local_size;
147 
148       binary b;
149 
150       std::vector<std::string> attributes;
151       std::unordered_map<SpvId, std::vector<size_t> > req_local_sizes;
152       std::unordered_map<SpvId, std::string> kernels;
153       std::unordered_map<SpvId, binary::argument> types;
154       std::unordered_map<SpvId, SpvId> pointer_types;
155       std::unordered_map<SpvId, unsigned int> constants;
156       std::unordered_set<SpvId> packed_structures;
157       std::unordered_map<SpvId, std::vector<SpvFunctionParameterAttribute>>
158          func_param_attr_map;
159       std::unordered_map<SpvId, std::string> names;
160       std::unordered_map<SpvId, cl_kernel_arg_type_qualifier> qualifiers;
161       std::unordered_map<std::string, std::vector<std::string> > param_type_names;
162       std::unordered_map<std::string, std::vector<std::string> > param_qual_names;
163 
164       while (i < length) {
165          const auto inst = &source[i * sizeof(uint32_t)];
166          const auto desc_word = get<uint32_t>(inst, 0);
167          const auto opcode = static_cast<SpvOp>(desc_word & SpvOpCodeMask);
168          const unsigned int num_operands = desc_word >> SpvWordCountShift;
169 
170          switch (opcode) {
171          case SpvOpName: {
172             names.emplace(get<SpvId>(inst, 1),
173                           source.data() + (i + 2u) * sizeof(uint32_t));
174             break;
175          }
176 
177          case SpvOpString: {
178             // SPIRV-LLVM-Translator stores param type names as OpStrings
179             std::string str(source.data() + (i + 2u) * sizeof(uint32_t));
180             if (str.find("kernel_arg_type.") == 0) {
181                std::string line;
182                std::istringstream istream(str.substr(16));
183 
184                std::getline(istream, line, '.');
185 
186                std::string k = line;
187                while (std::getline(istream, line, ','))
188                   param_type_names[k].push_back(line);
189             } else if (str.find("kernel_arg_type_qual.") == 0) {
190                std::string line;
191                std::istringstream istream(str.substr(21));
192 
193                std::getline(istream, line, '.');
194                std::string k = line;
195                while (std::getline(istream, line, ','))
196                   param_qual_names[k].push_back(line);
197             } else
198                continue;
199             break;
200          }
201 
202          case SpvOpEntryPoint:
203             if (get<SpvExecutionModel>(inst, 1) == SpvExecutionModelKernel)
204                kernels.emplace(get<SpvId>(inst, 2),
205                                source.data() + (i + 3u) * sizeof(uint32_t));
206             break;
207 
208          case SpvOpExecutionMode:
209             switch (get<SpvExecutionMode>(inst, 2)) {
210             case SpvExecutionModeLocalSize: {
211                req_local_sizes[get<SpvId>(inst, 1)] = {
212                   get<uint32_t>(inst, 3),
213                   get<uint32_t>(inst, 4),
214                   get<uint32_t>(inst, 5)
215                };
216                std::string s = "reqd_work_group_size(";
217                s += std::to_string(get<uint32_t>(inst, 3));
218                s += ",";
219                s += std::to_string(get<uint32_t>(inst, 4));
220                s += ",";
221                s += std::to_string(get<uint32_t>(inst, 5));
222                s += ")";
223                attributes.emplace_back(s);
224                break;
225             }
226             case SpvExecutionModeLocalSizeHint: {
227                std::string s = "work_group_size_hint(";
228                s += std::to_string(get<uint32_t>(inst, 3));
229                s += ",";
230                s += std::to_string(get<uint32_t>(inst, 4));
231                s += ",";
232                s += std::to_string(get<uint32_t>(inst, 5));
233                s += ")";
234                attributes.emplace_back(s);
235                break;
236             }
237 	    case SpvExecutionModeVecTypeHint: {
238                uint32_t val = get<uint32_t>(inst, 3);
239                uint32_t size = val >> 16;
240 
241                val &= 0xf;
242                if (val > 6)
243                   val = 0;
244                std::string s = "vec_type_hint(";
245                s += type_strs[val];
246                s += std::to_string(size);
247                s += ")";
248                attributes.emplace_back(s);
249 	       break;
250             }
251             default:
252                break;
253             }
254             break;
255 
256          case SpvOpDecorate: {
257             const auto id = get<SpvId>(inst, 1);
258             const auto decoration = get<SpvDecoration>(inst, 2);
259             switch (decoration) {
260             case SpvDecorationCPacked:
261                packed_structures.emplace(id);
262                break;
263             case SpvDecorationFuncParamAttr: {
264                const auto attribute =
265                   get<SpvFunctionParameterAttribute>(inst, 3u);
266                func_param_attr_map[id].push_back(attribute);
267                break;
268             }
269             case SpvDecorationVolatile:
270                qualifiers[id] |= CL_KERNEL_ARG_TYPE_VOLATILE;
271                break;
272             default:
273                break;
274             }
275             break;
276          }
277 
278          case SpvOpGroupDecorate: {
279             const auto group_id = get<SpvId>(inst, 1);
280             if (packed_structures.count(group_id)) {
281                for (unsigned int i = 2u; i < num_operands; ++i)
282                   packed_structures.emplace(get<SpvId>(inst, i));
283             }
284             const auto func_param_attr_iter =
285                func_param_attr_map.find(group_id);
286             if (func_param_attr_iter != func_param_attr_map.end()) {
287                for (unsigned int i = 2u; i < num_operands; ++i) {
288                   auto &attrs = func_param_attr_map[get<SpvId>(inst, i)];
289                   attrs.insert(attrs.begin(),
290                                func_param_attr_iter->second.begin(),
291                                func_param_attr_iter->second.end());
292                }
293             }
294             if (qualifiers.count(group_id)) {
295                for (unsigned int i = 2u; i < num_operands; ++i)
296                   qualifiers[get<SpvId>(inst, i)] |= qualifiers[group_id];
297             }
298             break;
299          }
300 
301          case SpvOpConstant:
302             // We only care about constants that represent the size of arrays.
303             // If they are passed as argument, they will never be more than
304             // 4GB-wide, and even if they did, a clover::binary::argument size
305             // is represented by an int.
306             constants[get<SpvId>(inst, 2)] = get<unsigned int>(inst, 3u);
307             break;
308 
309          case SpvOpTypeInt:
310          case SpvOpTypeFloat: {
311             const auto size = get<uint32_t>(inst, 2) / 8u;
312             const auto id = get<SpvId>(inst, 1);
313             types[id] = { binary::argument::scalar, size, size, size,
314                           binary::argument::zero_ext };
315             types[id].info.address_qualifier = CL_KERNEL_ARG_ADDRESS_PRIVATE;
316             break;
317          }
318 
319          case SpvOpTypeArray: {
320             const auto id = get<SpvId>(inst, 1);
321             const auto type_id = get<SpvId>(inst, 2);
322             const auto types_iter = types.find(type_id);
323             if (types_iter == types.end())
324                break;
325 
326             const auto constant_id = get<SpvId>(inst, 3);
327             const auto constants_iter = constants.find(constant_id);
328             if (constants_iter == constants.end()) {
329                err += "Constant " + std::to_string(constant_id) +
330                   " is missing\n";
331                throw build_error();
332             }
333             const auto elem_size = types_iter->second.size;
334             const auto elem_nbs = constants_iter->second;
335             const auto size = elem_size * elem_nbs;
336             types[id] = { binary::argument::scalar, size, size,
337                           types_iter->second.target_align,
338                           binary::argument::zero_ext };
339             break;
340          }
341 
342          case SpvOpTypeStruct: {
343             const auto id = get<SpvId>(inst, 1);
344             const bool is_packed = packed_structures.count(id);
345 
346             unsigned struct_size = 0u;
347             unsigned struct_align = 1u;
348             for (unsigned j = 2u; j < num_operands; ++j) {
349                const auto type_id = get<SpvId>(inst, j);
350                const auto types_iter = types.find(type_id);
351 
352                // If a type was not found, that means it is not one of the
353                // types allowed as kernel arguments. And since the binary has
354                // been validated, this means this type is not used for kernel
355                // arguments, and therefore can be ignored.
356                if (types_iter == types.end())
357                   break;
358 
359                const auto alignment = is_packed ? 1u
360                                                 : types_iter->second.target_align;
361                const auto padding = (-struct_size) & (alignment - 1u);
362                struct_size += padding + types_iter->second.target_size;
363                struct_align = std::max(struct_align, alignment);
364             }
365             struct_size += (-struct_size) & (struct_align - 1u);
366             types[id] = { binary::argument::scalar, struct_size, struct_size,
367                           struct_align, binary::argument::zero_ext };
368             break;
369          }
370 
371          case SpvOpTypeVector: {
372             const auto id = get<SpvId>(inst, 1);
373             const auto type_id = get<SpvId>(inst, 2);
374             const auto types_iter = types.find(type_id);
375 
376             // If a type was not found, that means it is not one of the
377             // types allowed as kernel arguments. And since the binary has
378             // been validated, this means this type is not used for kernel
379             // arguments, and therefore can be ignored.
380             if (types_iter == types.end())
381                break;
382 
383             const auto elem_size = types_iter->second.size;
384             const auto elem_nbs = get<uint32_t>(inst, 3);
385             const auto size = elem_size * (elem_nbs != 3 ? elem_nbs : 4);
386             types[id] = { binary::argument::scalar, size, size, size,
387                           binary::argument::zero_ext };
388             types[id].info.address_qualifier = CL_KERNEL_ARG_ADDRESS_PRIVATE;
389             break;
390          }
391 
392          case SpvOpTypeForwardPointer: // FALLTHROUGH
393          case SpvOpTypePointer: {
394             const auto id = get<SpvId>(inst, 1);
395             const auto storage_class = get<SpvStorageClass>(inst, 2);
396             // Input means this is for a builtin variable, which can not be
397             // passed as an argument to a kernel.
398             if (storage_class == SpvStorageClassInput)
399                break;
400 
401             if (opcode == SpvOpTypePointer)
402                pointer_types[id] = get<SpvId>(inst, 3);
403 
404             binary::size_t alignment;
405             if (storage_class == SpvStorageClassWorkgroup)
406                alignment = opcode == SpvOpTypePointer ? types[pointer_types[id]].target_align : 0;
407             else
408                alignment = pointer_byte_size;
409 
410             types[id] = { convert_storage_class(storage_class, err),
411                           sizeof(cl_mem),
412                           static_cast<binary::size_t>(pointer_byte_size),
413                           alignment,
414                           binary::argument::zero_ext };
415             types[id].info.address_qualifier = convert_storage_class_to_cl(storage_class);
416             break;
417          }
418 
419          case SpvOpTypeSampler:
420             types[get<SpvId>(inst, 1)] = { binary::argument::sampler,
421                                              sizeof(cl_sampler) };
422             break;
423 
424          case SpvOpTypeImage: {
425             const auto id = get<SpvId>(inst, 1);
426             const auto dim = get<SpvDim>(inst, 3);
427             const auto access = get<SpvAccessQualifier>(inst, 9);
428             types[id] = { convert_image_type(id, dim, access, err),
429                           sizeof(cl_mem), sizeof(cl_mem), sizeof(cl_mem),
430                           binary::argument::zero_ext };
431             break;
432          }
433 
434          case SpvOpTypePipe: // FALLTHROUGH
435          case SpvOpTypeQueue: {
436             err += "TypePipe and TypeQueue are valid SPIR-V 1.0 types, but are "
437                    "not available in the currently supported OpenCL C version."
438                    "\n";
439             throw build_error();
440          }
441 
442          case SpvOpFunction: {
443             auto id = get<SpvId>(inst, 2);
444             const auto kernels_iter = kernels.find(id);
445             if (kernels_iter != kernels.end())
446                kernel_name = kernels_iter->second;
447 
448             const auto req_local_size_iter = req_local_sizes.find(id);
449             if (req_local_size_iter != req_local_sizes.end())
450                req_local_size =  (*req_local_size_iter).second;
451             else
452                req_local_size = { 0, 0, 0 };
453 
454             break;
455          }
456 
457          case SpvOpFunctionParameter: {
458             if (kernel_name.empty())
459                break;
460 
461             const auto id = get<SpvId>(inst, 2);
462             const auto type_id = get<SpvId>(inst, 1);
463             auto arg = types.find(type_id)->second;
464             const auto &func_param_attr_iter =
465                func_param_attr_map.find(get<SpvId>(inst, 2));
466             if (func_param_attr_iter != func_param_attr_map.end()) {
467                for (auto &i : func_param_attr_iter->second) {
468                   switch (i) {
469                   case SpvFunctionParameterAttributeSext:
470                      arg.ext_type = binary::argument::sign_ext;
471                      break;
472                   case SpvFunctionParameterAttributeZext:
473                      arg.ext_type = binary::argument::zero_ext;
474                      break;
475                   case SpvFunctionParameterAttributeByVal: {
476                      const SpvId ptr_type_id =
477                         pointer_types.find(type_id)->second;
478                      arg = types.find(ptr_type_id)->second;
479                      break;
480                   }
481                   case SpvFunctionParameterAttributeNoAlias:
482                      arg.info.type_qualifier |= CL_KERNEL_ARG_TYPE_RESTRICT;
483                      break;
484                   case SpvFunctionParameterAttributeNoWrite:
485                      arg.info.type_qualifier |= CL_KERNEL_ARG_TYPE_CONST;
486                      break;
487                   default:
488                      break;
489                   }
490                }
491             }
492 
493             auto name_it = names.find(id);
494             if (name_it != names.end())
495                arg.info.arg_name = (*name_it).second;
496 
497             arg.info.type_qualifier |= qualifiers[id];
498             arg.info.address_qualifier = types[type_id].info.address_qualifier;
499             arg.info.access_qualifier = CL_KERNEL_ARG_ACCESS_NONE;
500             args.emplace_back(arg);
501             break;
502          }
503 
504          case SpvOpFunctionEnd: {
505             if (kernel_name.empty())
506                break;
507 
508             for (size_t i = 0; i < param_type_names[kernel_name].size(); i++)
509                args[i].info.type_name = param_type_names[kernel_name][i];
510 
511             for (size_t i = 0; i < param_qual_names[kernel_name].size(); i++)
512                if (param_qual_names[kernel_name][i].find("const") != std::string::npos)
513                   args[i].info.type_qualifier |= CL_KERNEL_ARG_TYPE_CONST;
514             b.syms.emplace_back(kernel_name, detokenize(attributes, " "),
515                                 req_local_size, 0, kernel_nb, args);
516             ++kernel_nb;
517             kernel_name.clear();
518             args.clear();
519             attributes.clear();
520             break;
521          }
522          default:
523             break;
524          }
525 
526          i += num_operands;
527       }
528 
529       b.secs.push_back(make_text_section(source,
530                                          binary::section::text_intermediate));
531       return b;
532    }
533 
534    bool
check_spirv_version(const device & dev,const char * binary,std::string & r_log)535    check_spirv_version(const device &dev, const char *binary,
536                        std::string &r_log) {
537       const auto spirv_version = get<uint32_t>(binary, 1u);
538       const auto supported_spirv_versions = clover::spirv::supported_versions();
539       const auto compare_versions =
540          [module_version =
541             clover::spirv::to_opencl_version_encoding(spirv_version)](const cl_name_version &supported){
542          return supported.version == module_version;
543       };
544 
545       if (std::find_if(supported_spirv_versions.cbegin(),
546                        supported_spirv_versions.cend(),
547                        compare_versions) != supported_spirv_versions.cend())
548          return true;
549 
550       r_log += "SPIR-V version " +
551                clover::spirv::version_to_string(spirv_version) +
552                " is not supported; supported versions:";
553       for (const auto &version : supported_spirv_versions) {
554          r_log += " " + clover::spirv::version_to_string(version.version);
555       }
556       r_log += "\n";
557       return false;
558    }
559 
560    bool
check_capabilities(const device & dev,const std::string & source,std::string & r_log)561    check_capabilities(const device &dev, const std::string &source,
562                       std::string &r_log) {
563       const size_t length = source.size() / sizeof(uint32_t);
564       size_t i = SPIRV_HEADER_WORD_SIZE; // Skip header
565 
566       while (i < length) {
567          const auto desc_word = get<uint32_t>(source.data(), i);
568          const auto opcode = static_cast<SpvOp>(desc_word & SpvOpCodeMask);
569          const unsigned int num_operands = desc_word >> SpvWordCountShift;
570 
571          if (opcode != SpvOpCapability)
572             break;
573 
574          const auto capability = get<SpvCapability>(source.data(), i + 1u);
575          switch (capability) {
576          // Mandatory capabilities
577          case SpvCapabilityAddresses:
578          case SpvCapabilityFloat16Buffer:
579          case SpvCapabilityGroups:
580          case SpvCapabilityInt64:
581          case SpvCapabilityInt16:
582          case SpvCapabilityInt8:
583          case SpvCapabilityKernel:
584          case SpvCapabilityLinkage:
585          case SpvCapabilityVector16:
586             break;
587          // Optional capabilities
588          case SpvCapabilityImageBasic:
589          case SpvCapabilityLiteralSampler:
590          case SpvCapabilitySampled1D:
591          case SpvCapabilityImage1D:
592          case SpvCapabilitySampledBuffer:
593          case SpvCapabilityImageBuffer:
594             if (!dev.image_support()) {
595                r_log += "Capability 'ImageBasic' is not supported.\n";
596                return false;
597             }
598             break;
599          case SpvCapabilityFloat64:
600             if (!dev.has_doubles()) {
601                r_log += "Capability 'Float64' is not supported.\n";
602                return false;
603             }
604             break;
605          // Enabled through extensions
606          case SpvCapabilityFloat16:
607             if (!dev.has_halves()) {
608                r_log += "Capability 'Float16' is not supported.\n";
609                return false;
610             }
611             break;
612          case SpvCapabilityInt64Atomics:
613             if (!dev.has_int64_atomics()) {
614                r_log += "Capability 'Int64Atomics' is not supported.\n";
615                return false;
616             }
617             break;
618          default:
619             r_log += "Capability '" + std::to_string(capability) +
620                      "' is not supported.\n";
621             return false;
622          }
623 
624          i += num_operands;
625       }
626 
627       return true;
628    }
629 
630    bool
check_extensions(const device & dev,const std::string & source,std::string & r_log)631    check_extensions(const device &dev, const std::string &source,
632                     std::string &r_log) {
633       const size_t length = source.size() / sizeof(uint32_t);
634       size_t i = SPIRV_HEADER_WORD_SIZE; // Skip header
635       const auto spirv_extensions = spirv::supported_extensions();
636 
637       while (i < length) {
638          const auto desc_word = get<uint32_t>(source.data(), i);
639          const auto opcode = static_cast<SpvOp>(desc_word & SpvOpCodeMask);
640          const unsigned int num_operands = desc_word >> SpvWordCountShift;
641 
642          if (opcode == SpvOpCapability) {
643             i += num_operands;
644             continue;
645          }
646          if (opcode != SpvOpExtension)
647             break;
648 
649          const std::string extension = source.data() + (i + 1u) * sizeof(uint32_t);
650          if (spirv_extensions.count(extension) == 0) {
651             r_log += "Extension '" + extension + "' is not supported.\n";
652             return false;
653          }
654 
655          i += num_operands;
656       }
657 
658       return true;
659    }
660 
661    bool
check_memory_model(const device & dev,const std::string & source,std::string & r_log)662    check_memory_model(const device &dev, const std::string &source,
663                       std::string &r_log) {
664       const size_t length = source.size() / sizeof(uint32_t);
665       size_t i = SPIRV_HEADER_WORD_SIZE; // Skip header
666 
667       while (i < length) {
668          const auto desc_word = get<uint32_t>(source.data(), i);
669          const auto opcode = static_cast<SpvOp>(desc_word & SpvOpCodeMask);
670          const unsigned int num_operands = desc_word >> SpvWordCountShift;
671 
672          switch (opcode) {
673          case SpvOpMemoryModel:
674             switch (get<SpvAddressingModel>(source.data(), i + 1u)) {
675             case SpvAddressingModelPhysical32:
676                return dev.address_bits() == 32;
677             case SpvAddressingModelPhysical64:
678                return dev.address_bits() == 64;
679             default:
680                unreachable("Only Physical32 and Physical64 are valid for OpenCL, and the binary was already validated");
681                return false;
682             }
683             break;
684          default:
685             break;
686          }
687 
688          i += num_operands;
689       }
690 
691       return false;
692    }
693 
694    // Copies the input binary and convert it to the endianness of the host CPU.
695    std::string
spirv_to_cpu(const std::string & binary)696    spirv_to_cpu(const std::string &binary)
697    {
698       const uint32_t first_word = get<uint32_t>(binary.data(), 0u);
699       if (first_word == SpvMagicNumber)
700          return binary;
701 
702       std::vector<char> cpu_endianness_binary(binary.size());
703       for (size_t i = 0; i < (binary.size() / 4u); ++i) {
704          const uint32_t word = get<uint32_t>(binary.data(), i);
705          reinterpret_cast<uint32_t *>(cpu_endianness_binary.data())[i] =
706             util_bswap32(word);
707       }
708 
709       return std::string(cpu_endianness_binary.begin(),
710                          cpu_endianness_binary.end());
711    }
712 
713 #ifdef HAVE_CLOVER_SPIRV
714    std::string
format_validator_msg(spv_message_level_t level,const char *,const spv_position_t & position,const char * message)715    format_validator_msg(spv_message_level_t level, const char * /* source */,
716                         const spv_position_t &position, const char *message) {
717       std::string level_str;
718       switch (level) {
719       case SPV_MSG_FATAL:
720          level_str = "Fatal";
721          break;
722       case SPV_MSG_INTERNAL_ERROR:
723          level_str = "Internal error";
724          break;
725       case SPV_MSG_ERROR:
726          level_str = "Error";
727          break;
728       case SPV_MSG_WARNING:
729          level_str = "Warning";
730          break;
731       case SPV_MSG_INFO:
732          level_str = "Info";
733          break;
734       case SPV_MSG_DEBUG:
735          level_str = "Debug";
736          break;
737       }
738       return "[" + level_str + "] At word No." +
739              std::to_string(position.index) + ": \"" + message + "\"\n";
740    }
741 
742    spv_target_env
convert_opencl_version_to_target_env(const cl_version opencl_version)743    convert_opencl_version_to_target_env(const cl_version opencl_version) {
744       // Pick 1.2 for 3.0 for now
745       if (opencl_version == CL_MAKE_VERSION(3, 0, 0)) {
746          return SPV_ENV_OPENCL_1_2;
747       } else if (opencl_version == CL_MAKE_VERSION(2, 2, 0)) {
748          return SPV_ENV_OPENCL_2_2;
749       } else if (opencl_version == CL_MAKE_VERSION(2, 1, 0)) {
750          return SPV_ENV_OPENCL_2_1;
751       } else if (opencl_version == CL_MAKE_VERSION(2, 0, 0)) {
752          return SPV_ENV_OPENCL_2_0;
753       } else if (opencl_version == CL_MAKE_VERSION(1, 2, 0) ||
754                  opencl_version == CL_MAKE_VERSION(1, 1, 0) ||
755                  opencl_version == CL_MAKE_VERSION(1, 0, 0)) {
756          // SPIR-V is only defined for OpenCL >= 1.2, however some drivers
757          // might use it with OpenCL 1.0 and 1.1.
758          return SPV_ENV_OPENCL_1_2;
759       } else {
760          throw build_error("Invalid OpenCL version");
761       }
762    }
763 #endif
764 
765 }
766 
767 bool
is_binary_spirv(const std::string & binary)768 clover::spirv::is_binary_spirv(const std::string &binary)
769 {
770    // A SPIR-V binary is at the very least 5 32-bit words, which represent the
771    // SPIR-V header.
772    if (binary.size() < 20u)
773       return false;
774 
775    const uint32_t first_word =
776       reinterpret_cast<const uint32_t *>(binary.data())[0u];
777    return (first_word == SpvMagicNumber) ||
778           (util_bswap32(first_word) == SpvMagicNumber);
779 }
780 
781 std::string
version_to_string(uint32_t version)782 clover::spirv::version_to_string(uint32_t version) {
783    const uint32_t major_version = (version >> 16) & 0xff;
784    const uint32_t minor_version = (version >> 8) & 0xff;
785    return std::to_string(major_version) + '.' +
786       std::to_string(minor_version);
787 }
788 
789 binary
compile_program(const std::string & binary,const device & dev,std::string & r_log,bool validate)790 clover::spirv::compile_program(const std::string &binary,
791                                const device &dev, std::string &r_log,
792                                bool validate) {
793    std::string source = spirv_to_cpu(binary);
794 
795    if (validate && !is_valid_spirv(source, dev.device_version(), r_log))
796       throw build_error();
797 
798    if (!check_spirv_version(dev, source.data(), r_log))
799       throw build_error();
800    if (!check_capabilities(dev, source, r_log))
801       throw build_error();
802    if (!check_extensions(dev, source, r_log))
803       throw build_error();
804    if (!check_memory_model(dev, source, r_log))
805       throw build_error();
806 
807    return create_binary_from_spirv(source,
808                                    dev.address_bits() == 32 ? 4u : 8u, r_log);
809 }
810 
811 binary
link_program(const std::vector<binary> & binaries,const device & dev,const std::string & opts,std::string & r_log)812 clover::spirv::link_program(const std::vector<binary> &binaries,
813                             const device &dev, const std::string &opts,
814                             std::string &r_log) {
815    std::vector<std::string> options = tokenize(opts);
816 
817    bool create_library = false;
818 
819    std::string ignored_options;
820    for (const std::string &option : options) {
821       if (option == "-create-library") {
822          create_library = true;
823       } else {
824          ignored_options += "'" + option + "' ";
825       }
826    }
827    if (!ignored_options.empty()) {
828       r_log += "Ignoring the following link options: " + ignored_options
829             + "\n";
830    }
831 
832    spvtools::LinkerOptions linker_options;
833    linker_options.SetCreateLibrary(create_library);
834 
835    binary b;
836 
837    const auto section_type = create_library ? binary::section::text_library :
838                                               binary::section::text_executable;
839 
840    std::vector<const uint32_t *> sections;
841    sections.reserve(binaries.size());
842    std::vector<size_t> lengths;
843    lengths.reserve(binaries.size());
844 
845    auto const validator_consumer = [&r_log](spv_message_level_t level,
846                                             const char *source,
847                                             const spv_position_t &position,
848                                             const char *message) {
849       r_log += format_validator_msg(level, source, position, message);
850    };
851 
852    for (const auto &bin : binaries) {
853       const auto &bsec = find([](const binary::section &sec) {
854                   return sec.type == binary::section::text_intermediate ||
855                          sec.type == binary::section::text_library;
856                }, bin.secs);
857 
858       const auto c_il = ((struct pipe_binary_program_header*)bsec.data.data())->blob;
859       const auto length = bsec.size;
860 
861       if (!check_spirv_version(dev, c_il, r_log))
862          throw error(CL_LINK_PROGRAM_FAILURE);
863 
864       sections.push_back(reinterpret_cast<const uint32_t *>(c_il));
865       lengths.push_back(length / sizeof(uint32_t));
866    }
867 
868    std::vector<uint32_t> linked_binary;
869 
870    const cl_version opencl_version = dev.device_version();
871    const spv_target_env target_env =
872       convert_opencl_version_to_target_env(opencl_version);
873 
874    const spvtools::MessageConsumer consumer = validator_consumer;
875    spvtools::Context context(target_env);
876    context.SetMessageConsumer(std::move(consumer));
877 
878    if (Link(context, sections.data(), lengths.data(), sections.size(),
879             &linked_binary, linker_options) != SPV_SUCCESS)
880       throw error(CL_LINK_PROGRAM_FAILURE);
881 
882    std::string final_binary{
883          reinterpret_cast<char *>(linked_binary.data()),
884          reinterpret_cast<char *>(linked_binary.data() +
885                linked_binary.size()) };
886    if (!is_valid_spirv(final_binary, opencl_version, r_log))
887       throw error(CL_LINK_PROGRAM_FAILURE);
888 
889    if (has_flag(llvm::debug::spirv))
890       llvm::debug::log(".spvasm", spirv::print_module(final_binary, dev.device_version()));
891 
892    for (const auto &bin : binaries)
893       b.syms.insert(b.syms.end(), bin.syms.begin(), bin.syms.end());
894 
895    b.secs.emplace_back(make_text_section(final_binary, section_type));
896 
897    return b;
898 }
899 
900 bool
is_valid_spirv(const std::string & binary,const cl_version opencl_version,std::string & r_log)901 clover::spirv::is_valid_spirv(const std::string &binary,
902                               const cl_version opencl_version,
903                               std::string &r_log) {
904    auto const validator_consumer =
905       [&r_log](spv_message_level_t level, const char *source,
906                const spv_position_t &position, const char *message) {
907       r_log += format_validator_msg(level, source, position, message);
908    };
909 
910    const spv_target_env target_env =
911       convert_opencl_version_to_target_env(opencl_version);
912    spvtools::SpirvTools spvTool(target_env);
913    spvTool.SetMessageConsumer(validator_consumer);
914 
915    spvtools::ValidatorOptions validator_options;
916    validator_options.SetUniversalLimit(spv_validator_limit_max_function_args,
917                                        std::numeric_limits<uint32_t>::max());
918 
919    return spvTool.Validate(reinterpret_cast<const uint32_t *>(binary.data()),
920                            binary.size() / 4u, validator_options);
921 }
922 
923 std::string
print_module(const std::string & binary,const cl_version opencl_version)924 clover::spirv::print_module(const std::string &binary,
925                             const cl_version opencl_version) {
926    const spv_target_env target_env =
927       convert_opencl_version_to_target_env(opencl_version);
928    spvtools::SpirvTools spvTool(target_env);
929    spv_context spvContext = spvContextCreate(target_env);
930    if (!spvContext)
931       return "Failed to create an spv_context for disassembling the binary.";
932 
933    spv_text disassembly;
934    spvBinaryToText(spvContext,
935                    reinterpret_cast<const uint32_t *>(binary.data()),
936                    binary.size() / 4u, SPV_BINARY_TO_TEXT_OPTION_NONE,
937                    &disassembly, nullptr);
938    spvContextDestroy(spvContext);
939 
940    const std::string disassemblyStr = disassembly->str;
941    spvTextDestroy(disassembly);
942 
943    return disassemblyStr;
944 }
945 
946 std::unordered_set<std::string>
supported_extensions()947 clover::spirv::supported_extensions() {
948    return {
949       /* this is only a hint so all devices support that */
950       "SPV_KHR_no_integer_wrap_decoration"
951    };
952 }
953 
954 std::vector<cl_name_version>
supported_versions()955 clover::spirv::supported_versions() {
956    return { cl_name_version { CL_MAKE_VERSION(1u, 0u, 0u), "SPIR-V" } };
957 }
958 
959 cl_version
to_opencl_version_encoding(uint32_t version)960 clover::spirv::to_opencl_version_encoding(uint32_t version) {
961       return CL_MAKE_VERSION((version >> 16u) & 0xff,
962                              (version >> 8u) & 0xff, 0u);
963 }
964 
965 uint32_t
to_spirv_version_encoding(cl_version version)966 clover::spirv::to_spirv_version_encoding(cl_version version) {
967    return ((CL_VERSION_MAJOR(version) & 0xff) << 16u) |
968           ((CL_VERSION_MINOR(version) & 0xff) << 8u);
969 }
970 
971 #else
972 bool
is_binary_spirv(const std::string & binary)973 clover::spirv::is_binary_spirv(const std::string &binary)
974 {
975    return false;
976 }
977 
978 bool
is_valid_spirv(const std::string &,const cl_version opencl_version,std::string &)979 clover::spirv::is_valid_spirv(const std::string &/*binary*/,
980                               const cl_version opencl_version,
981                               std::string &/*r_log*/) {
982    return false;
983 }
984 
985 std::string
version_to_string(uint32_t version)986 clover::spirv::version_to_string(uint32_t version) {
987    return "";
988 }
989 
990 binary
compile_program(const std::string & binary,const device & dev,std::string & r_log,bool validate)991 clover::spirv::compile_program(const std::string &binary,
992                                const device &dev, std::string &r_log,
993                                bool validate) {
994    r_log += "SPIR-V support in clover is not enabled.\n";
995    throw build_error();
996 }
997 
998 binary
link_program(const std::vector<binary> &,const device &,const std::string &,std::string & r_log)999 clover::spirv::link_program(const std::vector<binary> &/*binaries*/,
1000                             const device &/*dev*/, const std::string &/*opts*/,
1001                             std::string &r_log) {
1002    r_log += "SPIR-V support in clover is not enabled.\n";
1003    throw error(CL_LINKER_NOT_AVAILABLE);
1004 }
1005 
1006 std::string
print_module(const std::string & binary,const cl_version opencl_version)1007 clover::spirv::print_module(const std::string &binary,
1008                             const cl_version opencl_version) {
1009    return std::string();
1010 }
1011 
1012 std::unordered_set<std::string>
supported_extensions()1013 clover::spirv::supported_extensions() {
1014    return {};
1015 }
1016 
1017 std::vector<cl_name_version>
supported_versions()1018 clover::spirv::supported_versions() {
1019    return {};
1020 }
1021 
1022 cl_version
to_opencl_version_encoding(uint32_t version)1023 clover::spirv::to_opencl_version_encoding(uint32_t version) {
1024    return CL_MAKE_VERSION(0u, 0u, 0u);
1025 }
1026 
1027 uint32_t
to_spirv_version_encoding(cl_version version)1028 clover::spirv::to_spirv_version_encoding(cl_version version) {
1029    return 0u;
1030 }
1031 #endif
1032