xref: /aosp_15_r20/external/mesa3d/src/gallium/frontends/clover/nir/invocation.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 //
2 // Copyright 2019 Karol Herbst
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a
5 // copy of this software and associated documentation files (the "Software"),
6 // to deal in the Software without restriction, including without limitation
7 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 // and/or sell copies of the Software, and to permit persons to whom the
9 // Software is furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
17 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
18 // OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
19 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
20 // OTHER DEALINGS IN THE SOFTWARE.
21 //
22 
23 #include "invocation.hpp"
24 
25 #include <tuple>
26 
27 #include "core/device.hpp"
28 #include "core/error.hpp"
29 #include "core/binary.hpp"
30 #include "pipe/p_state.h"
31 #include "util/algorithm.hpp"
32 #include "util/functional.hpp"
33 
34 #include <compiler/glsl_types.h>
35 #include <compiler/clc/nir_clc_helpers.h>
36 #include <compiler/nir/nir_builder.h>
37 #include <compiler/nir/nir_serialize.h>
38 #include <compiler/spirv/nir_spirv.h>
39 #include <compiler/spirv/spirv_info.h>
40 #include <util/u_math.h>
41 #include <util/hex.h>
42 
43 using namespace clover;
44 
45 #ifdef HAVE_CLOVER_SPIRV
46 
47 // Refs and unrefs the glsl_type_singleton.
48 static class glsl_type_ref {
49 public:
glsl_type_ref()50    glsl_type_ref() {
51       glsl_type_singleton_init_or_ref();
52    }
53 
~glsl_type_ref()54    ~glsl_type_ref() {
55       glsl_type_singleton_decref();
56    }
57 } glsl_type_ref;
58 
59 static const nir_shader_compiler_options *
dev_get_nir_compiler_options(const device & dev)60 dev_get_nir_compiler_options(const device &dev)
61 {
62    const void *co = dev.get_compiler_options(PIPE_SHADER_IR_NIR);
63    return static_cast<const nir_shader_compiler_options*>(co);
64 }
65 
debug_function(void * private_data,enum nir_spirv_debug_level level,size_t spirv_offset,const char * message)66 static void debug_function(void *private_data,
67                    enum nir_spirv_debug_level level, size_t spirv_offset,
68                    const char *message)
69 {
70    assert(private_data);
71    auto r_log = reinterpret_cast<std::string *>(private_data);
72    *r_log += message;
73 }
74 
75 static void
clover_arg_size_align(const glsl_type * type,unsigned * size,unsigned * align)76 clover_arg_size_align(const glsl_type *type, unsigned *size, unsigned *align)
77 {
78    if (glsl_type_is_sampler(type) || glsl_type_is_image(type)) {
79       *size = 0;
80       *align = 1;
81    } else {
82       *size = glsl_get_cl_size(type);
83       *align = glsl_get_cl_alignment(type);
84    }
85 }
86 
87 static void
clover_nir_add_image_uniforms(nir_shader * shader)88 clover_nir_add_image_uniforms(nir_shader *shader)
89 {
90    /* Clover expects each image variable to take up a cl_mem worth of space in
91     * the arguments data.  Add uniforms as needed to match this expectation.
92     */
93    nir_foreach_image_variable_safe(var, shader) {
94       nir_variable *uniform = rzalloc(shader, nir_variable);
95       uniform->name = ralloc_strdup(uniform, var->name);
96       uniform->type = glsl_uintN_t_type(sizeof(cl_mem) * 8);
97       uniform->data.mode = nir_var_uniform;
98       uniform->data.read_only = true;
99       uniform->data.location = var->data.location;
100 
101       exec_node_insert_node_before(&var->node, &uniform->node);
102    }
103 }
104 
105 struct clover_lower_nir_state {
106    std::vector<binary::argument> &args;
107    uint32_t global_dims;
108    nir_variable *constant_var;
109    nir_variable *printf_buffer;
110    nir_variable *offset_vars[3];
111 };
112 
113 static bool
clover_lower_nir_filter(const nir_instr * instr,const void *)114 clover_lower_nir_filter(const nir_instr *instr, const void *)
115 {
116    return instr->type == nir_instr_type_intrinsic;
117 }
118 
119 static nir_def *
clover_lower_nir_instr(nir_builder * b,nir_instr * instr,void * _state)120 clover_lower_nir_instr(nir_builder *b, nir_instr *instr, void *_state)
121 {
122    clover_lower_nir_state *state = reinterpret_cast<clover_lower_nir_state*>(_state);
123    nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
124 
125    switch (intrinsic->intrinsic) {
126    case nir_intrinsic_load_printf_buffer_address: {
127       if (!state->printf_buffer) {
128          unsigned location = state->args.size();
129          state->args.emplace_back(binary::argument::global, sizeof(size_t),
130                                   8, 8, binary::argument::zero_ext,
131                                   binary::argument::printf_buffer);
132 
133          const glsl_type *type = glsl_uint64_t_type();
134          state->printf_buffer = nir_variable_create(b->shader, nir_var_uniform,
135                                                     type, "global_printf_buffer");
136          state->printf_buffer->data.location = location;
137       }
138       return nir_load_var(b, state->printf_buffer);
139    }
140    case nir_intrinsic_load_base_global_invocation_id: {
141       nir_def *loads[3];
142 
143       /* create variables if we didn't do so alrady */
144       if (!state->offset_vars[0]) {
145          /* TODO: fix for 64 bit */
146          /* Even though we only place one scalar argument, clover will bind up to
147           * three 32 bit values
148          */
149          unsigned location = state->args.size();
150          state->args.emplace_back(binary::argument::scalar, 4, 4, 4,
151                                   binary::argument::zero_ext,
152                                   binary::argument::grid_offset);
153 
154          const glsl_type *type = glsl_uint_type();
155          for (uint32_t i = 0; i < 3; i++) {
156             state->offset_vars[i] =
157                nir_variable_create(b->shader, nir_var_uniform, type,
158                                    "global_invocation_id_offsets");
159             state->offset_vars[i]->data.location = location + i;
160          }
161       }
162 
163       for (int i = 0; i < 3; i++) {
164          nir_variable *var = state->offset_vars[i];
165          loads[i] = var ? nir_load_var(b, var) : nir_imm_int(b, 0);
166       }
167 
168       return nir_u2uN(b, nir_vec(b, loads, state->global_dims),
169                      intrinsic->def.bit_size);
170    }
171    case nir_intrinsic_load_constant_base_ptr: {
172       return nir_load_var(b, state->constant_var);
173    }
174 
175    default:
176       return NULL;
177    }
178 }
179 
180 static bool
clover_lower_nir(nir_shader * nir,std::vector<binary::argument> & args,uint32_t dims,uint32_t pointer_bit_size)181 clover_lower_nir(nir_shader *nir, std::vector<binary::argument> &args,
182                  uint32_t dims, uint32_t pointer_bit_size)
183 {
184    nir_variable *constant_var = NULL;
185    if (nir->constant_data_size) {
186       const glsl_type *type = pointer_bit_size == 64 ? glsl_uint64_t_type() : glsl_uint_type();
187 
188       constant_var = nir_variable_create(nir, nir_var_uniform, type,
189                                          "constant_buffer_addr");
190       constant_var->data.location = args.size();
191 
192       args.emplace_back(binary::argument::global, sizeof(cl_mem),
193                         pointer_bit_size / 8, pointer_bit_size / 8,
194                         binary::argument::zero_ext,
195                         binary::argument::constant_buffer);
196    }
197 
198    clover_lower_nir_state state = { args, dims, constant_var };
199    return nir_shader_lower_instructions(nir,
200       clover_lower_nir_filter, clover_lower_nir_instr, &state);
201 }
202 
203 static spirv_capabilities
create_spirv_caps(const device & dev)204 create_spirv_caps(const device &dev)
205 {
206    struct spirv_capabilities caps = {};
207    caps.Addresses = true;
208    caps.Float64 = true;
209    caps.Int8 = true;
210    caps.Int16 = true;
211    caps.Int64 = true;
212    caps.Kernel = true;
213    caps.ImageBasic = dev.image_support();
214    caps.Int64Atomics = dev.has_int64_atomics();
215    return caps;
216 }
217 
218 static spirv_to_nir_options
create_spirv_options(const device & dev,spirv_capabilities & caps,std::string & r_log)219 create_spirv_options(const device &dev,
220                      spirv_capabilities &caps,
221                      std::string &r_log)
222 {
223    struct spirv_to_nir_options spirv_options = {};
224    spirv_options.environment = NIR_SPIRV_OPENCL;
225    if (dev.address_bits() == 32u) {
226       spirv_options.shared_addr_format = nir_address_format_32bit_offset;
227       spirv_options.global_addr_format = nir_address_format_32bit_global;
228       spirv_options.temp_addr_format = nir_address_format_32bit_offset;
229       spirv_options.constant_addr_format = nir_address_format_32bit_global;
230    } else {
231       spirv_options.shared_addr_format = nir_address_format_32bit_offset_as_64bit;
232       spirv_options.global_addr_format = nir_address_format_64bit_global;
233       spirv_options.temp_addr_format = nir_address_format_32bit_offset_as_64bit;
234       spirv_options.constant_addr_format = nir_address_format_64bit_global;
235    }
236    spirv_options.capabilities = &caps;
237    spirv_options.debug.func = &debug_function;
238    spirv_options.debug.private_data = &r_log;
239    spirv_options.printf = true;
240    return spirv_options;
241 }
242 
create_clc_disk_cache(void)243 struct disk_cache *clover::nir::create_clc_disk_cache(void)
244 {
245    struct mesa_sha1 ctx;
246    unsigned char sha1[20];
247    char cache_id[20 * 2 + 1];
248    _mesa_sha1_init(&ctx);
249 
250    if (!disk_cache_get_function_identifier((void *)clover::nir::create_clc_disk_cache, &ctx))
251       return NULL;
252 
253    _mesa_sha1_final(&ctx, sha1);
254 
255    mesa_bytes_to_hex(cache_id, sha1, 20);
256    return disk_cache_create("clover-clc", cache_id, 0);
257 }
258 
check_for_libclc(const device & dev)259 void clover::nir::check_for_libclc(const device &dev)
260 {
261    if (!nir_can_find_libclc(dev.address_bits()))
262       throw error(CL_COMPILER_NOT_AVAILABLE);
263 }
264 
load_libclc_nir(const device & dev,std::string & r_log)265 nir_shader *clover::nir::load_libclc_nir(const device &dev, std::string &r_log)
266 {
267    spirv_capabilities caps = create_spirv_caps(dev);
268    spirv_to_nir_options spirv_options = create_spirv_options(dev, caps, r_log);
269    auto *compiler_options = dev_get_nir_compiler_options(dev);
270 
271    return nir_load_libclc_shader(dev.address_bits(), dev.clc_cache,
272                                  &spirv_options, compiler_options,
273                                  dev.clc_cache != nullptr);
274 }
275 
276 static bool
can_remove_var(nir_variable * var,void * data)277 can_remove_var(nir_variable *var, void *data)
278 {
279    return !(glsl_type_is_sampler(var->type) ||
280             glsl_type_is_texture(var->type) ||
281             glsl_type_is_image(var->type));
282 }
283 
spirv_to_nir(const binary & mod,const device & dev,std::string & r_log)284 binary clover::nir::spirv_to_nir(const binary &mod, const device &dev,
285                                  std::string &r_log)
286 {
287    spirv_capabilities caps = create_spirv_caps(dev);
288    spirv_to_nir_options spirv_options = create_spirv_options(dev, caps, r_log);
289    std::shared_ptr<nir_shader> nir = dev.clc_nir;
290    spirv_options.clc_shader = nir.get();
291 
292    binary b;
293    // We only insert one section.
294    assert(mod.secs.size() == 1);
295    auto &section = mod.secs[0];
296 
297    binary::resource_id section_id = 0;
298    for (const auto &sym : mod.syms) {
299       assert(sym.section == 0);
300 
301       const auto *binary =
302          reinterpret_cast<const pipe_binary_program_header *>(section.data.data());
303       const uint32_t *data = reinterpret_cast<const uint32_t *>(binary->blob);
304       const size_t num_words = binary->num_bytes / 4;
305       const char *name = sym.name.c_str();
306       auto *compiler_options = dev_get_nir_compiler_options(dev);
307 
308       nir_shader *nir = spirv_to_nir(data, num_words, nullptr, 0,
309                                      MESA_SHADER_KERNEL, name,
310                                      &spirv_options, compiler_options);
311       if (!nir) {
312          r_log += "Translation from SPIR-V to NIR for kernel \"" + sym.name +
313                   "\" failed.\n";
314          throw build_error();
315       }
316 
317       nir->info.workgroup_size_variable = sym.reqd_work_group_size[0] == 0;
318       nir->info.workgroup_size[0] = sym.reqd_work_group_size[0];
319       nir->info.workgroup_size[1] = sym.reqd_work_group_size[1];
320       nir->info.workgroup_size[2] = sym.reqd_work_group_size[2];
321       nir_validate_shader(nir, "clover");
322 
323       // Inline all functions first.
324       // according to the comment on nir_inline_functions
325       NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
326       NIR_PASS_V(nir, nir_lower_returns);
327       NIR_PASS_V(nir, nir_link_shader_functions, spirv_options.clc_shader);
328 
329       NIR_PASS_V(nir, nir_inline_functions);
330       NIR_PASS_V(nir, nir_copy_prop);
331       NIR_PASS_V(nir, nir_opt_deref);
332 
333       // Pick off the single entrypoint that we want.
334       nir_remove_non_entrypoints(nir);
335 
336       nir_validate_shader(nir, "clover after function inlining");
337 
338       NIR_PASS_V(nir, nir_lower_variable_initializers, ~nir_var_function_temp);
339 
340       struct nir_lower_printf_options printf_options;
341       printf_options.max_buffer_size = dev.max_printf_buffer_size();
342 
343       NIR_PASS_V(nir, nir_lower_printf, &printf_options);
344 
345       NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_function_temp, NULL);
346 
347       // copy propagate to prepare for lower_explicit_io
348       NIR_PASS_V(nir, nir_split_var_copies);
349       NIR_PASS_V(nir, nir_opt_copy_prop_vars);
350       NIR_PASS_V(nir, nir_lower_var_copies);
351       NIR_PASS_V(nir, nir_lower_vars_to_ssa);
352       NIR_PASS_V(nir, nir_opt_dce);
353       NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
354 
355       if (compiler_options->lower_to_scalar) {
356          NIR_PASS_V(nir, nir_lower_alu_to_scalar,
357                     compiler_options->lower_to_scalar_filter, NULL);
358       }
359       NIR_PASS_V(nir, nir_lower_system_values);
360       nir_lower_compute_system_values_options sysval_options = { 0 };
361       sysval_options.has_base_global_invocation_id = true;
362       NIR_PASS_V(nir, nir_lower_compute_system_values, &sysval_options);
363 
364       // constant fold before lowering mem constants
365       NIR_PASS_V(nir, nir_opt_constant_folding);
366 
367       NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_mem_constant, NULL);
368       NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_mem_constant,
369                  glsl_get_cl_type_size_align);
370       if (nir->constant_data_size > 0) {
371          assert(nir->constant_data == NULL);
372          nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
373          nir_gather_explicit_io_initializers(nir, nir->constant_data,
374                                              nir->constant_data_size,
375                                              nir_var_mem_constant);
376       }
377       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant,
378                  spirv_options.constant_addr_format);
379 
380       auto args = sym.args;
381       NIR_PASS_V(nir, clover_lower_nir, args, dev.max_block_size().size(),
382                  dev.address_bits());
383 
384       NIR_PASS_V(nir, clover_nir_add_image_uniforms);
385       NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
386                  nir_var_uniform, clover_arg_size_align);
387       NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
388                  nir_var_mem_shared | nir_var_mem_global |
389                  nir_var_function_temp,
390                  glsl_get_cl_type_size_align);
391 
392       NIR_PASS_V(nir, nir_opt_deref);
393       NIR_PASS_V(nir, nir_lower_readonly_images_to_tex, false);
394       NIR_PASS_V(nir, nir_lower_cl_images, true, true);
395       NIR_PASS_V(nir, nir_lower_memcpy);
396 
397       /* use offsets for kernel inputs (uniform) */
398       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
399                  nir->info.cs.ptr_size == 64 ?
400                  nir_address_format_32bit_offset_as_64bit :
401                  nir_address_format_32bit_offset);
402 
403       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant,
404                  spirv_options.constant_addr_format);
405       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_shared,
406                  spirv_options.shared_addr_format);
407 
408       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_function_temp,
409                  spirv_options.temp_addr_format);
410 
411       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_global,
412                  spirv_options.global_addr_format);
413 
414       struct nir_remove_dead_variables_options remove_dead_variables_options = {};
415       remove_dead_variables_options.can_remove_var = can_remove_var;
416       NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_all, &remove_dead_variables_options);
417 
418       if (compiler_options->lower_int64_options)
419          NIR_PASS_V(nir, nir_lower_int64);
420 
421       NIR_PASS_V(nir, nir_opt_dce);
422 
423       if (nir->constant_data_size) {
424          const char *ptr = reinterpret_cast<const char *>(nir->constant_data);
425          const binary::section constants {
426             section_id,
427             binary::section::data_constant,
428             nir->constant_data_size,
429             { ptr, ptr + nir->constant_data_size }
430          };
431          nir->constant_data = NULL;
432          nir->constant_data_size = 0;
433          b.secs.push_back(constants);
434       }
435 
436       void *mem_ctx = ralloc_context(NULL);
437       unsigned printf_info_count = nir->printf_info_count;
438       u_printf_info *printf_infos = nir->printf_info;
439 
440       ralloc_steal(mem_ctx, printf_infos);
441 
442       struct blob blob;
443       blob_init(&blob);
444       nir_serialize(&blob, nir, false);
445 
446       ralloc_free(nir);
447 
448       const pipe_binary_program_header header { uint32_t(blob.size) };
449       binary::section text { section_id, binary::section::text_executable, header.num_bytes, {} };
450       text.data.insert(text.data.end(), reinterpret_cast<const char *>(&header),
451                        reinterpret_cast<const char *>(&header) + sizeof(header));
452       text.data.insert(text.data.end(), blob.data, blob.data + blob.size);
453 
454       free(blob.data);
455 
456       b.printf_strings_in_buffer = false;
457       b.printf_infos.reserve(printf_info_count);
458       for (unsigned i = 0; i < printf_info_count; i++) {
459          binary::printf_info info;
460 
461          info.arg_sizes.reserve(printf_infos[i].num_args);
462          for (unsigned j = 0; j < printf_infos[i].num_args; j++)
463             info.arg_sizes.push_back(printf_infos[i].arg_sizes[j]);
464 
465          info.strings.resize(printf_infos[i].string_size);
466          memcpy(info.strings.data(), printf_infos[i].strings, printf_infos[i].string_size);
467          b.printf_infos.push_back(info);
468       }
469 
470       ralloc_free(mem_ctx);
471 
472       b.syms.emplace_back(sym.name, sym.attributes,
473                           sym.reqd_work_group_size, section_id, 0, args);
474       b.secs.push_back(text);
475       section_id++;
476    }
477    return b;
478 }
479 #else
spirv_to_nir(const binary & mod,const device & dev,std::string & r_log)480 binary clover::nir::spirv_to_nir(const binary &mod, const device &dev, std::string &r_log)
481 {
482    r_log += "SPIR-V support in clover is not enabled.\n";
483    throw error(CL_LINKER_NOT_AVAILABLE);
484 }
485 #endif
486