xref: /aosp_15_r20/external/mesa3d/src/intel/compiler/brw_kernel.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2020 Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "brw_kernel.h"
25 #include "brw_nir.h"
26 #include "elk/elk_nir_options.h"
27 #include "intel_nir.h"
28 
29 #include "intel_nir.h"
30 #include "nir_clc_helpers.h"
31 #include "compiler/nir/nir_builder.h"
32 #include "compiler/spirv/nir_spirv.h"
33 #include "compiler/spirv/spirv_info.h"
34 #include "dev/intel_debug.h"
35 #include "util/u_atomic.h"
36 #include "util/u_dynarray.h"
37 
38 static const nir_shader *
load_clc_shader(struct brw_compiler * compiler,struct disk_cache * disk_cache,const nir_shader_compiler_options * nir_options,const struct spirv_to_nir_options * spirv_options)39 load_clc_shader(struct brw_compiler *compiler, struct disk_cache *disk_cache,
40                 const nir_shader_compiler_options *nir_options,
41                 const struct spirv_to_nir_options *spirv_options)
42 {
43    if (compiler->clc_shader)
44       return compiler->clc_shader;
45 
46    nir_shader *nir =  nir_load_libclc_shader(64, disk_cache,
47                                              spirv_options, nir_options,
48                                              disk_cache != NULL);
49    if (nir == NULL)
50       return NULL;
51 
52    const nir_shader *old_nir =
53       p_atomic_cmpxchg(&compiler->clc_shader, NULL, nir);
54    if (old_nir == NULL) {
55       /* We won the race */
56       ralloc_steal(compiler, nir);
57       return nir;
58    } else {
59       /* Someone else built the shader first */
60       ralloc_free(nir);
61       return old_nir;
62    }
63 }
64 
65 static nir_builder
builder_init_new_impl(nir_function * func)66 builder_init_new_impl(nir_function *func)
67 {
68    nir_function_impl *impl = nir_function_impl_create(func);
69    return nir_builder_at(nir_before_impl(impl));
70 }
71 
72 static void
implement_atomic_builtin(nir_function * func,nir_atomic_op atomic_op,enum glsl_base_type data_base_type,nir_variable_mode mode)73 implement_atomic_builtin(nir_function *func, nir_atomic_op atomic_op,
74                          enum glsl_base_type data_base_type,
75                          nir_variable_mode mode)
76 {
77    nir_builder b = builder_init_new_impl(func);
78    const struct glsl_type *data_type = glsl_scalar_type(data_base_type);
79 
80    unsigned p = 0;
81 
82    nir_deref_instr *ret = NULL;
83    ret = nir_build_deref_cast(&b, nir_load_param(&b, p++),
84                               nir_var_function_temp, data_type, 0);
85 
86    nir_intrinsic_op op = nir_intrinsic_deref_atomic;
87    nir_intrinsic_instr *atomic = nir_intrinsic_instr_create(b.shader, op);
88    nir_intrinsic_set_atomic_op(atomic, atomic_op);
89 
90    for (unsigned i = 0; i < nir_intrinsic_infos[op].num_srcs; i++) {
91       nir_def *src = nir_load_param(&b, p++);
92       if (i == 0) {
93          /* The first source is our deref */
94          assert(nir_intrinsic_infos[op].src_components[i] == -1);
95          src = &nir_build_deref_cast(&b, src, mode, data_type, 0)->def;
96       }
97       atomic->src[i] = nir_src_for_ssa(src);
98    }
99 
100    nir_def_init_for_type(&atomic->instr, &atomic->def, data_type);
101 
102    nir_builder_instr_insert(&b, &atomic->instr);
103    nir_store_deref(&b, ret, &atomic->def, ~0);
104 }
105 
106 static void
implement_sub_group_ballot_builtin(nir_function * func)107 implement_sub_group_ballot_builtin(nir_function *func)
108 {
109    nir_builder b = builder_init_new_impl(func);
110    nir_deref_instr *ret =
111       nir_build_deref_cast(&b, nir_load_param(&b, 0),
112                            nir_var_function_temp, glsl_uint_type(), 0);
113    nir_def *cond = nir_load_param(&b, 1);
114 
115    nir_intrinsic_instr *ballot =
116       nir_intrinsic_instr_create(b.shader, nir_intrinsic_ballot);
117    ballot->src[0] = nir_src_for_ssa(cond);
118    ballot->num_components = 1;
119    nir_def_init(&ballot->instr, &ballot->def, 1, 32);
120    nir_builder_instr_insert(&b, &ballot->instr);
121 
122    nir_store_deref(&b, ret, &ballot->def, ~0);
123 }
124 
125 static bool
implement_intel_builtins(nir_shader * nir)126 implement_intel_builtins(nir_shader *nir)
127 {
128    bool progress = false;
129 
130    nir_foreach_function(func, nir) {
131       if (strcmp(func->name, "_Z10atomic_minPU3AS1Vff") == 0) {
132          /* float atom_min(__global float volatile *p, float val) */
133          implement_atomic_builtin(func, nir_atomic_op_fmin,
134                                   GLSL_TYPE_FLOAT, nir_var_mem_global);
135          progress = true;
136       } else if (strcmp(func->name, "_Z10atomic_maxPU3AS1Vff") == 0) {
137          /* float atom_max(__global float volatile *p, float val) */
138          implement_atomic_builtin(func, nir_atomic_op_fmax,
139                                   GLSL_TYPE_FLOAT, nir_var_mem_global);
140          progress = true;
141       } else if (strcmp(func->name, "_Z10atomic_minPU3AS3Vff") == 0) {
142          /* float atomic_min(__shared float volatile *, float) */
143          implement_atomic_builtin(func, nir_atomic_op_fmin,
144                                   GLSL_TYPE_FLOAT, nir_var_mem_shared);
145          progress = true;
146       } else if (strcmp(func->name, "_Z10atomic_maxPU3AS3Vff") == 0) {
147          /* float atomic_max(__shared float volatile *, float) */
148          implement_atomic_builtin(func, nir_atomic_op_fmax,
149                                   GLSL_TYPE_FLOAT, nir_var_mem_shared);
150          progress = true;
151       } else if (strcmp(func->name, "intel_sub_group_ballot") == 0) {
152          implement_sub_group_ballot_builtin(func);
153          progress = true;
154       }
155    }
156 
157    nir_shader_preserve_all_metadata(nir);
158 
159    return progress;
160 }
161 
162 static bool
lower_kernel_intrinsics(nir_shader * nir)163 lower_kernel_intrinsics(nir_shader *nir)
164 {
165    nir_function_impl *impl = nir_shader_get_entrypoint(nir);
166 
167    bool progress = false;
168 
169    unsigned kernel_sysvals_start = 0;
170    unsigned kernel_arg_start = sizeof(struct brw_kernel_sysvals);
171    nir->num_uniforms += kernel_arg_start;
172 
173    nir_builder b = nir_builder_create(impl);
174 
175    nir_foreach_block(block, impl) {
176       nir_foreach_instr_safe(instr, block) {
177          if (instr->type != nir_instr_type_intrinsic)
178             continue;
179 
180          nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
181          switch (intrin->intrinsic) {
182          case nir_intrinsic_load_kernel_input: {
183             b.cursor = nir_instr_remove(&intrin->instr);
184 
185             nir_intrinsic_instr *load =
186                nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
187             load->num_components = intrin->num_components;
188             load->src[0] = nir_src_for_ssa(nir_u2u32(&b, intrin->src[0].ssa));
189             nir_intrinsic_set_base(load, kernel_arg_start);
190             nir_intrinsic_set_range(load, nir->num_uniforms);
191             nir_def_init(&load->instr, &load->def,
192                          intrin->def.num_components,
193                          intrin->def.bit_size);
194             nir_builder_instr_insert(&b, &load->instr);
195 
196             nir_def_rewrite_uses(&intrin->def, &load->def);
197             progress = true;
198             break;
199          }
200 
201          case nir_intrinsic_load_constant_base_ptr: {
202             b.cursor = nir_instr_remove(&intrin->instr);
203             nir_def *const_data_base_addr = nir_pack_64_2x32_split(&b,
204                nir_load_reloc_const_intel(&b, BRW_SHADER_RELOC_CONST_DATA_ADDR_LOW),
205                nir_load_reloc_const_intel(&b, BRW_SHADER_RELOC_CONST_DATA_ADDR_HIGH));
206             nir_def_rewrite_uses(&intrin->def, const_data_base_addr);
207             progress = true;
208             break;
209          }
210 
211          case nir_intrinsic_load_num_workgroups: {
212             b.cursor = nir_instr_remove(&intrin->instr);
213 
214             nir_intrinsic_instr *load =
215                nir_intrinsic_instr_create(nir, nir_intrinsic_load_uniform);
216             load->num_components = 3;
217             load->src[0] = nir_src_for_ssa(nir_imm_int(&b, 0));
218             nir_intrinsic_set_base(load, kernel_sysvals_start +
219                offsetof(struct brw_kernel_sysvals, num_work_groups));
220             nir_intrinsic_set_range(load, 3 * 4);
221             nir_def_init(&load->instr, &load->def, 3, 32);
222             nir_builder_instr_insert(&b, &load->instr);
223             nir_def_rewrite_uses(&intrin->def, &load->def);
224             progress = true;
225             break;
226          }
227 
228          default:
229             break;
230          }
231       }
232    }
233 
234    if (progress) {
235       nir_metadata_preserve(impl, nir_metadata_control_flow);
236    } else {
237       nir_metadata_preserve(impl, nir_metadata_all);
238    }
239 
240    return progress;
241 }
242 
243 static const struct spirv_capabilities spirv_caps = {
244    .Addresses = true,
245    .Float16 = true,
246    .Float64 = true,
247    .Groups = true,
248    .StorageImageWriteWithoutFormat = true,
249    .Int8 = true,
250    .Int16 = true,
251    .Int64 = true,
252    .Int64Atomics = true,
253    .Kernel = true,
254    .Linkage = true, /* We receive linked kernel from clc */
255    .DenormFlushToZero = true,
256    .DenormPreserve = true,
257    .SignedZeroInfNanPreserve = true,
258    .RoundingModeRTE = true,
259    .RoundingModeRTZ = true,
260    .GenericPointer = true,
261    .GroupNonUniform = true,
262    .GroupNonUniformArithmetic = true,
263    .GroupNonUniformClustered = true,
264    .GroupNonUniformBallot = true,
265    .GroupNonUniformQuad = true,
266    .GroupNonUniformShuffle = true,
267    .GroupNonUniformVote = true,
268    .SubgroupDispatch = true,
269 
270    .SubgroupShuffleINTEL = true,
271    .SubgroupBufferBlockIOINTEL = true,
272 };
273 
274 bool
brw_kernel_from_spirv(struct brw_compiler * compiler,struct disk_cache * disk_cache,struct brw_kernel * kernel,void * log_data,void * mem_ctx,const uint32_t * spirv,size_t spirv_size,const char * entrypoint_name,char ** error_str)275 brw_kernel_from_spirv(struct brw_compiler *compiler,
276                       struct disk_cache *disk_cache,
277                       struct brw_kernel *kernel,
278                       void *log_data, void *mem_ctx,
279                       const uint32_t *spirv, size_t spirv_size,
280                       const char *entrypoint_name,
281                       char **error_str)
282 {
283    const struct intel_device_info *devinfo = compiler->devinfo;
284    const nir_shader_compiler_options *nir_options =
285       compiler->nir_options[MESA_SHADER_KERNEL];
286 
287    struct spirv_to_nir_options spirv_options = {
288       .environment = NIR_SPIRV_OPENCL,
289       .capabilities = &spirv_caps,
290       .printf = true,
291       .shared_addr_format = nir_address_format_62bit_generic,
292       .global_addr_format = nir_address_format_62bit_generic,
293       .temp_addr_format = nir_address_format_62bit_generic,
294       .constant_addr_format = nir_address_format_64bit_global,
295    };
296 
297    spirv_options.clc_shader = load_clc_shader(compiler, disk_cache,
298                                               nir_options, &spirv_options);
299    if (spirv_options.clc_shader == NULL) {
300       fprintf(stderr, "ERROR: libclc shader missing."
301               " Consider installing the libclc package\n");
302       abort();
303    }
304 
305    assert(spirv_size % 4 == 0);
306    nir_shader *nir =
307       spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL,
308                    entrypoint_name, &spirv_options, nir_options);
309    nir_validate_shader(nir, "after spirv_to_nir");
310    nir_validate_ssa_dominance(nir, "after spirv_to_nir");
311    ralloc_steal(mem_ctx, nir);
312    nir->info.name = ralloc_strdup(nir, entrypoint_name);
313 
314    if (INTEL_DEBUG(DEBUG_CS)) {
315       /* Re-index SSA defs so we print more sensible numbers. */
316       nir_foreach_function_impl(impl, nir) {
317          nir_index_ssa_defs(impl);
318       }
319 
320       fprintf(stderr, "NIR (from SPIR-V) for kernel\n");
321       nir_print_shader(nir, stderr);
322    }
323 
324    nir_lower_printf_options printf_opts = {
325       .ptr_bit_size               = 64,
326       .use_printf_base_identifier = true,
327    };
328    NIR_PASS_V(nir, nir_lower_printf, &printf_opts);
329 
330    NIR_PASS_V(nir, implement_intel_builtins);
331    NIR_PASS_V(nir, nir_link_shader_functions, spirv_options.clc_shader);
332 
333    /* We have to lower away local constant initializers right before we
334     * inline functions.  That way they get properly initialized at the top
335     * of the function and not at the top of its caller.
336     */
337    NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
338    NIR_PASS_V(nir, nir_lower_returns);
339    NIR_PASS_V(nir, nir_inline_functions);
340    NIR_PASS_V(nir, nir_copy_prop);
341    NIR_PASS_V(nir, nir_opt_deref);
342 
343    /* Pick off the single entrypoint that we want */
344    nir_remove_non_entrypoints(nir);
345 
346    /* Now that we've deleted all but the main function, we can go ahead and
347     * lower the rest of the constant initializers.  We do this here so that
348     * nir_remove_dead_variables and split_per_member_structs below see the
349     * corresponding stores.
350     */
351    NIR_PASS_V(nir, nir_lower_variable_initializers, ~0);
352 
353    /* LLVM loves take advantage of the fact that vec3s in OpenCL are 16B
354     * aligned and so it can just read/write them as vec4s.  This results in a
355     * LOT of vec4->vec3 casts on loads and stores.  One solution to this
356     * problem is to get rid of all vec3 variables.
357     */
358    NIR_PASS_V(nir, nir_lower_vec3_to_vec4,
359               nir_var_shader_temp | nir_var_function_temp |
360               nir_var_mem_shared | nir_var_mem_global|
361               nir_var_mem_constant);
362 
363    /* We assign explicit types early so that the optimizer can take advantage
364     * of that information and hopefully get rid of some of our memcpys.
365     */
366    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
367               nir_var_uniform |
368               nir_var_shader_temp | nir_var_function_temp |
369               nir_var_mem_shared | nir_var_mem_global,
370               glsl_get_cl_type_size_align);
371 
372    struct brw_nir_compiler_opts opts = {};
373    brw_preprocess_nir(compiler, nir, &opts);
374 
375    int max_arg_idx = -1;
376    nir_foreach_uniform_variable(var, nir) {
377       assert(var->data.location < 256);
378       max_arg_idx = MAX2(max_arg_idx, var->data.location);
379    }
380 
381    kernel->args_size = nir->num_uniforms;
382    kernel->arg_count = max_arg_idx + 1;
383 
384    /* No bindings */
385    struct brw_kernel_arg_desc *args =
386       rzalloc_array(mem_ctx, struct brw_kernel_arg_desc, kernel->arg_count);
387    kernel->args = args;
388 
389    nir_foreach_uniform_variable(var, nir) {
390       struct brw_kernel_arg_desc arg_desc = {
391          .offset = var->data.driver_location,
392          .size = glsl_get_explicit_size(var->type, false),
393       };
394       assert(arg_desc.offset + arg_desc.size <= nir->num_uniforms);
395 
396       assert(var->data.location >= 0);
397       args[var->data.location] = arg_desc;
398    }
399 
400    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_all, NULL);
401 
402    /* Lower again, this time after dead-variables to get more compact variable
403     * layouts.
404     */
405    nir->global_mem_size = 0;
406    nir->scratch_size = 0;
407    nir->info.shared_size = 0;
408    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
409               nir_var_shader_temp | nir_var_function_temp |
410               nir_var_mem_shared | nir_var_mem_global | nir_var_mem_constant,
411               glsl_get_cl_type_size_align);
412    if (nir->constant_data_size > 0) {
413       assert(nir->constant_data == NULL);
414       nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
415       nir_gather_explicit_io_initializers(nir, nir->constant_data,
416                                           nir->constant_data_size,
417                                           nir_var_mem_constant);
418    }
419 
420    if (INTEL_DEBUG(DEBUG_CS)) {
421       /* Re-index SSA defs so we print more sensible numbers. */
422       nir_foreach_function_impl(impl, nir) {
423          nir_index_ssa_defs(impl);
424       }
425 
426       fprintf(stderr, "NIR (before I/O lowering) for kernel\n");
427       nir_print_shader(nir, stderr);
428    }
429 
430    NIR_PASS_V(nir, nir_lower_memcpy);
431 
432    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant,
433               nir_address_format_64bit_global);
434 
435    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
436               nir_address_format_32bit_offset_as_64bit);
437 
438    NIR_PASS_V(nir, nir_lower_explicit_io,
439               nir_var_shader_temp | nir_var_function_temp |
440               nir_var_mem_shared | nir_var_mem_global,
441               nir_address_format_62bit_generic);
442 
443    NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
444 
445    NIR_PASS_V(nir, brw_nir_lower_cs_intrinsics, devinfo, NULL);
446    NIR_PASS_V(nir, lower_kernel_intrinsics);
447 
448    struct brw_cs_prog_key key = { };
449 
450    memset(&kernel->prog_data, 0, sizeof(kernel->prog_data));
451    kernel->prog_data.base.nr_params = DIV_ROUND_UP(nir->num_uniforms, 4);
452 
453    struct brw_compile_cs_params params = {
454       .base = {
455          .nir = nir,
456          .stats = kernel->stats,
457          .log_data = log_data,
458          .mem_ctx = mem_ctx,
459       },
460       .key = &key,
461       .prog_data = &kernel->prog_data,
462    };
463 
464    kernel->code = brw_compile_cs(compiler, &params);
465 
466    if (error_str)
467       *error_str = params.base.error_str;
468 
469    return kernel->code != NULL;
470 }
471 
472 static nir_def *
rebuild_value_from_store(struct util_dynarray * stores,nir_def * value,unsigned read_offset)473 rebuild_value_from_store(struct util_dynarray *stores,
474                          nir_def *value, unsigned read_offset)
475 {
476    unsigned read_size = value->num_components * value->bit_size / 8;
477 
478    util_dynarray_foreach(stores, nir_intrinsic_instr *, _store) {
479       nir_intrinsic_instr *store = *_store;
480 
481       unsigned write_offset = nir_src_as_uint(store->src[1]);
482       unsigned write_size = nir_src_num_components(store->src[0]) *
483                             nir_src_bit_size(store->src[0]) / 8;
484       if (write_offset <= read_offset &&
485           (write_offset + write_size) >= (read_offset + read_size)) {
486          assert(nir_block_dominates(store->instr.block, value->parent_instr->block));
487          assert(write_size == read_size);
488          return store->src[0].ssa;
489       }
490    }
491    unreachable("Matching scratch store not found");
492 }
493 
494 /**
495  * Remove temporary variables stored to scratch to be then reloaded
496  * immediately. Remap the load to the store SSA value.
497  *
498  * This workaround is only meant to be applied to shaders in src/intel/shaders
499  * were we know there should be no issue. More complex cases might not work
500  * with this approach.
501  */
502 static bool
nir_remove_llvm17_scratch(nir_shader * nir)503 nir_remove_llvm17_scratch(nir_shader *nir)
504 {
505    struct util_dynarray scratch_stores;
506    void *mem_ctx = ralloc_context(NULL);
507 
508    util_dynarray_init(&scratch_stores, mem_ctx);
509 
510    nir_foreach_function_impl(func, nir) {
511       nir_foreach_block(block, func) {
512          nir_foreach_instr(instr, block) {
513             if (instr->type != nir_instr_type_intrinsic)
514                continue;
515 
516             nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
517 
518             if (intrin->intrinsic != nir_intrinsic_store_scratch)
519                continue;
520 
521             nir_const_value *offset = nir_src_as_const_value(intrin->src[1]);
522             if (offset != NULL) {
523                util_dynarray_append(&scratch_stores, nir_intrinsic_instr *, intrin);
524             }
525          }
526       }
527    }
528 
529    bool progress = false;
530    if (util_dynarray_num_elements(&scratch_stores, nir_intrinsic_instr *) > 0) {
531       nir_foreach_function_impl(func, nir) {
532          nir_foreach_block(block, func) {
533             nir_foreach_instr_safe(instr, block) {
534                if (instr->type != nir_instr_type_intrinsic)
535                   continue;
536 
537                nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
538 
539                if (intrin->intrinsic != nir_intrinsic_load_scratch)
540                   continue;
541 
542                nir_const_value *offset = nir_src_as_const_value(intrin->src[0]);
543                if (offset == NULL)
544                   continue;
545 
546                nir_def_replace(&intrin->def,
547                                rebuild_value_from_store(&scratch_stores, &intrin->def, nir_src_as_uint(intrin->src[0])));
548 
549                progress = true;
550             }
551          }
552       }
553    }
554 
555    util_dynarray_foreach(&scratch_stores, nir_intrinsic_instr *, _store) {
556       nir_intrinsic_instr *store = *_store;
557       nir_instr_remove(&store->instr);
558    }
559 
560    /* Quick sanity check */
561    assert(util_dynarray_num_elements(&scratch_stores, nir_intrinsic_instr *) == 0 ||
562           progress);
563 
564    ralloc_free(mem_ctx);
565 
566    return progress;
567 }
568 
569 static void
cleanup_llvm17_scratch(nir_shader * nir)570 cleanup_llvm17_scratch(nir_shader *nir)
571 {
572    {
573       bool progress;
574       do {
575          progress = false;
576          NIR_PASS(progress, nir, nir_copy_prop);
577          NIR_PASS(progress, nir, nir_opt_dce);
578          NIR_PASS(progress, nir, nir_opt_constant_folding);
579          NIR_PASS(progress, nir, nir_opt_cse);
580          NIR_PASS(progress, nir, nir_opt_algebraic);
581       } while (progress);
582    }
583 
584    nir_remove_llvm17_scratch(nir);
585 
586    {
587       bool progress;
588       do {
589          progress = false;
590          NIR_PASS(progress, nir, nir_copy_prop);
591          NIR_PASS(progress, nir, nir_opt_dce);
592          NIR_PASS(progress, nir, nir_opt_constant_folding);
593          NIR_PASS(progress, nir, nir_opt_cse);
594          NIR_PASS(progress, nir, nir_opt_algebraic);
595       } while (progress);
596    }
597 }
598 
599 nir_shader *
brw_nir_from_spirv(void * mem_ctx,unsigned gfx_version,const uint32_t * spirv,size_t spirv_size,bool llvm17_wa)600 brw_nir_from_spirv(void *mem_ctx, unsigned gfx_version, const uint32_t *spirv,
601                    size_t spirv_size, bool llvm17_wa)
602 {
603    struct spirv_to_nir_options spirv_options = {
604       .environment = NIR_SPIRV_OPENCL,
605       .capabilities = &spirv_caps,
606       .printf = true,
607       .shared_addr_format = nir_address_format_62bit_generic,
608       .global_addr_format = nir_address_format_62bit_generic,
609       .temp_addr_format = nir_address_format_62bit_generic,
610       .constant_addr_format = nir_address_format_64bit_global,
611       .create_library = true,
612    };
613 
614    assert(spirv_size % 4 == 0);
615 
616    assert(gfx_version);
617    const nir_shader_compiler_options *nir_options =
618       gfx_version >= 9 ? &brw_scalar_nir_options
619                        : &elk_scalar_nir_options;
620 
621    nir_shader *nir =
622       spirv_to_nir(spirv, spirv_size / 4, NULL, 0, MESA_SHADER_KERNEL,
623                    "library", &spirv_options, nir_options);
624    nir_validate_shader(nir, "after spirv_to_nir");
625    nir_validate_ssa_dominance(nir, "after spirv_to_nir");
626    ralloc_steal(mem_ctx, nir);
627    nir->info.name = ralloc_strdup(nir, "library");
628 
629    if (INTEL_DEBUG(DEBUG_CS)) {
630       /* Re-index SSA defs so we print more sensible numbers. */
631       nir_foreach_function_impl(impl, nir) {
632          nir_index_ssa_defs(impl);
633       }
634 
635       fprintf(stderr, "NIR (from SPIR-V) for kernel\n");
636       nir_print_shader(nir, stderr);
637    }
638 
639    nir_lower_printf_options printf_opts = {
640       .ptr_bit_size               = 64,
641       .use_printf_base_identifier = true,
642    };
643    NIR_PASS_V(nir, nir_lower_printf, &printf_opts);
644 
645    NIR_PASS_V(nir, implement_intel_builtins);
646    NIR_PASS_V(nir, nir_link_shader_functions, spirv_options.clc_shader);
647 
648    /* We have to lower away local constant initializers right before we
649     * inline functions.  That way they get properly initialized at the top
650     * of the function and not at the top of its caller.
651     */
652    NIR_PASS_V(nir, nir_lower_variable_initializers, ~(nir_var_shader_temp |
653                                                       nir_var_function_temp));
654    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_uniform | nir_var_mem_ubo |
655               nir_var_mem_constant | nir_var_function_temp | nir_var_image, NULL);
656    {
657       bool progress;
658       do
659       {
660          progress = false;
661          NIR_PASS(progress, nir, nir_copy_prop);
662          NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
663          NIR_PASS(progress, nir, nir_opt_deref);
664          NIR_PASS(progress, nir, nir_opt_dce);
665          NIR_PASS(progress, nir, nir_opt_undef);
666          NIR_PASS(progress, nir, nir_opt_constant_folding);
667          NIR_PASS(progress, nir, nir_opt_cse);
668          NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
669          NIR_PASS(progress, nir, nir_opt_algebraic);
670       } while (progress);
671    }
672 
673    NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
674    NIR_PASS_V(nir, nir_lower_returns);
675    NIR_PASS_V(nir, nir_inline_functions);
676 
677    assert(nir->scratch_size == 0);
678    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_function_temp, glsl_get_cl_type_size_align);
679 
680    {
681       bool progress;
682       do
683       {
684          progress = false;
685          NIR_PASS(progress, nir, nir_copy_prop);
686          NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
687          NIR_PASS(progress, nir, nir_opt_deref);
688          NIR_PASS(progress, nir, nir_opt_dce);
689          NIR_PASS(progress, nir, nir_opt_undef);
690          NIR_PASS(progress, nir, nir_opt_constant_folding);
691          NIR_PASS(progress, nir, nir_opt_cse);
692          NIR_PASS(progress, nir, nir_split_var_copies);
693          NIR_PASS(progress, nir, nir_lower_var_copies);
694          NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
695          NIR_PASS(progress, nir, nir_opt_algebraic);
696          NIR_PASS(progress, nir, nir_opt_if, nir_opt_if_optimize_phi_true_false);
697          NIR_PASS(progress, nir, nir_opt_dead_cf);
698          NIR_PASS(progress, nir, nir_opt_remove_phis);
699          NIR_PASS(progress, nir, nir_opt_peephole_select, 8, true, true);
700          NIR_PASS(progress, nir, nir_lower_vec3_to_vec4, nir_var_mem_generic | nir_var_uniform);
701          NIR_PASS(progress, nir, nir_opt_memcpy);
702       } while (progress);
703    }
704 
705    NIR_PASS_V(nir, nir_scale_fdiv);
706 
707    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_uniform | nir_var_mem_ubo |
708               nir_var_mem_constant | nir_var_function_temp | nir_var_image, NULL);
709 
710 
711    NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_mem_shared | nir_var_function_temp, NULL);
712 
713    nir->scratch_size = 0;
714    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
715               nir_var_mem_shared | nir_var_function_temp | nir_var_shader_temp |
716               nir_var_mem_global | nir_var_mem_constant,
717               glsl_get_cl_type_size_align);
718 
719    // Lower memcpy - needs to wait until types are sized
720    {
721       bool progress;
722       do {
723          progress = false;
724          NIR_PASS(progress, nir, nir_opt_memcpy);
725          NIR_PASS(progress, nir, nir_copy_prop);
726          NIR_PASS(progress, nir, nir_opt_copy_prop_vars);
727          NIR_PASS(progress, nir, nir_opt_deref);
728          NIR_PASS(progress, nir, nir_opt_dce);
729          NIR_PASS(progress, nir, nir_split_var_copies);
730          NIR_PASS(progress, nir, nir_lower_var_copies);
731          NIR_PASS(progress, nir, nir_lower_vars_to_ssa);
732          NIR_PASS(progress, nir, nir_opt_constant_folding);
733          NIR_PASS(progress, nir, nir_opt_cse);
734       } while (progress);
735    }
736    NIR_PASS_V(nir, nir_lower_memcpy);
737 
738    NIR_PASS_V(nir, nir_lower_explicit_io,
739               nir_var_mem_shared | nir_var_function_temp | nir_var_shader_temp | nir_var_uniform,
740               nir_address_format_32bit_offset_as_64bit);
741 
742    NIR_PASS_V(nir, nir_lower_system_values);
743 
744    /* Hopefully we can drop this once lower_vars_to_ssa has improved to not
745     * lower everything to scratch.
746     */
747    if (llvm17_wa)
748       cleanup_llvm17_scratch(nir);
749 
750    /* Lower again, this time after dead-variables to get more compact variable
751     * layouts.
752     */
753    nir->global_mem_size = 0;
754    nir->scratch_size = 0;
755    nir->info.shared_size = 0;
756    NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
757               nir_var_mem_shared | nir_var_mem_global | nir_var_mem_constant,
758               glsl_get_cl_type_size_align);
759    if (nir->constant_data_size > 0) {
760       assert(nir->constant_data == NULL);
761       nir->constant_data = rzalloc_size(nir, nir->constant_data_size);
762       nir_gather_explicit_io_initializers(nir, nir->constant_data,
763                                           nir->constant_data_size,
764                                           nir_var_mem_constant);
765    }
766 
767    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_constant,
768               nir_address_format_64bit_global);
769 
770    NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
771               nir_address_format_32bit_offset_as_64bit);
772 
773    NIR_PASS_V(nir, nir_lower_explicit_io,
774               nir_var_shader_temp | nir_var_function_temp |
775               nir_var_mem_shared | nir_var_mem_global,
776               nir_address_format_62bit_generic);
777 
778    if (INTEL_DEBUG(DEBUG_CS)) {
779       /* Re-index SSA defs so we print more sensible numbers. */
780       nir_foreach_function_impl(impl, nir) {
781          nir_index_ssa_defs(impl);
782       }
783 
784       fprintf(stderr, "NIR (before I/O lowering) for kernel\n");
785       nir_print_shader(nir, stderr);
786    }
787 
788    return nir;
789 }
790