xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_bool_to_int32.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 
27 static bool
assert_ssa_def_is_not_1bit(nir_def * def,UNUSED void * unused)28 assert_ssa_def_is_not_1bit(nir_def *def, UNUSED void *unused)
29 {
30    assert(def->bit_size > 1);
31    return true;
32 }
33 
34 static bool
rewrite_1bit_ssa_def_to_32bit(nir_def * def,void * _progress)35 rewrite_1bit_ssa_def_to_32bit(nir_def *def, void *_progress)
36 {
37    bool *progress = _progress;
38    if (def->bit_size == 1) {
39       def->bit_size = 32;
40       *progress = true;
41    }
42    return true;
43 }
44 
45 static bool
lower_alu_instr(nir_alu_instr * alu)46 lower_alu_instr(nir_alu_instr *alu)
47 {
48    const nir_op_info *op_info = &nir_op_infos[alu->op];
49 
50    switch (alu->op) {
51    case nir_op_mov:
52    case nir_op_vec2:
53    case nir_op_vec3:
54    case nir_op_vec4:
55    case nir_op_vec5:
56    case nir_op_vec8:
57    case nir_op_vec16:
58    case nir_op_inot:
59    case nir_op_iand:
60    case nir_op_ior:
61    case nir_op_ixor:
62       if (alu->def.bit_size != 1)
63          return false;
64       /* These we expect to have booleans but the opcode doesn't change */
65       break;
66 
67    case nir_op_b2b32:
68    case nir_op_b2b1:
69       /* We're mutating instructions in a dominance-preserving order so our
70        * source boolean should be 32-bit by now.
71        */
72       assert(nir_src_bit_size(alu->src[0].src) == 32);
73       alu->op = nir_op_mov;
74       break;
75 
76    case nir_op_flt:
77       alu->op = nir_op_flt32;
78       break;
79    case nir_op_fge:
80       alu->op = nir_op_fge32;
81       break;
82    case nir_op_feq:
83       alu->op = nir_op_feq32;
84       break;
85    case nir_op_fneu:
86       alu->op = nir_op_fneu32;
87       break;
88    case nir_op_ilt:
89       alu->op = nir_op_ilt32;
90       break;
91    case nir_op_ige:
92       alu->op = nir_op_ige32;
93       break;
94    case nir_op_ieq:
95       alu->op = nir_op_ieq32;
96       break;
97    case nir_op_ine:
98       alu->op = nir_op_ine32;
99       break;
100    case nir_op_ult:
101       alu->op = nir_op_ult32;
102       break;
103    case nir_op_uge:
104       alu->op = nir_op_uge32;
105       break;
106 
107    case nir_op_ball_fequal2:
108       alu->op = nir_op_b32all_fequal2;
109       break;
110    case nir_op_ball_fequal3:
111       alu->op = nir_op_b32all_fequal3;
112       break;
113    case nir_op_ball_fequal4:
114       alu->op = nir_op_b32all_fequal4;
115       break;
116    case nir_op_bany_fnequal2:
117       alu->op = nir_op_b32any_fnequal2;
118       break;
119    case nir_op_bany_fnequal3:
120       alu->op = nir_op_b32any_fnequal3;
121       break;
122    case nir_op_bany_fnequal4:
123       alu->op = nir_op_b32any_fnequal4;
124       break;
125    case nir_op_ball_iequal2:
126       alu->op = nir_op_b32all_iequal2;
127       break;
128    case nir_op_ball_iequal3:
129       alu->op = nir_op_b32all_iequal3;
130       break;
131    case nir_op_ball_iequal4:
132       alu->op = nir_op_b32all_iequal4;
133       break;
134    case nir_op_bany_inequal2:
135       alu->op = nir_op_b32any_inequal2;
136       break;
137    case nir_op_bany_inequal3:
138       alu->op = nir_op_b32any_inequal3;
139       break;
140    case nir_op_bany_inequal4:
141       alu->op = nir_op_b32any_inequal4;
142       break;
143 
144    case nir_op_bcsel:
145       alu->op = nir_op_b32csel;
146       break;
147 
148    case nir_op_fisfinite:
149       alu->op = nir_op_fisfinite32;
150       break;
151 
152    default:
153       assert(alu->def.bit_size > 1);
154       for (unsigned i = 0; i < op_info->num_inputs; i++)
155          assert(alu->src[i].src.ssa->bit_size > 1);
156       return false;
157    }
158 
159    if (alu->def.bit_size == 1)
160       alu->def.bit_size = 32;
161 
162    return true;
163 }
164 
165 static bool
lower_tex_instr(nir_tex_instr * tex)166 lower_tex_instr(nir_tex_instr *tex)
167 {
168    bool progress = false;
169    rewrite_1bit_ssa_def_to_32bit(&tex->def, &progress);
170    if (tex->dest_type == nir_type_bool1) {
171       tex->dest_type = nir_type_bool32;
172       progress = true;
173    }
174    return progress;
175 }
176 
177 static bool
nir_lower_bool_to_int32_instr(UNUSED nir_builder * b,nir_instr * instr,UNUSED void * cb_data)178 nir_lower_bool_to_int32_instr(UNUSED nir_builder *b,
179                               nir_instr *instr,
180                               UNUSED void *cb_data)
181 {
182    switch (instr->type) {
183    case nir_instr_type_alu:
184       return lower_alu_instr(nir_instr_as_alu(instr));
185 
186    case nir_instr_type_load_const: {
187       nir_load_const_instr *load = nir_instr_as_load_const(instr);
188       if (load->def.bit_size == 1) {
189          nir_const_value *value = load->value;
190          for (unsigned i = 0; i < load->def.num_components; i++)
191             load->value[i].u32 = value[i].b ? NIR_TRUE : NIR_FALSE;
192          load->def.bit_size = 32;
193          return true;
194       }
195       return false;
196    }
197 
198    case nir_instr_type_intrinsic:
199    case nir_instr_type_undef:
200    case nir_instr_type_phi: {
201       bool progress = false;
202       nir_foreach_def(instr, rewrite_1bit_ssa_def_to_32bit, &progress);
203       return progress;
204    }
205 
206    case nir_instr_type_tex:
207       return lower_tex_instr(nir_instr_as_tex(instr));
208 
209    default:
210       nir_foreach_def(instr, assert_ssa_def_is_not_1bit, NULL);
211       return false;
212    }
213 }
214 
215 bool
nir_lower_bool_to_int32(nir_shader * shader)216 nir_lower_bool_to_int32(nir_shader *shader)
217 {
218    bool progress = false;
219    nir_foreach_function(func, shader) {
220       for (unsigned idx = 0; idx < func->num_params; idx++) {
221          nir_parameter *param = &func->params[idx];
222          if (param->bit_size == 1) {
223             param->bit_size = 32;
224             progress = true;
225          }
226       }
227    }
228 
229    progress |=
230       nir_shader_instructions_pass(shader, nir_lower_bool_to_int32_instr,
231                                    nir_metadata_control_flow,
232                                    NULL);
233    return progress;
234 }
235