xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/tests/algebraic_tests.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Intel Corporation
3  * Copyright © 2021 Valve Corporation
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22  * DEALINGS IN THE SOFTWARE.
23  */
24 
25 #include "nir_test.h"
26 
27 namespace {
28 
29 class algebraic_test_base : public nir_test {
30 protected:
31    algebraic_test_base();
32 
33    virtual void run_pass()=0;
34 
35    void test_op(nir_op op, nir_def *src0, nir_def *src1, nir_def *src2,
36                 nir_def *src3, const char *desc);
37 
38    void test_2src_op(nir_op op, int64_t src0, int64_t src1);
39 
40    void require_one_alu(nir_op op);
41 
42    nir_variable *res_var;
43 };
44 
algebraic_test_base()45 algebraic_test_base::algebraic_test_base()
46    : nir_test::nir_test("nir_opt_algebraic_test")
47 {
48    res_var = nir_local_variable_create(b->impl, glsl_int_type(), "res");
49 }
50 
test_op(nir_op op,nir_def * src0,nir_def * src1,nir_def * src2,nir_def * src3,const char * desc)51 void algebraic_test_base::test_op(nir_op op, nir_def *src0, nir_def *src1,
52                                      nir_def *src2, nir_def *src3, const char *desc)
53 {
54    nir_def *res_deref = &nir_build_deref_var(b, res_var)->def;
55 
56    /* create optimized expression */
57    nir_intrinsic_instr *optimized_instr = nir_build_store_deref(
58       b, res_deref, nir_build_alu(b, op, src0, src1, src2, src3), 0x1);
59 
60    run_pass();
61    b->cursor = nir_after_cf_list(&b->impl->body);
62 
63    /* create reference expression */
64    nir_intrinsic_instr *ref_instr = nir_build_store_deref(
65       b, res_deref, nir_build_alu(b, op, src0, src1, src2, src3), 0x1);
66 
67    /* test equality */
68    nir_opt_constant_folding(b->shader);
69 
70    ASSERT_TRUE(nir_src_is_const(ref_instr->src[1]));
71    ASSERT_TRUE(nir_src_is_const(optimized_instr->src[1]));
72 
73    int32_t ref = nir_src_as_int(ref_instr->src[1]);
74    int32_t optimized = nir_src_as_int(optimized_instr->src[1]);
75 
76    EXPECT_EQ(ref, optimized) << "Test input: " << desc;
77 
78    /* reset shader */
79    exec_list_make_empty(&nir_start_block(b->impl)->instr_list);
80    b->cursor = nir_after_cf_list(&b->impl->body);
81 }
82 
test_2src_op(nir_op op,int64_t src0,int64_t src1)83 void algebraic_test_base::test_2src_op(nir_op op, int64_t src0, int64_t src1)
84 {
85    char desc[128];
86    snprintf(desc, sizeof(desc), "%s(%" PRId64 ", %" PRId64 ")", nir_op_infos[op].name, src0, src1);
87    test_op(op, nir_imm_int(b, src0), nir_imm_int(b, src1), NULL, NULL, desc);
88 }
89 
require_one_alu(nir_op op)90 void algebraic_test_base::require_one_alu(nir_op op)
91 {
92    unsigned count = 0;
93    nir_foreach_instr(instr, nir_start_block(b->impl)) {
94       if (instr->type == nir_instr_type_alu) {
95          ASSERT_TRUE(nir_instr_as_alu(instr)->op == op);
96          ASSERT_EQ(count, 0);
97          count++;
98       }
99    }
100 }
101 
102 class nir_opt_algebraic_test : public algebraic_test_base {
103 protected:
run_pass()104    virtual void run_pass() {
105       nir_opt_algebraic(b->shader);
106    }
107 };
108 
109 class nir_opt_idiv_const_test : public algebraic_test_base {
110 protected:
run_pass()111    virtual void run_pass() {
112       nir_opt_idiv_const(b->shader, 8);
113    }
114 };
115 
116 class nir_opt_mqsad_test : public algebraic_test_base {
117 protected:
run_pass()118    virtual void run_pass() {
119       nir_opt_mqsad(b->shader);
120    }
121 };
122 
TEST_F(nir_opt_algebraic_test,umod_pow2_src2)123 TEST_F(nir_opt_algebraic_test, umod_pow2_src2)
124 {
125    for (int i = 0; i <= 9; i++)
126       test_2src_op(nir_op_umod, i, 4);
127    test_2src_op(nir_op_umod, UINT32_MAX, 4);
128 }
129 
TEST_F(nir_opt_algebraic_test,imod_pow2_src2)130 TEST_F(nir_opt_algebraic_test, imod_pow2_src2)
131 {
132    for (int i = -9; i <= 9; i++) {
133       test_2src_op(nir_op_imod, i, 4);
134       test_2src_op(nir_op_imod, i, -4);
135       test_2src_op(nir_op_imod, i, INT32_MIN);
136    }
137    test_2src_op(nir_op_imod, INT32_MAX, 4);
138    test_2src_op(nir_op_imod, INT32_MAX, -4);
139    test_2src_op(nir_op_imod, INT32_MIN, 4);
140    test_2src_op(nir_op_imod, INT32_MIN, -4);
141    test_2src_op(nir_op_imod, INT32_MIN, INT32_MIN);
142 }
143 
TEST_F(nir_opt_algebraic_test,irem_pow2_src2)144 TEST_F(nir_opt_algebraic_test, irem_pow2_src2)
145 {
146    for (int i = -9; i <= 9; i++) {
147       test_2src_op(nir_op_irem, i, 4);
148       test_2src_op(nir_op_irem, i, -4);
149    }
150    test_2src_op(nir_op_irem, INT32_MAX, 4);
151    test_2src_op(nir_op_irem, INT32_MAX, -4);
152    test_2src_op(nir_op_irem, INT32_MIN, 4);
153    test_2src_op(nir_op_irem, INT32_MIN, -4);
154 }
155 
TEST_F(nir_opt_algebraic_test,msad)156 TEST_F(nir_opt_algebraic_test, msad)
157 {
158    options.lower_bitfield_extract = true;
159    options.has_bfe = true;
160    options.has_msad = true;
161 
162    nir_def *src0 = nir_load_var(b, nir_local_variable_create(b->impl, glsl_int_type(), "src0"));
163    nir_def *src1 = nir_load_var(b, nir_local_variable_create(b->impl, glsl_int_type(), "src1"));
164 
165    /* This mimics the sequence created by vkd3d-proton. */
166    nir_def *res = NULL;
167    for (unsigned i = 0; i < 4; i++) {
168       nir_def *ref = nir_ubitfield_extract(b, src0, nir_imm_int(b, i * 8), nir_imm_int(b, 8));
169       nir_def *src = nir_ubitfield_extract(b, src1, nir_imm_int(b, i * 8), nir_imm_int(b, 8));
170       nir_def *is_ref_zero = nir_ieq_imm(b, ref, 0);
171       nir_def *abs_diff = nir_iabs(b, nir_isub(b, ref, src));
172       nir_def *masked_diff = nir_bcsel(b, is_ref_zero, nir_imm_int(b, 0), abs_diff);
173       if (res)
174          res = nir_iadd(b, res, masked_diff);
175       else
176          res = masked_diff;
177    }
178 
179    nir_store_var(b, res_var, res, 0x1);
180 
181    while (nir_opt_algebraic(b->shader)) {
182       nir_opt_constant_folding(b->shader);
183       nir_opt_dce(b->shader);
184    }
185 
186    require_one_alu(nir_op_msad_4x8);
187 }
188 
TEST_F(nir_opt_mqsad_test,mqsad)189 TEST_F(nir_opt_mqsad_test, mqsad)
190 {
191    options.lower_bitfield_extract = true;
192    options.has_bfe = true;
193    options.has_msad = true;
194    options.has_shfr32 = true;
195 
196    nir_def *ref = nir_load_var(b, nir_local_variable_create(b->impl, glsl_int_type(), "ref"));
197    nir_def *src = nir_load_var(b, nir_local_variable_create(b->impl, glsl_ivec_type(2), "src"));
198    nir_def *accum = nir_load_var(b, nir_local_variable_create(b->impl, glsl_ivec_type(4), "accum"));
199 
200    nir_def *srcx = nir_channel(b, src, 0);
201    nir_def *srcy = nir_channel(b, src, 1);
202 
203    nir_def *res[4];
204    for (unsigned i = 0; i < 4; i++) {
205       nir_def *src1 = srcx;
206       switch (i) {
207       case 0:
208          break;
209       case 1:
210          src1 = nir_bitfield_select(b, nir_imm_int(b, 0xff000000), nir_ishl_imm(b, srcy, 24),
211                                     nir_ushr_imm(b, srcx, 8));
212          break;
213       case 2:
214          src1 = nir_bitfield_select(b, nir_imm_int(b, 0xffff0000), nir_ishl_imm(b, srcy, 16),
215                                     nir_extract_u16(b, srcx, nir_imm_int(b, 1)));
216          break;
217       case 3:
218          src1 = nir_bitfield_select(b, nir_imm_int(b, 0xffffff00), nir_ishl_imm(b, srcy, 8),
219                                     nir_extract_u8_imm(b, srcx, 3));
220          break;
221       }
222 
223       res[i] = nir_msad_4x8(b, ref, src1, nir_channel(b, accum, i));
224    }
225 
226    nir_store_var(b, nir_local_variable_create(b->impl, glsl_ivec_type(4), "res"), nir_vec(b, res, 4), 0xf);
227 
228    while (nir_opt_algebraic(b->shader)) {
229       nir_opt_constant_folding(b->shader);
230       nir_opt_dce(b->shader);
231    }
232 
233    ASSERT_TRUE(nir_opt_mqsad(b->shader));
234    nir_copy_prop(b->shader);
235    nir_opt_dce(b->shader);
236 
237    require_one_alu(nir_op_mqsad_4x8);
238 }
239 
TEST_F(nir_opt_idiv_const_test,umod)240 TEST_F(nir_opt_idiv_const_test, umod)
241 {
242    for (uint32_t d : {16u, 17u, 0u, UINT32_MAX}) {
243       for (int i = 0; i <= 40; i++)
244          test_2src_op(nir_op_umod, i, d);
245       for (int i = 0; i < 20; i++)
246          test_2src_op(nir_op_umod, UINT32_MAX - i, d);
247    }
248 }
249 
TEST_F(nir_opt_idiv_const_test,imod)250 TEST_F(nir_opt_idiv_const_test, imod)
251 {
252    for (int32_t d : {16, -16, 17, -17, 0, INT32_MIN, INT32_MAX}) {
253       for (int i = -40; i <= 40; i++)
254          test_2src_op(nir_op_imod, i, d);
255       for (int i = 0; i < 20; i++)
256          test_2src_op(nir_op_imod, INT32_MIN + i, d);
257       for (int i = 0; i < 20; i++)
258          test_2src_op(nir_op_imod, INT32_MAX - i, d);
259    }
260 }
261 
TEST_F(nir_opt_idiv_const_test,irem)262 TEST_F(nir_opt_idiv_const_test, irem)
263 {
264    for (int32_t d : {16, -16, 17, -17, 0, INT32_MIN, INT32_MAX}) {
265       for (int i = -40; i <= 40; i++)
266          test_2src_op(nir_op_irem, i, d);
267       for (int i = 0; i < 20; i++)
268          test_2src_op(nir_op_irem, INT32_MIN + i, d);
269       for (int i = 0; i < 20; i++)
270          test_2src_op(nir_op_irem, INT32_MAX - i, d);
271    }
272 }
273 
274 }
275