xref: /aosp_15_r20/external/mesa3d/src/asahi/compiler/agx_nir_algebraic.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1# Copyright 2022 Alyssa Rosenzweig
2# Copyright 2021 Collabora, Ltd.
3# Copyright 2016 Intel Corporation
4# SPDX-License-Identifier: MIT
5
6import argparse
7import sys
8import math
9
10a = 'a'
11b = 'b'
12c = 'c'
13d = 'd'
14e = 'e'
15
16lower_sm5_shift = []
17
18# Our shifts differ from SM5 for the upper bits. Mask to match the NIR
19# behaviour. Because this happens as a late lowering, NIR won't optimize the
20# masking back out (that happens in the main nir_opt_algebraic).
21for s in [8, 16, 32, 64]:
22    for shift in ["ishl", "ishr", "ushr"]:
23        lower_sm5_shift += [((shift, f'a@{s}', b),
24                             (shift, a, ('iand', b, s - 1)))]
25
26lower_pack = [
27    (('pack_half_2x16_split', a, b),
28     ('pack_32_2x16_split', ('f2f16', a), ('f2f16', b))),
29
30    # We don't have 8-bit ALU, so we need to lower this. But if we lower it like
31    # this, we can at least coalesce the pack_32_2x16_split and only pay the
32    # cost of the iors and ishl. (u2u16 of 8-bit is assumed free.)
33    (('pack_32_4x8_split', a, b, c, d),
34     ('pack_32_2x16_split', ('ior', ('u2u16', a), ('ishl', ('u2u16', b), 8)),
35                            ('ior', ('u2u16', c), ('ishl', ('u2u16', d), 8)))),
36
37    (('unpack_half_2x16_split_x', a), ('f2f32', ('unpack_32_2x16_split_x', a))),
38    (('unpack_half_2x16_split_y', a), ('f2f32', ('unpack_32_2x16_split_y', a))),
39
40    (('extract_u16', 'a@32', 0), ('u2u32', ('unpack_32_2x16_split_x', a))),
41    (('extract_u16', 'a@32', 1), ('u2u32', ('unpack_32_2x16_split_y', a))),
42    (('extract_i16', 'a@32', 0), ('i2i32', ('unpack_32_2x16_split_x', a))),
43    (('extract_i16', 'a@32', 1), ('i2i32', ('unpack_32_2x16_split_y', a))),
44
45    # For optimizing extract->convert sequences for unpack/pack norm
46    (('u2f32', ('u2u32', a)), ('u2f32', a)),
47    (('i2f32', ('i2i32', a)), ('i2f32', a)),
48
49    # Chew through some 8-bit before the backend has to deal with it
50    (('f2u8', a), ('u2u8', ('f2u16', a))),
51    (('f2i8', a), ('i2i8', ('f2i16', a))),
52
53    # Based on the VIR lowering
54    (('f2f16_rtz', 'a@32'),
55     ('bcsel', ('flt', ('fabs', a), ('fabs', ('f2f32', ('f2f16_rtne', a)))),
56      ('isub', ('f2f16_rtne', a), 1), ('f2f16_rtne', a))),
57
58    # These are based on the lowerings from nir_opt_algebraic, but conditioned
59    # on the number of bits not being constant. If the bit count is constant
60    # (the happy path) we can use our native instruction instead.
61    (('ibitfield_extract', 'value', 'offset', 'bits(is_not_const)'),
62     ('bcsel', ('ieq', 0, 'bits'),
63      0,
64      ('ishr',
65       ('ishl', 'value', ('isub', ('isub', 32, 'bits'), 'offset')),
66       ('isub', 32, 'bits')))),
67
68    (('ubitfield_extract', 'value', 'offset', 'bits(is_not_const)'),
69     ('iand',
70      ('ushr', 'value', 'offset'),
71      ('bcsel', ('ieq', 'bits', 32),
72       0xffffffff,
73       ('isub', ('ishl', 1, 'bits'), 1)))),
74
75    # Codegen depends on this trivial case being optimized out.
76    (('ubitfield_extract', 'value', 'offset', 0), 0),
77    (('ibitfield_extract', 'value', 'offset', 0), 0),
78
79    # At this point, bitfield extracts are constant. We can only do constant
80    # unsigned bitfield extract, so lower signed to unsigned + sign extend.
81    (('ibitfield_extract', a, b, '#bits'),
82     ('ishr', ('ishl', ('ubitfield_extract', a, b, 'bits'), ('isub', 32, 'bits')),
83      ('isub', 32, 'bits'))),
84]
85
86lower_selects = []
87for T, sizes, one in [('f', [16, 32], 1.0),
88                      ('i', [8, 16, 32], 1),
89                      ('b', [16, 32], -1)]:
90    for size in sizes:
91        lower_selects.extend([
92            ((f'b2{T}{size}', ('inot', 'a@1')), ('bcsel', a, 0, one)),
93            ((f'b2{T}{size}', 'a@1'), ('bcsel', a, one, 0)),
94        ])
95
96# Rewriting bcsel(a || b, ...) in terms of bcsel(a, ...) and bcsel(b, ...) lets
97# our rules to fuse compare-and-select do a better job, assuming that a and b
98# are comparisons themselves.
99#
100# This needs to be a separate pass that runs after lower_selects, in order to
101# pick up patterns like b2f32(iand(...))
102opt_selects = [
103        (('bcsel', ('ior(is_used_once)', a, b), c, d),
104         ('bcsel', a, c, ('bcsel', b, c, d))),
105
106        (('bcsel', ('iand(is_used_once)', a, b), c, d),
107         ('bcsel', a, ('bcsel', b, c, d), d)),
108]
109
110# When the ior/iand is used multiple times, we can instead fuse the other way.
111opt_selects.extend([
112        (('iand', 'a@1', b), ('bcsel', a, b, False)),
113        (('ior', 'a@1', b), ('bcsel', a, True, b)),
114])
115
116fuse_extr = []
117for start in range(32):
118    fuse_extr.extend([
119        (('ior', ('ushr', 'a@32', start), ('ishl', 'b@32', 32 - start)),
120         ('extr_agx', a, b, start, 0)),
121    ])
122
123fuse_ubfe = []
124for bits in range(1, 32):
125    fuse_ubfe.extend([
126        (('iand', ('ushr', 'a@32', b), (1 << bits) - 1),
127         ('ubitfield_extract', a, b, bits))
128    ])
129
130# (x * y) + s = (x * y) + (s << 0)
131def imad(x, y, z):
132    return ('imadshl_agx', x, y, z, 0)
133
134# (x * y) - s = (x * y) - (s << 0)
135def imsub(x, y, z):
136    return ('imsubshl_agx', x, y, z, 0)
137
138# x + (y << s) = (x * 1) + (y << s)
139def iaddshl(x, y, s):
140    return ('imadshl_agx', x, 1, y, s)
141
142# x - (y << s) = (x * 1) - (y << s)
143def isubshl(x, y, s):
144    return ('imsubshl_agx', x, 1, y, s)
145
146fuse_imad = [
147    # Reassociate imul+iadd chain in order to fuse imads. This pattern comes up
148    # in compute shader lowering.
149    (('iadd', ('iadd(is_used_once)', ('imul(is_used_once)', a, b),
150              ('imul(is_used_once)', c, d)), e),
151     imad(a, b, imad(c, d, e))),
152
153    # Fuse regular imad
154    (('iadd', ('imul(is_used_once)', a, b), c), imad(a, b, c)),
155    (('isub', ('imul(is_used_once)', a, b), c), imsub(a, b, c)),
156]
157
158for s in range(1, 5):
159    fuse_imad += [
160        # Definitions
161        (('iadd', a, ('ishl(is_used_once)', b, s)), iaddshl(a, b, s)),
162        (('isub', a, ('ishl(is_used_once)', b, s)), isubshl(a, b, s)),
163
164        # ineg(x) is 0 - x
165        (('ineg', ('ishl(is_used_once)', b, s)), isubshl(0, b, s)),
166
167        # Definitions
168        (imad(a, b, ('ishl(is_used_once)', c, s)), ('imadshl_agx', a, b, c, s)),
169        (imsub(a, b, ('ishl(is_used_once)', c, s)), ('imsubshl_agx', a, b, c, s)),
170
171        # a + (a << s) = a + a * (1 << s) = a * (1 + (1 << s))
172        (('imul', a, 1 + (1 << s)), iaddshl(a, a, s)),
173
174        # a - (a << s) = a - a * (1 << s) = a * (1 - (1 << s))
175        (('imul', a, 1 - (1 << s)), isubshl(a, a, s)),
176
177        # a - (a << s) = a * (1 - (1 << s)) = -(a * (1 << s) - 1)
178        (('ineg', ('imul(is_used_once)', a, (1 << s) - 1)), isubshl(a, a, s)),
179
180        # iadd is SCIB, general shfit is IC (slower)
181        (('ishl', a, s), iaddshl(0, a, s)),
182    ]
183
184# Discard lowering generates this pattern, clean it up
185ixor_bcsel = [
186   (('ixor', ('bcsel', a, '#b', '#c'), '#d'),
187    ('bcsel', a, ('ixor', b, d), ('ixor', c, d))),
188]
189
190def main():
191    parser = argparse.ArgumentParser()
192    parser.add_argument('-p', '--import-path', required=True)
193    args = parser.parse_args()
194    sys.path.insert(0, args.import_path)
195    run()
196
197def run():
198    import nir_algebraic  # pylint: disable=import-error
199
200    print('#include "agx_nir.h"')
201
202    print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late",
203                                      lower_sm5_shift + lower_pack +
204                                      lower_selects).render())
205    print(nir_algebraic.AlgebraicPass("agx_nir_fuse_selects",
206                                      opt_selects).render())
207    print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late",
208                                      fuse_extr + fuse_ubfe +
209                                      fuse_imad + ixor_bcsel).render())
210
211
212if __name__ == '__main__':
213    main()
214