xref: /aosp_15_r20/external/mesa3d/src/asahi/clc/asahi_clc.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Alyssa Rosenzweig
3  * Copyright 2020 Intel Corporation
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "asahi/compiler/agx_compile.h"
8 #include "compiler/clc/clc.h"
9 #include "compiler/glsl_types.h"
10 #include "compiler/spirv/nir_spirv.h"
11 #include "compiler/spirv/spirv_info.h"
12 #include "util/build_id.h"
13 #include "util/disk_cache.h"
14 #include "util/macros.h"
15 #include "util/mesa-sha1.h"
16 #include "util/u_dynarray.h"
17 #include "nir.h"
18 #include "nir_builder.h"
19 #include "nir_serialize.h"
20 
21 #include <fcntl.h>
22 #include <getopt.h>
23 #include <inttypes.h>
24 #include <stdio.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include "util/u_math.h"
28 #include <sys/mman.h>
29 
30 static const struct spirv_to_nir_options spirv_options = {
31    .environment = NIR_SPIRV_OPENCL,
32    .shared_addr_format = nir_address_format_62bit_generic,
33    .global_addr_format = nir_address_format_62bit_generic,
34    .temp_addr_format = nir_address_format_62bit_generic,
35    .constant_addr_format = nir_address_format_64bit_global,
36    .create_library = true,
37 };
38 
39 static bool
lower_builtins(nir_builder * b,nir_instr * instr,void * data)40 lower_builtins(nir_builder *b, nir_instr *instr, void *data)
41 {
42    if (instr->type != nir_instr_type_call)
43       return false;
44 
45    nir_call_instr *call = nir_instr_as_call(instr);
46    nir_function *func = call->callee;
47 
48    if (strcmp(func->name, "nir_interleave_agx") == 0) {
49       b->cursor = nir_instr_remove(&call->instr);
50       nir_store_deref(
51          b, nir_src_as_deref(call->params[0]),
52          nir_interleave_agx(b, call->params[1].ssa, call->params[2].ssa), 1);
53 
54       return true;
55    } else if (strcmp(func->name, "nir_doorbell_agx") == 0) {
56       b->cursor = nir_instr_remove(&call->instr);
57       nir_doorbell_agx(b, call->params[0].ssa);
58       return true;
59    } else if (strcmp(func->name, "nir_stack_map_agx") == 0) {
60       b->cursor = nir_instr_remove(&call->instr);
61       nir_stack_map_agx(b, call->params[0].ssa, call->params[1].ssa);
62       return true;
63    } else if (strcmp(func->name, "nir_stack_unmap_agx") == 0) {
64       b->cursor = nir_instr_remove(&call->instr);
65       nir_store_deref(b, nir_src_as_deref(call->params[0]),
66                       nir_stack_unmap_agx(b, call->params[1].ssa), 1);
67       return true;
68    } else if (strcmp(func->name, "nir_load_core_id_agx") == 0) {
69       b->cursor = nir_instr_remove(&call->instr);
70       nir_store_deref(b, nir_src_as_deref(call->params[0]),
71                       nir_load_core_id_agx(b), 1);
72       return true;
73    } else if (strcmp(func->name, "nir_load_helper_op_id_agx") == 0) {
74       b->cursor = nir_instr_remove(&call->instr);
75       nir_store_deref(b, nir_src_as_deref(call->params[0]),
76                       nir_load_helper_op_id_agx(b, 1, 32), 1);
77       return true;
78    } else if (strcmp(func->name, "nir_load_helper_arg_lo_agx") == 0) {
79       b->cursor = nir_instr_remove(&call->instr);
80       nir_store_deref(b, nir_src_as_deref(call->params[0]),
81                       nir_load_helper_arg_lo_agx(b, 1, 32), 1);
82       return true;
83    } else if (strcmp(func->name, "nir_load_helper_arg_hi_agx") == 0) {
84       b->cursor = nir_instr_remove(&call->instr);
85       nir_store_deref(b, nir_src_as_deref(call->params[0]),
86                       nir_load_helper_arg_hi_agx(b, 1, 32), 1);
87       return true;
88    } else if (strcmp(func->name, "ballot") == 0) {
89       b->cursor = nir_instr_remove(&call->instr);
90       nir_store_deref(b, nir_src_as_deref(call->params[0]),
91                       nir_ballot(b, 1, 32, call->params[1].ssa), 1);
92       return true;
93    } else if (strcmp(func->name, "nir_fence_helper_exit_agx") == 0) {
94       b->cursor = nir_instr_remove(&call->instr);
95       nir_fence_helper_exit_agx(b);
96       return true;
97    } else if (strcmp(func->name, "nir_bindless_image_load_array") == 0) {
98       b->cursor = nir_instr_remove(&call->instr);
99 
100       nir_def *texel = nir_bindless_image_load(
101          b, 4, 32, call->params[1].ssa, call->params[2].ssa, nir_imm_int(b, 0),
102          nir_imm_int(b, 0), .image_array = true,
103          .image_dim = GLSL_SAMPLER_DIM_2D, .dest_type = nir_type_uint32,
104          .access = ACCESS_IN_BOUNDS_AGX);
105 
106       nir_store_deref(b, nir_src_as_deref(call->params[0]), texel, 0xf);
107       return true;
108    } else if (strcmp(func->name, "nir_bindless_image_store_array") == 0) {
109       b->cursor = nir_instr_remove(&call->instr);
110 
111       nir_bindless_image_store(
112          b, call->params[0].ssa, call->params[1].ssa, nir_imm_int(b, 0),
113          call->params[2].ssa, nir_imm_int(b, 0), .image_array = true,
114          .image_dim = GLSL_SAMPLER_DIM_2D, .src_type = nir_type_uint32,
115          .access = ACCESS_NON_READABLE);
116       return true;
117    } else if (strcmp(func->name, "nir_bindless_image_load_ms_array") == 0) {
118       b->cursor = nir_instr_remove(&call->instr);
119 
120       nir_def *texel = nir_bindless_image_load(
121          b, 4, 32, call->params[1].ssa, call->params[2].ssa,
122          call->params[3].ssa, nir_imm_int(b, 0), .image_array = true,
123          .image_dim = GLSL_SAMPLER_DIM_MS, .dest_type = nir_type_uint32,
124          .access = ACCESS_IN_BOUNDS_AGX);
125 
126       nir_store_deref(b, nir_src_as_deref(call->params[0]), texel, 0xf);
127       return true;
128    } else if (strcmp(func->name, "nir_bindless_image_store_ms_array") == 0) {
129       b->cursor = nir_instr_remove(&call->instr);
130 
131       nir_bindless_image_store(
132          b, call->params[0].ssa, call->params[1].ssa, call->params[2].ssa,
133          call->params[3].ssa, nir_imm_int(b, 0), .image_array = true,
134          .image_dim = GLSL_SAMPLER_DIM_MS, .src_type = nir_type_uint32,
135          .access = ACCESS_NON_READABLE);
136       return true;
137    }
138 
139    return false;
140 }
141 
142 /* Standard optimization loop */
143 static void
optimize(nir_shader * nir)144 optimize(nir_shader *nir)
145 {
146    bool progress;
147    do {
148       progress = false;
149 
150       NIR_PASS(progress, nir, nir_lower_var_copies);
151       NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
152 
153       NIR_PASS(progress, nir, nir_copy_prop);
154       NIR_PASS(progress, nir, nir_opt_remove_phis);
155       NIR_PASS(progress, nir, nir_lower_phis_to_scalar, true);
156       NIR_PASS(progress, nir, nir_opt_dce);
157       NIR_PASS(progress, nir, nir_opt_dead_cf);
158       NIR_PASS(progress, nir, nir_opt_cse);
159       NIR_PASS(progress, nir, nir_opt_peephole_select, 64, false, true);
160       NIR_PASS(progress, nir, nir_opt_phi_precision);
161       NIR_PASS(progress, nir, nir_opt_algebraic);
162       NIR_PASS(progress, nir, nir_opt_constant_folding);
163 
164       NIR_PASS(progress, nir, nir_opt_deref);
165       NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
166       NIR_PASS(progress, nir, nir_opt_undef);
167       NIR_PASS(progress, nir, nir_lower_undef_to_zero);
168 
169       NIR_PASS(progress, nir, nir_opt_shrink_vectors, true);
170       NIR_PASS(progress, nir, nir_opt_loop_unroll);
171 
172       NIR_PASS(progress, nir, nir_split_var_copies);
173       NIR_PASS(progress, nir, nir_split_struct_vars, nir_var_function_temp);
174    } while (progress);
175 }
176 
177 static nir_shader *
compile(void * memctx,const uint32_t * spirv,size_t spirv_size)178 compile(void *memctx, const uint32_t *spirv, size_t spirv_size)
179 {
180    const nir_shader_compiler_options *nir_options = &agx_nir_options;
181 
182    assert(spirv_size % 4 == 0);
183    nir_shader *nir =
184       spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL,
185                    "library", &spirv_options, nir_options);
186    nir_validate_shader(nir, "after spirv_to_nir");
187    nir_validate_ssa_dominance(nir, "after spirv_to_nir");
188    ralloc_steal(memctx, nir);
189 
190    NIR_PASS(_, nir, nir_lower_system_values);
191    nir_shader_instructions_pass(nir, lower_builtins, nir_metadata_none, NULL);
192 
193    /* We have to lower away local constant initializers right before we
194     * inline functions.  That way they get properly initialized at the top
195     * of the function and not at the top of its caller.
196     */
197    NIR_PASS(_, nir, nir_lower_variable_initializers, nir_var_function_temp);
198    NIR_PASS(_, nir, nir_lower_returns);
199    NIR_PASS(_, nir, nir_inline_functions);
200    nir_remove_non_exported(nir);
201    NIR_PASS(_, nir, nir_copy_prop);
202    NIR_PASS(_, nir, nir_opt_deref);
203 
204    /* We can go ahead and lower the rest of the constant initializers.  We do
205     * this here so that nir_remove_dead_variables and split_per_member_structs
206     * below see the corresponding stores.
207     */
208    NIR_PASS(_, nir, nir_lower_variable_initializers, ~0);
209 
210    /* LLVM loves take advantage of the fact that vec3s in OpenCL are 16B
211     * aligned and so it can just read/write them as vec4s.  This results in a
212     * LOT of vec4->vec3 casts on loads and stores.  One solution to this
213     * problem is to get rid of all vec3 variables.
214     */
215    NIR_PASS(_, nir, nir_lower_vec3_to_vec4,
216             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
217                nir_var_mem_global | nir_var_mem_constant);
218 
219    /* We assign explicit types early so that the optimizer can take advantage
220     * of that information and hopefully get rid of some of our memcpys.
221     */
222    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
223             nir_var_uniform | nir_var_shader_temp | nir_var_function_temp |
224                nir_var_mem_shared | nir_var_mem_global,
225             glsl_get_cl_type_size_align);
226 
227    optimize(nir);
228 
229    NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_all, NULL);
230 
231    /* Lower again, this time after dead-variables to get more compact variable
232     * layouts.
233     */
234    NIR_PASS(_, nir, nir_lower_vars_to_explicit_types,
235             nir_var_shader_temp | nir_var_function_temp | nir_var_mem_shared |
236                nir_var_mem_global | nir_var_mem_constant,
237             glsl_get_cl_type_size_align);
238    if (nir->constant_data_size > 0) {
239       assert(nir->constant_data == NULL);
240       nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
241       nir_gather_explicit_io_initializers(nir, nir->constant_data,
242                                           nir->constant_data_size,
243                                           nir_var_mem_constant);
244    }
245 
246    NIR_PASS(_, nir, nir_lower_memcpy);
247 
248    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_mem_constant,
249             nir_address_format_64bit_global);
250 
251    NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_uniform,
252             nir_address_format_32bit_offset_as_64bit);
253 
254    /* Note: we cannot lower explicit I/O here, because we need derefs in tact
255     * for function calls into the library to work.
256     */
257 
258    NIR_PASS(_, nir, nir_lower_convert_alu_types, NULL);
259    NIR_PASS(_, nir, nir_opt_if, 0);
260    NIR_PASS(_, nir, nir_opt_idiv_const, 16);
261 
262    optimize(nir);
263 
264    return nir;
265 }
266 
267 /* Shader functions */
268 #define SPIR_V_MAGIC_NUMBER 0x07230203
269 
270 static void
msg_callback(void * priv,const char * msg)271 msg_callback(void *priv, const char *msg)
272 {
273    (void)priv;
274    fprintf(stderr, "%s", msg);
275 }
276 
277 static void
print_u32_data(FILE * fp,const char * prefix,const char * arr_name,const uint32_t * data,size_t len)278 print_u32_data(FILE *fp, const char *prefix, const char *arr_name,
279                const uint32_t *data, size_t len)
280 {
281    fprintf(fp, "static const uint32_t %s_%s[] = {", prefix, arr_name);
282    for (unsigned i = 0; i < (len / 4); i++) {
283       if (i % 4 == 0)
284          fprintf(fp, "\n   ");
285 
286       fprintf(fp, " 0x%08" PRIx32 ",", data[i]);
287    }
288 
289    if (len % 4) {
290       const uint8_t *data_u8 = (const uint8_t *)data;
291       uint32_t last = 0;
292       unsigned last_offs = ROUND_DOWN_TO(len, 4);
293       for (unsigned i = 0; i < len % 4; ++i) {
294          last |= (uint32_t)data_u8[last_offs + i] << (i * 8);
295       }
296 
297       fprintf(fp, " 0x%08" PRIx32 ",", last);
298    }
299 
300    fprintf(fp, "\n};\n");
301 }
302 
303 static void
print_usage(char * exec_name,FILE * f)304 print_usage(char *exec_name, FILE *f)
305 {
306    fprintf(
307       f,
308       "Usage: %s [options] -- [clang args]\n"
309       "Options:\n"
310       "  -h  --help              Print this help.\n"
311       "      --prefix <prefix>   Prefix for variable names in generated C code.\n"
312       "  -o, --out <filename>    Specify the output filename.\n"
313       "  -i, --in <filename>     Specify one input filename. Accepted multiple times.\n"
314       "  -s, --spv <filename>    Specify the output filename for spirv.\n"
315       "  -v, --verbose           Print more information during compilation.\n",
316       exec_name);
317 }
318 
319 #define OPT_PREFIX 1000
320 
321 static uint32_t
get_module_spirv_version(const uint32_t * spirv,size_t size)322 get_module_spirv_version(const uint32_t *spirv, size_t size)
323 {
324    assert(size >= 8);
325    assert(spirv[0] == SPIR_V_MAGIC_NUMBER);
326    return spirv[1];
327 }
328 
329 static void
set_module_spirv_version(uint32_t * spirv,size_t size,uint32_t version)330 set_module_spirv_version(uint32_t *spirv, size_t size, uint32_t version)
331 {
332    assert(size >= 8);
333    assert(spirv[0] == SPIR_V_MAGIC_NUMBER);
334    spirv[1] = version;
335 }
336 
337 int
main(int argc,char ** argv)338 main(int argc, char **argv)
339 {
340    static struct option long_options[] = {
341       {"help", no_argument, 0, 'h'},
342       {"prefix", required_argument, 0, OPT_PREFIX},
343       {"in", required_argument, 0, 'i'},
344       {"out", required_argument, 0, 'o'},
345       {"spv", required_argument, 0, 's'},
346       {"verbose", no_argument, 0, 'v'},
347       {0, 0, 0, 0},
348    };
349 
350    char *outfile = NULL, *spv_outfile = NULL, *prefix = NULL;
351    struct util_dynarray clang_args;
352    struct util_dynarray input_files;
353    struct util_dynarray spirv_objs;
354    struct util_dynarray spirv_ptr_objs;
355 
356    void *mem_ctx = ralloc_context(NULL);
357 
358    util_dynarray_init(&clang_args, mem_ctx);
359    util_dynarray_init(&input_files, mem_ctx);
360    util_dynarray_init(&spirv_objs, mem_ctx);
361    util_dynarray_init(&spirv_ptr_objs, mem_ctx);
362 
363    int ch;
364    while ((ch = getopt_long(argc, argv, "he:p:s:i:o:v", long_options, NULL)) !=
365           -1) {
366       switch (ch) {
367       case 'h':
368          print_usage(argv[0], stdout);
369          return 0;
370       case 'o':
371          outfile = optarg;
372          break;
373       case 'i':
374          util_dynarray_append(&input_files, char *, optarg);
375          break;
376       case 's':
377          spv_outfile = optarg;
378          break;
379       case OPT_PREFIX:
380          prefix = optarg;
381          break;
382       default:
383          fprintf(stderr, "Unrecognized option \"%s\".\n", optarg);
384          print_usage(argv[0], stderr);
385          return 1;
386       }
387    }
388 
389    for (int i = optind; i < argc; i++) {
390       util_dynarray_append(&clang_args, char *, argv[i]);
391    }
392 
393    if (util_dynarray_num_elements(&input_files, char *) == 0) {
394       fprintf(stderr, "No input file(s).\n");
395       print_usage(argv[0], stderr);
396       return -1;
397    }
398 
399    if (prefix == NULL) {
400       fprintf(stderr, "No prefix specified.\n");
401       print_usage(argv[0], stderr);
402       return -1;
403    }
404 
405    struct clc_logger logger = {
406       .error = msg_callback,
407       .warning = msg_callback,
408    };
409 
410    util_dynarray_foreach(&input_files, char *, infile) {
411       int fd = open(*infile, O_RDONLY);
412       if (fd < 0) {
413          fprintf(stderr, "Failed to open %s\n", *infile);
414          ralloc_free(mem_ctx);
415          return 1;
416       }
417 
418       off_t len = lseek(fd, 0, SEEK_END);
419       const void *map = mmap(NULL, len, PROT_READ, MAP_PRIVATE, fd, 0);
420       close(fd);
421       if (map == MAP_FAILED) {
422          fprintf(stderr, "Failed to mmap the file: errno=%d, %s\n", errno,
423                  strerror(errno));
424          ralloc_free(mem_ctx);
425          return 1;
426       }
427 
428       const char *allowed_spirv_extensions[] = {
429          "SPV_EXT_shader_atomic_float_add",
430          "SPV_EXT_shader_atomic_float_min_max",
431          "SPV_KHR_float_controls",
432          "SPV_INTEL_subgroups",
433          NULL,
434       };
435 
436       struct clc_compile_args clc_args = {
437          .source =
438             {
439                .name = *infile,
440                .value = map,
441             },
442          .features =
443             {
444                .fp16 = true,
445                .intel_subgroups = true,
446                .subgroups = true,
447                .subgroups_ifp = true,
448             },
449          .args = util_dynarray_begin(&clang_args),
450          .num_args = util_dynarray_num_elements(&clang_args, char *),
451          .allowed_spirv_extensions = allowed_spirv_extensions,
452       };
453 
454       struct clc_binary *spirv_out =
455          util_dynarray_grow(&spirv_objs, struct clc_binary, 1);
456 
457       if (!clc_compile_c_to_spirv(&clc_args, &logger, spirv_out)) {
458          ralloc_free(mem_ctx);
459          return 1;
460       }
461    }
462 
463    util_dynarray_foreach(&spirv_objs, struct clc_binary, p) {
464       util_dynarray_append(&spirv_ptr_objs, struct clc_binary *, p);
465    }
466 
467    /* The SPIRV-Tools linker started checking that all modules have the same
468     * version. But SPIRV-LLVM-Translator picks the lower required version for
469     * each module it compiles. So we have to iterate over all of them and set
470     * the max found to make SPIRV-Tools link our modules.
471     *
472     * TODO: This is not the correct thing to do. We need SPIRV-LLVM-Translator
473     *       to pick a given SPIRV version given to it and have all the modules
474     *       at that version. We should remove this hack when this issue is
475     *       fixed :
476     *       https://github.com/KhronosGroup/SPIRV-LLVM-Translator/issues/1445
477     */
478    uint32_t max_spirv_version = 0;
479    util_dynarray_foreach(&spirv_ptr_objs, struct clc_binary *, module) {
480       max_spirv_version =
481          MAX2(max_spirv_version,
482               get_module_spirv_version((*module)->data, (*module)->size));
483    }
484 
485    assert(max_spirv_version > 0);
486    util_dynarray_foreach(&spirv_ptr_objs, struct clc_binary *, module) {
487       set_module_spirv_version((*module)->data, (*module)->size,
488                                max_spirv_version);
489    }
490 
491    struct clc_linker_args link_args = {
492       .in_objs = util_dynarray_begin(&spirv_ptr_objs),
493       .num_in_objs =
494          util_dynarray_num_elements(&spirv_ptr_objs, struct clc_binary *),
495       .create_library = true,
496    };
497    struct clc_binary final_spirv;
498    if (!clc_link_spirv(&link_args, &logger, &final_spirv)) {
499       ralloc_free(mem_ctx);
500       return 1;
501    }
502 
503    if (spv_outfile) {
504       FILE *fp = fopen(spv_outfile, "w");
505       fwrite(final_spirv.data, final_spirv.size, 1, fp);
506       fclose(fp);
507    }
508 
509    FILE *fp = stdout;
510    if (outfile != NULL)
511       fp = fopen(outfile, "w");
512 
513    glsl_type_singleton_init_or_ref();
514 
515    fprintf(fp, "/*\n");
516    fprintf(fp, " * Copyright The Asahi Linux Contributors\n");
517    fprintf(fp, " * SPDX-License-Identifier: MIT\n");
518    fprintf(fp, " *\n");
519    fprintf(fp, " * Autogenerated file, do not edit\n");
520    fprintf(fp, " */\n");
521    fprintf(fp, " #include <stdint.h>\n");
522 
523    /* Compile SPIR-V to NIR */
524    nir_shader *nir = compile(mem_ctx, final_spirv.data, final_spirv.size);
525 
526    {
527       nir_builder b = nir_builder_init_simple_shader(
528          MESA_SHADER_COMPUTE, &agx_nir_options, "Helper shader");
529 
530       nir_function *func =
531          nir_shader_get_function_for_name(nir, "libagx_helper");
532 
533       nir_call(&b, nir_function_clone(b.shader, func));
534 
535       struct agx_shader_part compiled;
536       struct agx_shader_key key = {
537          .libagx = nir,
538          .is_helper = true,
539       };
540 
541       agx_preprocess_nir(b.shader, nir);
542       agx_compile_shader_nir(b.shader, &key, NULL, &compiled);
543 
544       print_u32_data(fp, "libagx_g13", "helper", compiled.binary,
545                      compiled.binary_size);
546       free(compiled.binary);
547       ralloc_free(b.shader);
548 
549       /* Remove the NIR function, it's compiled, we don't need it at runtime */
550       exec_node_remove(&func->node);
551    }
552 
553    spirv_library_to_nir_builder(fp, final_spirv.data, final_spirv.size / 4,
554                                 &spirv_options);
555 
556    /* Serialize NIR for embedding */
557    struct blob blob;
558    blob_init(&blob);
559    nir_serialize(&blob, nir, false /* strip */);
560    print_u32_data(fp, prefix, "nir", (const uint32_t *)blob.data, blob.size);
561    blob_finish(&blob);
562 
563    glsl_type_singleton_decref();
564 
565    if (fp != stdout)
566       fclose(fp);
567 
568    util_dynarray_foreach(&spirv_objs, struct clc_binary, p) {
569       clc_free_spirv(p);
570    }
571 
572    clc_free_spirv(&final_spirv);
573    ralloc_free(mem_ctx);
574 
575    return 0;
576 }
577