1# Copyright 2022 Alyssa Rosenzweig 2# Copyright 2021 Collabora, Ltd. 3# Copyright 2016 Intel Corporation 4# SPDX-License-Identifier: MIT 5 6import argparse 7import sys 8import math 9 10a = 'a' 11b = 'b' 12c = 'c' 13d = 'd' 14e = 'e' 15 16lower_sm5_shift = [] 17 18# Our shifts differ from SM5 for the upper bits. Mask to match the NIR 19# behaviour. Because this happens as a late lowering, NIR won't optimize the 20# masking back out (that happens in the main nir_opt_algebraic). 21for s in [8, 16, 32, 64]: 22 for shift in ["ishl", "ishr", "ushr"]: 23 lower_sm5_shift += [((shift, f'a@{s}', b), 24 (shift, a, ('iand', b, s - 1)))] 25 26lower_pack = [ 27 (('pack_half_2x16_split', a, b), 28 ('pack_32_2x16_split', ('f2f16', a), ('f2f16', b))), 29 30 # We don't have 8-bit ALU, so we need to lower this. But if we lower it like 31 # this, we can at least coalesce the pack_32_2x16_split and only pay the 32 # cost of the iors and ishl. (u2u16 of 8-bit is assumed free.) 33 (('pack_32_4x8_split', a, b, c, d), 34 ('pack_32_2x16_split', ('ior', ('u2u16', a), ('ishl', ('u2u16', b), 8)), 35 ('ior', ('u2u16', c), ('ishl', ('u2u16', d), 8)))), 36 37 (('unpack_half_2x16_split_x', a), ('f2f32', ('unpack_32_2x16_split_x', a))), 38 (('unpack_half_2x16_split_y', a), ('f2f32', ('unpack_32_2x16_split_y', a))), 39 40 (('extract_u16', 'a@32', 0), ('u2u32', ('unpack_32_2x16_split_x', a))), 41 (('extract_u16', 'a@32', 1), ('u2u32', ('unpack_32_2x16_split_y', a))), 42 (('extract_i16', 'a@32', 0), ('i2i32', ('unpack_32_2x16_split_x', a))), 43 (('extract_i16', 'a@32', 1), ('i2i32', ('unpack_32_2x16_split_y', a))), 44 45 # For optimizing extract->convert sequences for unpack/pack norm 46 (('u2f32', ('u2u32', a)), ('u2f32', a)), 47 (('i2f32', ('i2i32', a)), ('i2f32', a)), 48 49 # Chew through some 8-bit before the backend has to deal with it 50 (('f2u8', a), ('u2u8', ('f2u16', a))), 51 (('f2i8', a), ('i2i8', ('f2i16', a))), 52 53 # Based on the VIR lowering 54 (('f2f16_rtz', 'a@32'), 55 ('bcsel', ('flt', ('fabs', a), ('fabs', ('f2f32', ('f2f16_rtne', a)))), 56 ('isub', ('f2f16_rtne', a), 1), ('f2f16_rtne', a))), 57 58 # These are based on the lowerings from nir_opt_algebraic, but conditioned 59 # on the number of bits not being constant. If the bit count is constant 60 # (the happy path) we can use our native instruction instead. 61 (('ibitfield_extract', 'value', 'offset', 'bits(is_not_const)'), 62 ('bcsel', ('ieq', 0, 'bits'), 63 0, 64 ('ishr', 65 ('ishl', 'value', ('isub', ('isub', 32, 'bits'), 'offset')), 66 ('isub', 32, 'bits')))), 67 68 (('ubitfield_extract', 'value', 'offset', 'bits(is_not_const)'), 69 ('iand', 70 ('ushr', 'value', 'offset'), 71 ('bcsel', ('ieq', 'bits', 32), 72 0xffffffff, 73 ('isub', ('ishl', 1, 'bits'), 1)))), 74 75 # Codegen depends on this trivial case being optimized out. 76 (('ubitfield_extract', 'value', 'offset', 0), 0), 77 (('ibitfield_extract', 'value', 'offset', 0), 0), 78 79 # At this point, bitfield extracts are constant. We can only do constant 80 # unsigned bitfield extract, so lower signed to unsigned + sign extend. 81 (('ibitfield_extract', a, b, '#bits'), 82 ('ishr', ('ishl', ('ubitfield_extract', a, b, 'bits'), ('isub', 32, 'bits')), 83 ('isub', 32, 'bits'))), 84] 85 86lower_selects = [] 87for T, sizes, one in [('f', [16, 32], 1.0), 88 ('i', [8, 16, 32], 1), 89 ('b', [16, 32], -1)]: 90 for size in sizes: 91 lower_selects.extend([ 92 ((f'b2{T}{size}', ('inot', 'a@1')), ('bcsel', a, 0, one)), 93 ((f'b2{T}{size}', 'a@1'), ('bcsel', a, one, 0)), 94 ]) 95 96# Rewriting bcsel(a || b, ...) in terms of bcsel(a, ...) and bcsel(b, ...) lets 97# our rules to fuse compare-and-select do a better job, assuming that a and b 98# are comparisons themselves. 99# 100# This needs to be a separate pass that runs after lower_selects, in order to 101# pick up patterns like b2f32(iand(...)) 102opt_selects = [ 103 (('bcsel', ('ior(is_used_once)', a, b), c, d), 104 ('bcsel', a, c, ('bcsel', b, c, d))), 105 106 (('bcsel', ('iand(is_used_once)', a, b), c, d), 107 ('bcsel', a, ('bcsel', b, c, d), d)), 108] 109 110# When the ior/iand is used multiple times, we can instead fuse the other way. 111opt_selects.extend([ 112 (('iand', 'a@1', b), ('bcsel', a, b, False)), 113 (('ior', 'a@1', b), ('bcsel', a, True, b)), 114]) 115 116fuse_extr = [] 117for start in range(32): 118 fuse_extr.extend([ 119 (('ior', ('ushr', 'a@32', start), ('ishl', 'b@32', 32 - start)), 120 ('extr_agx', a, b, start, 0)), 121 ]) 122 123fuse_ubfe = [] 124for bits in range(1, 32): 125 fuse_ubfe.extend([ 126 (('iand', ('ushr', 'a@32', b), (1 << bits) - 1), 127 ('ubitfield_extract', a, b, bits)) 128 ]) 129 130# (x * y) + s = (x * y) + (s << 0) 131def imad(x, y, z): 132 return ('imadshl_agx', x, y, z, 0) 133 134# (x * y) - s = (x * y) - (s << 0) 135def imsub(x, y, z): 136 return ('imsubshl_agx', x, y, z, 0) 137 138# x + (y << s) = (x * 1) + (y << s) 139def iaddshl(x, y, s): 140 return ('imadshl_agx', x, 1, y, s) 141 142# x - (y << s) = (x * 1) - (y << s) 143def isubshl(x, y, s): 144 return ('imsubshl_agx', x, 1, y, s) 145 146fuse_imad = [ 147 # Reassociate imul+iadd chain in order to fuse imads. This pattern comes up 148 # in compute shader lowering. 149 (('iadd', ('iadd(is_used_once)', ('imul(is_used_once)', a, b), 150 ('imul(is_used_once)', c, d)), e), 151 imad(a, b, imad(c, d, e))), 152 153 # Fuse regular imad 154 (('iadd', ('imul(is_used_once)', a, b), c), imad(a, b, c)), 155 (('isub', ('imul(is_used_once)', a, b), c), imsub(a, b, c)), 156] 157 158for s in range(1, 5): 159 fuse_imad += [ 160 # Definitions 161 (('iadd', a, ('ishl(is_used_once)', b, s)), iaddshl(a, b, s)), 162 (('isub', a, ('ishl(is_used_once)', b, s)), isubshl(a, b, s)), 163 164 # ineg(x) is 0 - x 165 (('ineg', ('ishl(is_used_once)', b, s)), isubshl(0, b, s)), 166 167 # Definitions 168 (imad(a, b, ('ishl(is_used_once)', c, s)), ('imadshl_agx', a, b, c, s)), 169 (imsub(a, b, ('ishl(is_used_once)', c, s)), ('imsubshl_agx', a, b, c, s)), 170 171 # a + (a << s) = a + a * (1 << s) = a * (1 + (1 << s)) 172 (('imul', a, 1 + (1 << s)), iaddshl(a, a, s)), 173 174 # a - (a << s) = a - a * (1 << s) = a * (1 - (1 << s)) 175 (('imul', a, 1 - (1 << s)), isubshl(a, a, s)), 176 177 # a - (a << s) = a * (1 - (1 << s)) = -(a * (1 << s) - 1) 178 (('ineg', ('imul(is_used_once)', a, (1 << s) - 1)), isubshl(a, a, s)), 179 180 # iadd is SCIB, general shfit is IC (slower) 181 (('ishl', a, s), iaddshl(0, a, s)), 182 ] 183 184# Discard lowering generates this pattern, clean it up 185ixor_bcsel = [ 186 (('ixor', ('bcsel', a, '#b', '#c'), '#d'), 187 ('bcsel', a, ('ixor', b, d), ('ixor', c, d))), 188] 189 190def main(): 191 parser = argparse.ArgumentParser() 192 parser.add_argument('-p', '--import-path', required=True) 193 args = parser.parse_args() 194 sys.path.insert(0, args.import_path) 195 run() 196 197def run(): 198 import nir_algebraic # pylint: disable=import-error 199 200 print('#include "agx_nir.h"') 201 202 print(nir_algebraic.AlgebraicPass("agx_nir_lower_algebraic_late", 203 lower_sm5_shift + lower_pack + 204 lower_selects).render()) 205 print(nir_algebraic.AlgebraicPass("agx_nir_fuse_selects", 206 opt_selects).render()) 207 print(nir_algebraic.AlgebraicPass("agx_nir_fuse_algebraic_late", 208 fuse_extr + fuse_ubfe + 209 fuse_imad + ixor_bcsel).render()) 210 211 212if __name__ == '__main__': 213 main() 214