1*61046927SAndroid Build Coastguard Worker# 2*61046927SAndroid Build Coastguard Worker# Copyright (C) 2014 Intel Corporation 3*61046927SAndroid Build Coastguard Worker# 4*61046927SAndroid Build Coastguard Worker# Permission is hereby granted, free of charge, to any person obtaining a 5*61046927SAndroid Build Coastguard Worker# copy of this software and associated documentation files (the "Software"), 6*61046927SAndroid Build Coastguard Worker# to deal in the Software without restriction, including without limitation 7*61046927SAndroid Build Coastguard Worker# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8*61046927SAndroid Build Coastguard Worker# and/or sell copies of the Software, and to permit persons to whom the 9*61046927SAndroid Build Coastguard Worker# Software is furnished to do so, subject to the following conditions: 10*61046927SAndroid Build Coastguard Worker# 11*61046927SAndroid Build Coastguard Worker# The above copyright notice and this permission notice (including the next 12*61046927SAndroid Build Coastguard Worker# paragraph) shall be included in all copies or substantial portions of the 13*61046927SAndroid Build Coastguard Worker# Software. 14*61046927SAndroid Build Coastguard Worker# 15*61046927SAndroid Build Coastguard Worker# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16*61046927SAndroid Build Coastguard Worker# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17*61046927SAndroid Build Coastguard Worker# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18*61046927SAndroid Build Coastguard Worker# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19*61046927SAndroid Build Coastguard Worker# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20*61046927SAndroid Build Coastguard Worker# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21*61046927SAndroid Build Coastguard Worker# IN THE SOFTWARE. 22*61046927SAndroid Build Coastguard Worker 23*61046927SAndroid Build Coastguard Workerimport ast 24*61046927SAndroid Build Coastguard Workerfrom collections import defaultdict 25*61046927SAndroid Build Coastguard Workerimport itertools 26*61046927SAndroid Build Coastguard Workerimport struct 27*61046927SAndroid Build Coastguard Workerimport sys 28*61046927SAndroid Build Coastguard Workerimport mako.template 29*61046927SAndroid Build Coastguard Workerimport re 30*61046927SAndroid Build Coastguard Workerimport traceback 31*61046927SAndroid Build Coastguard Worker 32*61046927SAndroid Build Coastguard Workerfrom nir_opcodes import opcodes, type_sizes 33*61046927SAndroid Build Coastguard Worker 34*61046927SAndroid Build Coastguard Worker# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 35*61046927SAndroid Build Coastguard Workernir_search_max_comm_ops = 8 36*61046927SAndroid Build Coastguard Worker 37*61046927SAndroid Build Coastguard Worker# These opcodes are only employed by nir_search. This provides a mapping from 38*61046927SAndroid Build Coastguard Worker# opcode to destination type. 39*61046927SAndroid Build Coastguard Workerconv_opcode_types = { 40*61046927SAndroid Build Coastguard Worker 'i2f' : 'float', 41*61046927SAndroid Build Coastguard Worker 'u2f' : 'float', 42*61046927SAndroid Build Coastguard Worker 'f2f' : 'float', 43*61046927SAndroid Build Coastguard Worker 'f2u' : 'uint', 44*61046927SAndroid Build Coastguard Worker 'f2i' : 'int', 45*61046927SAndroid Build Coastguard Worker 'u2u' : 'uint', 46*61046927SAndroid Build Coastguard Worker 'i2i' : 'int', 47*61046927SAndroid Build Coastguard Worker 'b2f' : 'float', 48*61046927SAndroid Build Coastguard Worker 'b2i' : 'int', 49*61046927SAndroid Build Coastguard Worker 'i2b' : 'bool', 50*61046927SAndroid Build Coastguard Worker 'f2b' : 'bool', 51*61046927SAndroid Build Coastguard Worker} 52*61046927SAndroid Build Coastguard Worker 53*61046927SAndroid Build Coastguard Workerdef get_cond_index(conds, cond): 54*61046927SAndroid Build Coastguard Worker if cond: 55*61046927SAndroid Build Coastguard Worker if cond in conds: 56*61046927SAndroid Build Coastguard Worker return conds[cond] 57*61046927SAndroid Build Coastguard Worker else: 58*61046927SAndroid Build Coastguard Worker cond_index = len(conds) 59*61046927SAndroid Build Coastguard Worker conds[cond] = cond_index 60*61046927SAndroid Build Coastguard Worker return cond_index 61*61046927SAndroid Build Coastguard Worker else: 62*61046927SAndroid Build Coastguard Worker return -1 63*61046927SAndroid Build Coastguard Worker 64*61046927SAndroid Build Coastguard Workerdef get_c_opcode(op): 65*61046927SAndroid Build Coastguard Worker if op in conv_opcode_types: 66*61046927SAndroid Build Coastguard Worker return 'nir_search_op_' + op 67*61046927SAndroid Build Coastguard Worker else: 68*61046927SAndroid Build Coastguard Worker return 'nir_op_' + op 69*61046927SAndroid Build Coastguard Worker 70*61046927SAndroid Build Coastguard Worker_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 71*61046927SAndroid Build Coastguard Worker 72*61046927SAndroid Build Coastguard Workerdef type_bits(type_str): 73*61046927SAndroid Build Coastguard Worker m = _type_re.match(type_str) 74*61046927SAndroid Build Coastguard Worker assert m.group('type') 75*61046927SAndroid Build Coastguard Worker 76*61046927SAndroid Build Coastguard Worker if m.group('bits') is None: 77*61046927SAndroid Build Coastguard Worker return 0 78*61046927SAndroid Build Coastguard Worker else: 79*61046927SAndroid Build Coastguard Worker return int(m.group('bits')) 80*61046927SAndroid Build Coastguard Worker 81*61046927SAndroid Build Coastguard Worker# Represents a set of variables, each with a unique id 82*61046927SAndroid Build Coastguard Workerclass VarSet(object): 83*61046927SAndroid Build Coastguard Worker def __init__(self): 84*61046927SAndroid Build Coastguard Worker self.names = {} 85*61046927SAndroid Build Coastguard Worker self.ids = itertools.count() 86*61046927SAndroid Build Coastguard Worker self.immutable = False; 87*61046927SAndroid Build Coastguard Worker 88*61046927SAndroid Build Coastguard Worker def __getitem__(self, name): 89*61046927SAndroid Build Coastguard Worker if name not in self.names: 90*61046927SAndroid Build Coastguard Worker assert not self.immutable, "Unknown replacement variable: " + name 91*61046927SAndroid Build Coastguard Worker self.names[name] = next(self.ids) 92*61046927SAndroid Build Coastguard Worker 93*61046927SAndroid Build Coastguard Worker return self.names[name] 94*61046927SAndroid Build Coastguard Worker 95*61046927SAndroid Build Coastguard Worker def lock(self): 96*61046927SAndroid Build Coastguard Worker self.immutable = True 97*61046927SAndroid Build Coastguard Worker 98*61046927SAndroid Build Coastguard Workerclass SearchExpression(object): 99*61046927SAndroid Build Coastguard Worker def __init__(self, expr): 100*61046927SAndroid Build Coastguard Worker self.opcode = expr[0] 101*61046927SAndroid Build Coastguard Worker self.sources = expr[1:] 102*61046927SAndroid Build Coastguard Worker self.ignore_exact = False 103*61046927SAndroid Build Coastguard Worker 104*61046927SAndroid Build Coastguard Worker @staticmethod 105*61046927SAndroid Build Coastguard Worker def create(val): 106*61046927SAndroid Build Coastguard Worker if isinstance(val, tuple): 107*61046927SAndroid Build Coastguard Worker return SearchExpression(val) 108*61046927SAndroid Build Coastguard Worker else: 109*61046927SAndroid Build Coastguard Worker assert(isinstance(val, SearchExpression)) 110*61046927SAndroid Build Coastguard Worker return val 111*61046927SAndroid Build Coastguard Worker 112*61046927SAndroid Build Coastguard Worker def __repr__(self): 113*61046927SAndroid Build Coastguard Worker l = [self.opcode, *self.sources] 114*61046927SAndroid Build Coastguard Worker if self.ignore_exact: 115*61046927SAndroid Build Coastguard Worker l.append('ignore_exact') 116*61046927SAndroid Build Coastguard Worker return repr((*l,)) 117*61046927SAndroid Build Coastguard Worker 118*61046927SAndroid Build Coastguard Workerclass Value(object): 119*61046927SAndroid Build Coastguard Worker @staticmethod 120*61046927SAndroid Build Coastguard Worker def create(val, name_base, varset, algebraic_pass): 121*61046927SAndroid Build Coastguard Worker if isinstance(val, bytes): 122*61046927SAndroid Build Coastguard Worker val = val.decode('utf-8') 123*61046927SAndroid Build Coastguard Worker 124*61046927SAndroid Build Coastguard Worker if isinstance(val, tuple) or isinstance(val, SearchExpression): 125*61046927SAndroid Build Coastguard Worker return Expression(val, name_base, varset, algebraic_pass) 126*61046927SAndroid Build Coastguard Worker elif isinstance(val, Expression): 127*61046927SAndroid Build Coastguard Worker return val 128*61046927SAndroid Build Coastguard Worker elif isinstance(val, str): 129*61046927SAndroid Build Coastguard Worker return Variable(val, name_base, varset, algebraic_pass) 130*61046927SAndroid Build Coastguard Worker elif isinstance(val, (bool, float, int)): 131*61046927SAndroid Build Coastguard Worker return Constant(val, name_base) 132*61046927SAndroid Build Coastguard Worker 133*61046927SAndroid Build Coastguard Worker def __init__(self, val, name, type_str): 134*61046927SAndroid Build Coastguard Worker self.in_val = str(val) 135*61046927SAndroid Build Coastguard Worker self.name = name 136*61046927SAndroid Build Coastguard Worker self.type_str = type_str 137*61046927SAndroid Build Coastguard Worker 138*61046927SAndroid Build Coastguard Worker def __str__(self): 139*61046927SAndroid Build Coastguard Worker return self.in_val 140*61046927SAndroid Build Coastguard Worker 141*61046927SAndroid Build Coastguard Worker def get_bit_size(self): 142*61046927SAndroid Build Coastguard Worker """Get the physical bit-size that has been chosen for this value, or if 143*61046927SAndroid Build Coastguard Worker there is none, the canonical value which currently represents this 144*61046927SAndroid Build Coastguard Worker bit-size class. Variables will be preferred, i.e. if there are any 145*61046927SAndroid Build Coastguard Worker variables in the equivalence class, the canonical value will be a 146*61046927SAndroid Build Coastguard Worker variable. We do this since we'll need to know which variable each value 147*61046927SAndroid Build Coastguard Worker is equivalent to when constructing the replacement expression. This is 148*61046927SAndroid Build Coastguard Worker the "find" part of the union-find algorithm. 149*61046927SAndroid Build Coastguard Worker """ 150*61046927SAndroid Build Coastguard Worker bit_size = self 151*61046927SAndroid Build Coastguard Worker 152*61046927SAndroid Build Coastguard Worker while isinstance(bit_size, Value): 153*61046927SAndroid Build Coastguard Worker if bit_size._bit_size is None: 154*61046927SAndroid Build Coastguard Worker break 155*61046927SAndroid Build Coastguard Worker bit_size = bit_size._bit_size 156*61046927SAndroid Build Coastguard Worker 157*61046927SAndroid Build Coastguard Worker if bit_size is not self: 158*61046927SAndroid Build Coastguard Worker self._bit_size = bit_size 159*61046927SAndroid Build Coastguard Worker return bit_size 160*61046927SAndroid Build Coastguard Worker 161*61046927SAndroid Build Coastguard Worker def set_bit_size(self, other): 162*61046927SAndroid Build Coastguard Worker """Make self.get_bit_size() return what other.get_bit_size() return 163*61046927SAndroid Build Coastguard Worker before calling this, or just "other" if it's a concrete bit-size. This is 164*61046927SAndroid Build Coastguard Worker the "union" part of the union-find algorithm. 165*61046927SAndroid Build Coastguard Worker """ 166*61046927SAndroid Build Coastguard Worker 167*61046927SAndroid Build Coastguard Worker self_bit_size = self.get_bit_size() 168*61046927SAndroid Build Coastguard Worker other_bit_size = other if isinstance(other, int) else other.get_bit_size() 169*61046927SAndroid Build Coastguard Worker 170*61046927SAndroid Build Coastguard Worker if self_bit_size == other_bit_size: 171*61046927SAndroid Build Coastguard Worker return 172*61046927SAndroid Build Coastguard Worker 173*61046927SAndroid Build Coastguard Worker self_bit_size._bit_size = other_bit_size 174*61046927SAndroid Build Coastguard Worker 175*61046927SAndroid Build Coastguard Worker @property 176*61046927SAndroid Build Coastguard Worker def type_enum(self): 177*61046927SAndroid Build Coastguard Worker return "nir_search_value_" + self.type_str 178*61046927SAndroid Build Coastguard Worker 179*61046927SAndroid Build Coastguard Worker @property 180*61046927SAndroid Build Coastguard Worker def c_bit_size(self): 181*61046927SAndroid Build Coastguard Worker bit_size = self.get_bit_size() 182*61046927SAndroid Build Coastguard Worker if isinstance(bit_size, int): 183*61046927SAndroid Build Coastguard Worker return bit_size 184*61046927SAndroid Build Coastguard Worker elif isinstance(bit_size, Variable): 185*61046927SAndroid Build Coastguard Worker return -bit_size.index - 1 186*61046927SAndroid Build Coastguard Worker else: 187*61046927SAndroid Build Coastguard Worker # If the bit-size class is neither a variable, nor an actual bit-size, then 188*61046927SAndroid Build Coastguard Worker # - If it's in the search expression, we don't need to check anything 189*61046927SAndroid Build Coastguard Worker # - If it's in the replace expression, either it's ambiguous (in which 190*61046927SAndroid Build Coastguard Worker # case we'd reject it), or it equals the bit-size of the search value 191*61046927SAndroid Build Coastguard Worker # We represent these cases with a 0 bit-size. 192*61046927SAndroid Build Coastguard Worker return 0 193*61046927SAndroid Build Coastguard Worker 194*61046927SAndroid Build Coastguard Worker __template = mako.template.Template(""" { .${val.type_str} = { 195*61046927SAndroid Build Coastguard Worker { ${val.type_enum}, ${val.c_bit_size} }, 196*61046927SAndroid Build Coastguard Worker% if isinstance(val, Constant): 197*61046927SAndroid Build Coastguard Worker ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 198*61046927SAndroid Build Coastguard Worker% elif isinstance(val, Variable): 199*61046927SAndroid Build Coastguard Worker ${val.index}, /* ${val.var_name} */ 200*61046927SAndroid Build Coastguard Worker ${'true' if val.is_constant else 'false'}, 201*61046927SAndroid Build Coastguard Worker ${val.type() or 'nir_type_invalid' }, 202*61046927SAndroid Build Coastguard Worker ${val.cond_index}, 203*61046927SAndroid Build Coastguard Worker ${val.swizzle()}, 204*61046927SAndroid Build Coastguard Worker% elif isinstance(val, Expression): 205*61046927SAndroid Build Coastguard Worker ${'true' if val.inexact else 'false'}, 206*61046927SAndroid Build Coastguard Worker ${'true' if val.exact else 'false'}, 207*61046927SAndroid Build Coastguard Worker ${'true' if val.ignore_exact else 'false'}, 208*61046927SAndroid Build Coastguard Worker ${'true' if val.nsz else 'false'}, 209*61046927SAndroid Build Coastguard Worker ${'true' if val.nnan else 'false'}, 210*61046927SAndroid Build Coastguard Worker ${'true' if val.ninf else 'false'}, 211*61046927SAndroid Build Coastguard Worker ${val.c_opcode()}, 212*61046927SAndroid Build Coastguard Worker ${val.comm_expr_idx}, ${val.comm_exprs}, 213*61046927SAndroid Build Coastguard Worker { ${', '.join(src.array_index for src in val.sources)} }, 214*61046927SAndroid Build Coastguard Worker ${val.cond_index}, 215*61046927SAndroid Build Coastguard Worker% endif 216*61046927SAndroid Build Coastguard Worker } }, 217*61046927SAndroid Build Coastguard Worker""") 218*61046927SAndroid Build Coastguard Worker 219*61046927SAndroid Build Coastguard Worker def render(self, cache): 220*61046927SAndroid Build Coastguard Worker struct_init = self.__template.render(val=self, 221*61046927SAndroid Build Coastguard Worker Constant=Constant, 222*61046927SAndroid Build Coastguard Worker Variable=Variable, 223*61046927SAndroid Build Coastguard Worker Expression=Expression) 224*61046927SAndroid Build Coastguard Worker if struct_init in cache: 225*61046927SAndroid Build Coastguard Worker # If it's in the cache, register a name remap in the cache and render 226*61046927SAndroid Build Coastguard Worker # only a comment saying it's been remapped 227*61046927SAndroid Build Coastguard Worker self.array_index = cache[struct_init] 228*61046927SAndroid Build Coastguard Worker return " /* {} -> {} in the cache */\n".format(self.name, 229*61046927SAndroid Build Coastguard Worker cache[struct_init]) 230*61046927SAndroid Build Coastguard Worker else: 231*61046927SAndroid Build Coastguard Worker self.array_index = str(cache["next_index"]) 232*61046927SAndroid Build Coastguard Worker cache[struct_init] = self.array_index 233*61046927SAndroid Build Coastguard Worker cache["next_index"] += 1 234*61046927SAndroid Build Coastguard Worker return struct_init 235*61046927SAndroid Build Coastguard Worker 236*61046927SAndroid Build Coastguard Worker_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 237*61046927SAndroid Build Coastguard Worker 238*61046927SAndroid Build Coastguard Workerclass Constant(Value): 239*61046927SAndroid Build Coastguard Worker def __init__(self, val, name): 240*61046927SAndroid Build Coastguard Worker Value.__init__(self, val, name, "constant") 241*61046927SAndroid Build Coastguard Worker 242*61046927SAndroid Build Coastguard Worker if isinstance(val, (str)): 243*61046927SAndroid Build Coastguard Worker m = _constant_re.match(val) 244*61046927SAndroid Build Coastguard Worker self.value = ast.literal_eval(m.group('value')) 245*61046927SAndroid Build Coastguard Worker self._bit_size = int(m.group('bits')) if m.group('bits') else None 246*61046927SAndroid Build Coastguard Worker else: 247*61046927SAndroid Build Coastguard Worker self.value = val 248*61046927SAndroid Build Coastguard Worker self._bit_size = None 249*61046927SAndroid Build Coastguard Worker 250*61046927SAndroid Build Coastguard Worker if isinstance(self.value, bool): 251*61046927SAndroid Build Coastguard Worker assert self._bit_size is None or self._bit_size == 1 252*61046927SAndroid Build Coastguard Worker self._bit_size = 1 253*61046927SAndroid Build Coastguard Worker 254*61046927SAndroid Build Coastguard Worker def hex(self): 255*61046927SAndroid Build Coastguard Worker if isinstance(self.value, (bool)): 256*61046927SAndroid Build Coastguard Worker return 'NIR_TRUE' if self.value else 'NIR_FALSE' 257*61046927SAndroid Build Coastguard Worker if isinstance(self.value, int): 258*61046927SAndroid Build Coastguard Worker # Explicitly sign-extend negative integers to 64-bit, ensuring correct 259*61046927SAndroid Build Coastguard Worker # handling of -INT32_MIN which is not representable in 32-bit. 260*61046927SAndroid Build Coastguard Worker if self.value < 0: 261*61046927SAndroid Build Coastguard Worker return hex(struct.unpack('Q', struct.pack('q', self.value))[0]) + 'ull' 262*61046927SAndroid Build Coastguard Worker else: 263*61046927SAndroid Build Coastguard Worker return hex(self.value) + 'ull' 264*61046927SAndroid Build Coastguard Worker elif isinstance(self.value, float): 265*61046927SAndroid Build Coastguard Worker return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) + 'ull' 266*61046927SAndroid Build Coastguard Worker else: 267*61046927SAndroid Build Coastguard Worker assert False 268*61046927SAndroid Build Coastguard Worker 269*61046927SAndroid Build Coastguard Worker def type(self): 270*61046927SAndroid Build Coastguard Worker if isinstance(self.value, (bool)): 271*61046927SAndroid Build Coastguard Worker return "nir_type_bool" 272*61046927SAndroid Build Coastguard Worker elif isinstance(self.value, int): 273*61046927SAndroid Build Coastguard Worker return "nir_type_int" 274*61046927SAndroid Build Coastguard Worker elif isinstance(self.value, float): 275*61046927SAndroid Build Coastguard Worker return "nir_type_float" 276*61046927SAndroid Build Coastguard Worker 277*61046927SAndroid Build Coastguard Worker def equivalent(self, other): 278*61046927SAndroid Build Coastguard Worker """Check that two constants are equivalent. 279*61046927SAndroid Build Coastguard Worker 280*61046927SAndroid Build Coastguard Worker This is check is much weaker than equality. One generally cannot be 281*61046927SAndroid Build Coastguard Worker used in place of the other. Using this implementation for the __eq__ 282*61046927SAndroid Build Coastguard Worker will break BitSizeValidator. 283*61046927SAndroid Build Coastguard Worker 284*61046927SAndroid Build Coastguard Worker """ 285*61046927SAndroid Build Coastguard Worker if not isinstance(other, type(self)): 286*61046927SAndroid Build Coastguard Worker return False 287*61046927SAndroid Build Coastguard Worker 288*61046927SAndroid Build Coastguard Worker return self.value == other.value 289*61046927SAndroid Build Coastguard Worker 290*61046927SAndroid Build Coastguard Worker# The $ at the end forces there to be an error if any part of the string 291*61046927SAndroid Build Coastguard Worker# doesn't match one of the field patterns. 292*61046927SAndroid Build Coastguard Worker_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 293*61046927SAndroid Build Coastguard Worker r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 294*61046927SAndroid Build Coastguard Worker r"(?P<cond>\([^\)]+\))?" 295*61046927SAndroid Build Coastguard Worker r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?" 296*61046927SAndroid Build Coastguard Worker r"$") 297*61046927SAndroid Build Coastguard Worker 298*61046927SAndroid Build Coastguard Workerclass Variable(Value): 299*61046927SAndroid Build Coastguard Worker def __init__(self, val, name, varset, algebraic_pass): 300*61046927SAndroid Build Coastguard Worker Value.__init__(self, val, name, "variable") 301*61046927SAndroid Build Coastguard Worker 302*61046927SAndroid Build Coastguard Worker m = _var_name_re.match(val) 303*61046927SAndroid Build Coastguard Worker assert m and m.group('name') is not None, \ 304*61046927SAndroid Build Coastguard Worker "Malformed variable name \"{}\".".format(val) 305*61046927SAndroid Build Coastguard Worker 306*61046927SAndroid Build Coastguard Worker self.var_name = m.group('name') 307*61046927SAndroid Build Coastguard Worker 308*61046927SAndroid Build Coastguard Worker # Prevent common cases where someone puts quotes around a literal 309*61046927SAndroid Build Coastguard Worker # constant. If we want to support names that have numeric or 310*61046927SAndroid Build Coastguard Worker # punctuation characters, we can me the first assertion more flexible. 311*61046927SAndroid Build Coastguard Worker assert self.var_name.isalpha() 312*61046927SAndroid Build Coastguard Worker assert self.var_name != 'True' 313*61046927SAndroid Build Coastguard Worker assert self.var_name != 'False' 314*61046927SAndroid Build Coastguard Worker 315*61046927SAndroid Build Coastguard Worker self.is_constant = m.group('const') is not None 316*61046927SAndroid Build Coastguard Worker self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond')) 317*61046927SAndroid Build Coastguard Worker self.required_type = m.group('type') 318*61046927SAndroid Build Coastguard Worker self._bit_size = int(m.group('bits')) if m.group('bits') else None 319*61046927SAndroid Build Coastguard Worker self.swiz = m.group('swiz') 320*61046927SAndroid Build Coastguard Worker 321*61046927SAndroid Build Coastguard Worker if self.required_type == 'bool': 322*61046927SAndroid Build Coastguard Worker if self._bit_size is not None: 323*61046927SAndroid Build Coastguard Worker assert self._bit_size in type_sizes(self.required_type) 324*61046927SAndroid Build Coastguard Worker else: 325*61046927SAndroid Build Coastguard Worker self._bit_size = 1 326*61046927SAndroid Build Coastguard Worker 327*61046927SAndroid Build Coastguard Worker if self.required_type is not None: 328*61046927SAndroid Build Coastguard Worker assert self.required_type in ('float', 'bool', 'int', 'uint') 329*61046927SAndroid Build Coastguard Worker 330*61046927SAndroid Build Coastguard Worker self.index = varset[self.var_name] 331*61046927SAndroid Build Coastguard Worker 332*61046927SAndroid Build Coastguard Worker def type(self): 333*61046927SAndroid Build Coastguard Worker if self.required_type == 'bool': 334*61046927SAndroid Build Coastguard Worker return "nir_type_bool" 335*61046927SAndroid Build Coastguard Worker elif self.required_type in ('int', 'uint'): 336*61046927SAndroid Build Coastguard Worker return "nir_type_int" 337*61046927SAndroid Build Coastguard Worker elif self.required_type == 'float': 338*61046927SAndroid Build Coastguard Worker return "nir_type_float" 339*61046927SAndroid Build Coastguard Worker 340*61046927SAndroid Build Coastguard Worker def equivalent(self, other): 341*61046927SAndroid Build Coastguard Worker """Check that two variables are equivalent. 342*61046927SAndroid Build Coastguard Worker 343*61046927SAndroid Build Coastguard Worker This is check is much weaker than equality. One generally cannot be 344*61046927SAndroid Build Coastguard Worker used in place of the other. Using this implementation for the __eq__ 345*61046927SAndroid Build Coastguard Worker will break BitSizeValidator. 346*61046927SAndroid Build Coastguard Worker 347*61046927SAndroid Build Coastguard Worker """ 348*61046927SAndroid Build Coastguard Worker if not isinstance(other, type(self)): 349*61046927SAndroid Build Coastguard Worker return False 350*61046927SAndroid Build Coastguard Worker 351*61046927SAndroid Build Coastguard Worker return self.index == other.index 352*61046927SAndroid Build Coastguard Worker 353*61046927SAndroid Build Coastguard Worker def swizzle(self): 354*61046927SAndroid Build Coastguard Worker if self.swiz is not None: 355*61046927SAndroid Build Coastguard Worker swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 356*61046927SAndroid Build Coastguard Worker 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 357*61046927SAndroid Build Coastguard Worker 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 358*61046927SAndroid Build Coastguard Worker 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 359*61046927SAndroid Build Coastguard Worker 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 360*61046927SAndroid Build Coastguard Worker return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 361*61046927SAndroid Build Coastguard Worker return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 362*61046927SAndroid Build Coastguard Worker 363*61046927SAndroid Build Coastguard Worker_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 364*61046927SAndroid Build Coastguard Worker r"(?P<cond>\([^\)]+\))?") 365*61046927SAndroid Build Coastguard Worker 366*61046927SAndroid Build Coastguard Workerclass Expression(Value): 367*61046927SAndroid Build Coastguard Worker def __init__(self, expr, name_base, varset, algebraic_pass): 368*61046927SAndroid Build Coastguard Worker Value.__init__(self, expr, name_base, "expression") 369*61046927SAndroid Build Coastguard Worker 370*61046927SAndroid Build Coastguard Worker expr = SearchExpression.create(expr) 371*61046927SAndroid Build Coastguard Worker 372*61046927SAndroid Build Coastguard Worker m = _opcode_re.match(expr.opcode) 373*61046927SAndroid Build Coastguard Worker assert m and m.group('opcode') is not None 374*61046927SAndroid Build Coastguard Worker 375*61046927SAndroid Build Coastguard Worker self.opcode = m.group('opcode') 376*61046927SAndroid Build Coastguard Worker self._bit_size = int(m.group('bits')) if m.group('bits') else None 377*61046927SAndroid Build Coastguard Worker self.inexact = m.group('inexact') is not None 378*61046927SAndroid Build Coastguard Worker self.exact = m.group('exact') is not None 379*61046927SAndroid Build Coastguard Worker self.ignore_exact = expr.ignore_exact 380*61046927SAndroid Build Coastguard Worker self.cond = m.group('cond') 381*61046927SAndroid Build Coastguard Worker 382*61046927SAndroid Build Coastguard Worker assert not self.inexact or not self.exact, \ 383*61046927SAndroid Build Coastguard Worker 'Expression cannot be both exact and inexact.' 384*61046927SAndroid Build Coastguard Worker 385*61046927SAndroid Build Coastguard Worker # "many-comm-expr" isn't really a condition. It's notification to the 386*61046927SAndroid Build Coastguard Worker # generator that this pattern is known to have too many commutative 387*61046927SAndroid Build Coastguard Worker # expressions, and an error should not be generated for this case. 388*61046927SAndroid Build Coastguard Worker # nsz, nnan and ninf are special conditions, so we treat them specially too. 389*61046927SAndroid Build Coastguard Worker cond = {k: True for k in self.cond[1:-1].split(",")} if self.cond else {} 390*61046927SAndroid Build Coastguard Worker self.many_commutative_expressions = cond.pop('many-comm-expr', False) 391*61046927SAndroid Build Coastguard Worker self.nsz = cond.pop('nsz', False) 392*61046927SAndroid Build Coastguard Worker self.nnan = cond.pop('nnan', False) 393*61046927SAndroid Build Coastguard Worker self.ninf = cond.pop('ninf', False) 394*61046927SAndroid Build Coastguard Worker 395*61046927SAndroid Build Coastguard Worker assert len(cond) <= 1 396*61046927SAndroid Build Coastguard Worker self.cond = cond.popitem()[0] if cond else None 397*61046927SAndroid Build Coastguard Worker 398*61046927SAndroid Build Coastguard Worker # Deduplicate references to the condition functions for the expressions 399*61046927SAndroid Build Coastguard Worker # and save the index for the order they were added. 400*61046927SAndroid Build Coastguard Worker self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond) 401*61046927SAndroid Build Coastguard Worker 402*61046927SAndroid Build Coastguard Worker self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass) 403*61046927SAndroid Build Coastguard Worker for (i, src) in enumerate(expr.sources) ] 404*61046927SAndroid Build Coastguard Worker 405*61046927SAndroid Build Coastguard Worker # nir_search_expression::srcs is hard-coded to 4 406*61046927SAndroid Build Coastguard Worker assert len(self.sources) <= 4 407*61046927SAndroid Build Coastguard Worker 408*61046927SAndroid Build Coastguard Worker if self.opcode in conv_opcode_types: 409*61046927SAndroid Build Coastguard Worker assert self._bit_size is None, \ 410*61046927SAndroid Build Coastguard Worker 'Expression cannot use an unsized conversion opcode with ' \ 411*61046927SAndroid Build Coastguard Worker 'an explicit size; that\'s silly.' 412*61046927SAndroid Build Coastguard Worker 413*61046927SAndroid Build Coastguard Worker self.__index_comm_exprs(0) 414*61046927SAndroid Build Coastguard Worker 415*61046927SAndroid Build Coastguard Worker def equivalent(self, other): 416*61046927SAndroid Build Coastguard Worker """Check that two variables are equivalent. 417*61046927SAndroid Build Coastguard Worker 418*61046927SAndroid Build Coastguard Worker This is check is much weaker than equality. One generally cannot be 419*61046927SAndroid Build Coastguard Worker used in place of the other. Using this implementation for the __eq__ 420*61046927SAndroid Build Coastguard Worker will break BitSizeValidator. 421*61046927SAndroid Build Coastguard Worker 422*61046927SAndroid Build Coastguard Worker This implementation does not check for equivalence due to commutativity, 423*61046927SAndroid Build Coastguard Worker but it could. 424*61046927SAndroid Build Coastguard Worker 425*61046927SAndroid Build Coastguard Worker """ 426*61046927SAndroid Build Coastguard Worker if not isinstance(other, type(self)): 427*61046927SAndroid Build Coastguard Worker return False 428*61046927SAndroid Build Coastguard Worker 429*61046927SAndroid Build Coastguard Worker if len(self.sources) != len(other.sources): 430*61046927SAndroid Build Coastguard Worker return False 431*61046927SAndroid Build Coastguard Worker 432*61046927SAndroid Build Coastguard Worker if self.opcode != other.opcode: 433*61046927SAndroid Build Coastguard Worker return False 434*61046927SAndroid Build Coastguard Worker 435*61046927SAndroid Build Coastguard Worker return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 436*61046927SAndroid Build Coastguard Worker 437*61046927SAndroid Build Coastguard Worker def __index_comm_exprs(self, base_idx): 438*61046927SAndroid Build Coastguard Worker """Recursively count and index commutative expressions 439*61046927SAndroid Build Coastguard Worker """ 440*61046927SAndroid Build Coastguard Worker self.comm_exprs = 0 441*61046927SAndroid Build Coastguard Worker 442*61046927SAndroid Build Coastguard Worker # A note about the explicit "len(self.sources)" check. The list of 443*61046927SAndroid Build Coastguard Worker # sources comes from user input, and that input might be bad. Check 444*61046927SAndroid Build Coastguard Worker # that the expected second source exists before accessing it. Without 445*61046927SAndroid Build Coastguard Worker # this check, a unit test that does "('iadd', 'a')" will crash. 446*61046927SAndroid Build Coastguard Worker if self.opcode not in conv_opcode_types and \ 447*61046927SAndroid Build Coastguard Worker "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 448*61046927SAndroid Build Coastguard Worker len(self.sources) >= 2 and \ 449*61046927SAndroid Build Coastguard Worker not self.sources[0].equivalent(self.sources[1]): 450*61046927SAndroid Build Coastguard Worker self.comm_expr_idx = base_idx 451*61046927SAndroid Build Coastguard Worker self.comm_exprs += 1 452*61046927SAndroid Build Coastguard Worker else: 453*61046927SAndroid Build Coastguard Worker self.comm_expr_idx = -1 454*61046927SAndroid Build Coastguard Worker 455*61046927SAndroid Build Coastguard Worker for s in self.sources: 456*61046927SAndroid Build Coastguard Worker if isinstance(s, Expression): 457*61046927SAndroid Build Coastguard Worker s.__index_comm_exprs(base_idx + self.comm_exprs) 458*61046927SAndroid Build Coastguard Worker self.comm_exprs += s.comm_exprs 459*61046927SAndroid Build Coastguard Worker 460*61046927SAndroid Build Coastguard Worker return self.comm_exprs 461*61046927SAndroid Build Coastguard Worker 462*61046927SAndroid Build Coastguard Worker def c_opcode(self): 463*61046927SAndroid Build Coastguard Worker return get_c_opcode(self.opcode) 464*61046927SAndroid Build Coastguard Worker 465*61046927SAndroid Build Coastguard Worker def render(self, cache): 466*61046927SAndroid Build Coastguard Worker srcs = "".join(src.render(cache) for src in self.sources) 467*61046927SAndroid Build Coastguard Worker return srcs + super(Expression, self).render(cache) 468*61046927SAndroid Build Coastguard Worker 469*61046927SAndroid Build Coastguard Workerclass BitSizeValidator(object): 470*61046927SAndroid Build Coastguard Worker """A class for validating bit sizes of expressions. 471*61046927SAndroid Build Coastguard Worker 472*61046927SAndroid Build Coastguard Worker NIR supports multiple bit-sizes on expressions in order to handle things 473*61046927SAndroid Build Coastguard Worker such as fp64. The source and destination of every ALU operation is 474*61046927SAndroid Build Coastguard Worker assigned a type and that type may or may not specify a bit size. Sources 475*61046927SAndroid Build Coastguard Worker and destinations whose type does not specify a bit size are considered 476*61046927SAndroid Build Coastguard Worker "unsized" and automatically take on the bit size of the corresponding 477*61046927SAndroid Build Coastguard Worker register or SSA value. NIR has two simple rules for bit sizes that are 478*61046927SAndroid Build Coastguard Worker validated by nir_validator: 479*61046927SAndroid Build Coastguard Worker 480*61046927SAndroid Build Coastguard Worker 1) A given SSA def or register has a single bit size that is respected by 481*61046927SAndroid Build Coastguard Worker everything that reads from it or writes to it. 482*61046927SAndroid Build Coastguard Worker 483*61046927SAndroid Build Coastguard Worker 2) The bit sizes of all unsized inputs/outputs on any given ALU 484*61046927SAndroid Build Coastguard Worker instruction must match. They need not match the sized inputs or 485*61046927SAndroid Build Coastguard Worker outputs but they must match each other. 486*61046927SAndroid Build Coastguard Worker 487*61046927SAndroid Build Coastguard Worker In order to keep nir_algebraic relatively simple and easy-to-use, 488*61046927SAndroid Build Coastguard Worker nir_search supports a type of bit-size inference based on the two rules 489*61046927SAndroid Build Coastguard Worker above. This is similar to type inference in many common programming 490*61046927SAndroid Build Coastguard Worker languages. If, for instance, you are constructing an add operation and you 491*61046927SAndroid Build Coastguard Worker know the second source is 16-bit, then you know that the other source and 492*61046927SAndroid Build Coastguard Worker the destination must also be 16-bit. There are, however, cases where this 493*61046927SAndroid Build Coastguard Worker inference can be ambiguous or contradictory. Consider, for instance, the 494*61046927SAndroid Build Coastguard Worker following transformation: 495*61046927SAndroid Build Coastguard Worker 496*61046927SAndroid Build Coastguard Worker (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 497*61046927SAndroid Build Coastguard Worker 498*61046927SAndroid Build Coastguard Worker This transformation can potentially cause a problem because usub_borrow is 499*61046927SAndroid Build Coastguard Worker well-defined for any bit-size of integer. However, b2i always generates a 500*61046927SAndroid Build Coastguard Worker 32-bit result so it could end up replacing a 64-bit expression with one 501*61046927SAndroid Build Coastguard Worker that takes two 64-bit values and produces a 32-bit value. As another 502*61046927SAndroid Build Coastguard Worker example, consider this expression: 503*61046927SAndroid Build Coastguard Worker 504*61046927SAndroid Build Coastguard Worker (('bcsel', a, b, 0), ('iand', a, b)) 505*61046927SAndroid Build Coastguard Worker 506*61046927SAndroid Build Coastguard Worker In this case, in the search expression a must be 32-bit but b can 507*61046927SAndroid Build Coastguard Worker potentially have any bit size. If we had a 64-bit b value, we would end up 508*61046927SAndroid Build Coastguard Worker trying to and a 32-bit value with a 64-bit value which would be invalid 509*61046927SAndroid Build Coastguard Worker 510*61046927SAndroid Build Coastguard Worker This class solves that problem by providing a validation layer that proves 511*61046927SAndroid Build Coastguard Worker that a given search-and-replace operation is 100% well-defined before we 512*61046927SAndroid Build Coastguard Worker generate any code. This ensures that bugs are caught at compile time 513*61046927SAndroid Build Coastguard Worker rather than at run time. 514*61046927SAndroid Build Coastguard Worker 515*61046927SAndroid Build Coastguard Worker Each value maintains a "bit-size class", which is either an actual bit size 516*61046927SAndroid Build Coastguard Worker or an equivalence class with other values that must have the same bit size. 517*61046927SAndroid Build Coastguard Worker The validator works by combining bit-size classes with each other according 518*61046927SAndroid Build Coastguard Worker to the NIR rules outlined above, checking that there are no inconsistencies. 519*61046927SAndroid Build Coastguard Worker When doing this for the replacement expression, we make sure to never change 520*61046927SAndroid Build Coastguard Worker the equivalence class of any of the search values. We could make the example 521*61046927SAndroid Build Coastguard Worker transforms above work by doing some extra run-time checking of the search 522*61046927SAndroid Build Coastguard Worker expression, but we make the user specify those constraints themselves, to 523*61046927SAndroid Build Coastguard Worker avoid any surprises. Since the replacement bitsizes can only be connected to 524*61046927SAndroid Build Coastguard Worker the source bitsize via variables (variables must have the same bitsize in 525*61046927SAndroid Build Coastguard Worker the source and replacment expressions) or the roots of the expression (the 526*61046927SAndroid Build Coastguard Worker replacement expression must produce the same bit size as the search 527*61046927SAndroid Build Coastguard Worker expression), we prevent merging a variable with anything when processing the 528*61046927SAndroid Build Coastguard Worker replacement expression, or specializing the search bitsize 529*61046927SAndroid Build Coastguard Worker with anything. The former prevents 530*61046927SAndroid Build Coastguard Worker 531*61046927SAndroid Build Coastguard Worker (('bcsel', a, b, 0), ('iand', a, b)) 532*61046927SAndroid Build Coastguard Worker 533*61046927SAndroid Build Coastguard Worker from being allowed, since we'd have to merge the bitsizes for a and b due to 534*61046927SAndroid Build Coastguard Worker the 'iand', while the latter prevents 535*61046927SAndroid Build Coastguard Worker 536*61046927SAndroid Build Coastguard Worker (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 537*61046927SAndroid Build Coastguard Worker 538*61046927SAndroid Build Coastguard Worker from being allowed, since the search expression has the bit size of a and b, 539*61046927SAndroid Build Coastguard Worker which can't be specialized to 32 which is the bitsize of the replace 540*61046927SAndroid Build Coastguard Worker expression. It also prevents something like: 541*61046927SAndroid Build Coastguard Worker 542*61046927SAndroid Build Coastguard Worker (('b2i', ('i2b', a)), ('ineq', a, 0)) 543*61046927SAndroid Build Coastguard Worker 544*61046927SAndroid Build Coastguard Worker since the bitsize of 'b2i', which can be anything, can't be specialized to 545*61046927SAndroid Build Coastguard Worker the bitsize of a. 546*61046927SAndroid Build Coastguard Worker 547*61046927SAndroid Build Coastguard Worker After doing all this, we check that every subexpression of the replacement 548*61046927SAndroid Build Coastguard Worker was assigned a constant bitsize, the bitsize of a variable, or the bitsize 549*61046927SAndroid Build Coastguard Worker of the search expresssion, since those are the things that are known when 550*61046927SAndroid Build Coastguard Worker constructing the replacement expresssion. Finally, we record the bitsize 551*61046927SAndroid Build Coastguard Worker needed in nir_search_value so that we know what to do when building the 552*61046927SAndroid Build Coastguard Worker replacement expression. 553*61046927SAndroid Build Coastguard Worker """ 554*61046927SAndroid Build Coastguard Worker 555*61046927SAndroid Build Coastguard Worker def __init__(self, varset): 556*61046927SAndroid Build Coastguard Worker self._var_classes = [None] * len(varset.names) 557*61046927SAndroid Build Coastguard Worker 558*61046927SAndroid Build Coastguard Worker def compare_bitsizes(self, a, b): 559*61046927SAndroid Build Coastguard Worker """Determines which bitsize class is a specialization of the other, or 560*61046927SAndroid Build Coastguard Worker whether neither is. When we merge two different bitsizes, the 561*61046927SAndroid Build Coastguard Worker less-specialized bitsize always points to the more-specialized one, so 562*61046927SAndroid Build Coastguard Worker that calling get_bit_size() always gets you the most specialized bitsize. 563*61046927SAndroid Build Coastguard Worker The specialization partial order is given by: 564*61046927SAndroid Build Coastguard Worker - Physical bitsizes are always the most specialized, and a different 565*61046927SAndroid Build Coastguard Worker bitsize can never specialize another. 566*61046927SAndroid Build Coastguard Worker - In the search expression, variables can always be specialized to each 567*61046927SAndroid Build Coastguard Worker other and to physical bitsizes. In the replace expression, we disallow 568*61046927SAndroid Build Coastguard Worker this to avoid adding extra constraints to the search expression that 569*61046927SAndroid Build Coastguard Worker the user didn't specify. 570*61046927SAndroid Build Coastguard Worker - Expressions and constants without a bitsize can always be specialized to 571*61046927SAndroid Build Coastguard Worker each other and variables, but not the other way around. 572*61046927SAndroid Build Coastguard Worker 573*61046927SAndroid Build Coastguard Worker We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 574*61046927SAndroid Build Coastguard Worker and None if they are not comparable (neither a <= b nor b <= a). 575*61046927SAndroid Build Coastguard Worker """ 576*61046927SAndroid Build Coastguard Worker if isinstance(a, int): 577*61046927SAndroid Build Coastguard Worker if isinstance(b, int): 578*61046927SAndroid Build Coastguard Worker return 0 if a == b else None 579*61046927SAndroid Build Coastguard Worker elif isinstance(b, Variable): 580*61046927SAndroid Build Coastguard Worker return -1 if self.is_search else None 581*61046927SAndroid Build Coastguard Worker else: 582*61046927SAndroid Build Coastguard Worker return -1 583*61046927SAndroid Build Coastguard Worker elif isinstance(a, Variable): 584*61046927SAndroid Build Coastguard Worker if isinstance(b, int): 585*61046927SAndroid Build Coastguard Worker return 1 if self.is_search else None 586*61046927SAndroid Build Coastguard Worker elif isinstance(b, Variable): 587*61046927SAndroid Build Coastguard Worker return 0 if self.is_search or a.index == b.index else None 588*61046927SAndroid Build Coastguard Worker else: 589*61046927SAndroid Build Coastguard Worker return -1 590*61046927SAndroid Build Coastguard Worker else: 591*61046927SAndroid Build Coastguard Worker if isinstance(b, int): 592*61046927SAndroid Build Coastguard Worker return 1 593*61046927SAndroid Build Coastguard Worker elif isinstance(b, Variable): 594*61046927SAndroid Build Coastguard Worker return 1 595*61046927SAndroid Build Coastguard Worker else: 596*61046927SAndroid Build Coastguard Worker return 0 597*61046927SAndroid Build Coastguard Worker 598*61046927SAndroid Build Coastguard Worker def unify_bit_size(self, a, b, error_msg): 599*61046927SAndroid Build Coastguard Worker """Record that a must have the same bit-size as b. If both 600*61046927SAndroid Build Coastguard Worker have been assigned conflicting physical bit-sizes, call "error_msg" with 601*61046927SAndroid Build Coastguard Worker the bit-sizes of self and other to get a message and raise an error. 602*61046927SAndroid Build Coastguard Worker In the replace expression, disallow merging variables with other 603*61046927SAndroid Build Coastguard Worker variables and physical bit-sizes as well. 604*61046927SAndroid Build Coastguard Worker """ 605*61046927SAndroid Build Coastguard Worker a_bit_size = a.get_bit_size() 606*61046927SAndroid Build Coastguard Worker b_bit_size = b if isinstance(b, int) else b.get_bit_size() 607*61046927SAndroid Build Coastguard Worker 608*61046927SAndroid Build Coastguard Worker cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 609*61046927SAndroid Build Coastguard Worker 610*61046927SAndroid Build Coastguard Worker assert cmp_result is not None, \ 611*61046927SAndroid Build Coastguard Worker error_msg(a_bit_size, b_bit_size) 612*61046927SAndroid Build Coastguard Worker 613*61046927SAndroid Build Coastguard Worker if cmp_result < 0: 614*61046927SAndroid Build Coastguard Worker b_bit_size.set_bit_size(a) 615*61046927SAndroid Build Coastguard Worker elif not isinstance(a_bit_size, int): 616*61046927SAndroid Build Coastguard Worker a_bit_size.set_bit_size(b) 617*61046927SAndroid Build Coastguard Worker 618*61046927SAndroid Build Coastguard Worker def merge_variables(self, val): 619*61046927SAndroid Build Coastguard Worker """Perform the first part of type inference by merging all the different 620*61046927SAndroid Build Coastguard Worker uses of the same variable. We always do this as if we're in the search 621*61046927SAndroid Build Coastguard Worker expression, even if we're actually not, since otherwise we'd get errors 622*61046927SAndroid Build Coastguard Worker if the search expression specified some constraint but the replace 623*61046927SAndroid Build Coastguard Worker expression didn't, because we'd be merging a variable and a constant. 624*61046927SAndroid Build Coastguard Worker """ 625*61046927SAndroid Build Coastguard Worker if isinstance(val, Variable): 626*61046927SAndroid Build Coastguard Worker if self._var_classes[val.index] is None: 627*61046927SAndroid Build Coastguard Worker self._var_classes[val.index] = val 628*61046927SAndroid Build Coastguard Worker else: 629*61046927SAndroid Build Coastguard Worker other = self._var_classes[val.index] 630*61046927SAndroid Build Coastguard Worker self.unify_bit_size(other, val, 631*61046927SAndroid Build Coastguard Worker lambda other_bit_size, bit_size: 632*61046927SAndroid Build Coastguard Worker 'Variable {} has conflicting bit size requirements: ' \ 633*61046927SAndroid Build Coastguard Worker 'it must have bit size {} and {}'.format( 634*61046927SAndroid Build Coastguard Worker val.var_name, other_bit_size, bit_size)) 635*61046927SAndroid Build Coastguard Worker elif isinstance(val, Expression): 636*61046927SAndroid Build Coastguard Worker for src in val.sources: 637*61046927SAndroid Build Coastguard Worker self.merge_variables(src) 638*61046927SAndroid Build Coastguard Worker 639*61046927SAndroid Build Coastguard Worker def validate_value(self, val): 640*61046927SAndroid Build Coastguard Worker """Validate the an expression by performing classic Hindley-Milner 641*61046927SAndroid Build Coastguard Worker type inference on bitsizes. This will detect if there are any conflicting 642*61046927SAndroid Build Coastguard Worker requirements, and unify variables so that we know which variables must 643*61046927SAndroid Build Coastguard Worker have the same bitsize. If we're operating on the replace expression, we 644*61046927SAndroid Build Coastguard Worker will refuse to merge different variables together or merge a variable 645*61046927SAndroid Build Coastguard Worker with a constant, in order to prevent surprises due to rules unexpectedly 646*61046927SAndroid Build Coastguard Worker not matching at runtime. 647*61046927SAndroid Build Coastguard Worker """ 648*61046927SAndroid Build Coastguard Worker if not isinstance(val, Expression): 649*61046927SAndroid Build Coastguard Worker return 650*61046927SAndroid Build Coastguard Worker 651*61046927SAndroid Build Coastguard Worker # Generic conversion ops are special in that they have a single unsized 652*61046927SAndroid Build Coastguard Worker # source and an unsized destination and the two don't have to match. 653*61046927SAndroid Build Coastguard Worker # This means there's no validation or unioning to do here besides the 654*61046927SAndroid Build Coastguard Worker # len(val.sources) check. 655*61046927SAndroid Build Coastguard Worker if val.opcode in conv_opcode_types: 656*61046927SAndroid Build Coastguard Worker assert len(val.sources) == 1, \ 657*61046927SAndroid Build Coastguard Worker "Expression {} has {} sources, expected 1".format( 658*61046927SAndroid Build Coastguard Worker val, len(val.sources)) 659*61046927SAndroid Build Coastguard Worker self.validate_value(val.sources[0]) 660*61046927SAndroid Build Coastguard Worker return 661*61046927SAndroid Build Coastguard Worker 662*61046927SAndroid Build Coastguard Worker nir_op = opcodes[val.opcode] 663*61046927SAndroid Build Coastguard Worker assert len(val.sources) == nir_op.num_inputs, \ 664*61046927SAndroid Build Coastguard Worker "Expression {} has {} sources, expected {}".format( 665*61046927SAndroid Build Coastguard Worker val, len(val.sources), nir_op.num_inputs) 666*61046927SAndroid Build Coastguard Worker 667*61046927SAndroid Build Coastguard Worker for src in val.sources: 668*61046927SAndroid Build Coastguard Worker self.validate_value(src) 669*61046927SAndroid Build Coastguard Worker 670*61046927SAndroid Build Coastguard Worker dst_type_bits = type_bits(nir_op.output_type) 671*61046927SAndroid Build Coastguard Worker 672*61046927SAndroid Build Coastguard Worker # First, unify all the sources. That way, an error coming up because two 673*61046927SAndroid Build Coastguard Worker # sources have an incompatible bit-size won't produce an error message 674*61046927SAndroid Build Coastguard Worker # involving the destination. 675*61046927SAndroid Build Coastguard Worker first_unsized_src = None 676*61046927SAndroid Build Coastguard Worker for src_type, src in zip(nir_op.input_types, val.sources): 677*61046927SAndroid Build Coastguard Worker src_type_bits = type_bits(src_type) 678*61046927SAndroid Build Coastguard Worker if src_type_bits == 0: 679*61046927SAndroid Build Coastguard Worker if first_unsized_src is None: 680*61046927SAndroid Build Coastguard Worker first_unsized_src = src 681*61046927SAndroid Build Coastguard Worker continue 682*61046927SAndroid Build Coastguard Worker 683*61046927SAndroid Build Coastguard Worker if self.is_search: 684*61046927SAndroid Build Coastguard Worker self.unify_bit_size(first_unsized_src, src, 685*61046927SAndroid Build Coastguard Worker lambda first_unsized_src_bit_size, src_bit_size: 686*61046927SAndroid Build Coastguard Worker 'Source {} of {} must have bit size {}, while source {} ' \ 687*61046927SAndroid Build Coastguard Worker 'must have incompatible bit size {}'.format( 688*61046927SAndroid Build Coastguard Worker first_unsized_src, val, first_unsized_src_bit_size, 689*61046927SAndroid Build Coastguard Worker src, src_bit_size)) 690*61046927SAndroid Build Coastguard Worker else: 691*61046927SAndroid Build Coastguard Worker self.unify_bit_size(first_unsized_src, src, 692*61046927SAndroid Build Coastguard Worker lambda first_unsized_src_bit_size, src_bit_size: 693*61046927SAndroid Build Coastguard Worker 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 694*61046927SAndroid Build Coastguard Worker 'of {} may not have the same bit size when building the ' \ 695*61046927SAndroid Build Coastguard Worker 'replacement expression.'.format( 696*61046927SAndroid Build Coastguard Worker first_unsized_src, first_unsized_src_bit_size, src, 697*61046927SAndroid Build Coastguard Worker src_bit_size, val)) 698*61046927SAndroid Build Coastguard Worker else: 699*61046927SAndroid Build Coastguard Worker if self.is_search: 700*61046927SAndroid Build Coastguard Worker self.unify_bit_size(src, src_type_bits, 701*61046927SAndroid Build Coastguard Worker lambda src_bit_size, unused: 702*61046927SAndroid Build Coastguard Worker '{} must have {} bits, but as a source of nir_op_{} '\ 703*61046927SAndroid Build Coastguard Worker 'it must have {} bits'.format( 704*61046927SAndroid Build Coastguard Worker src, src_bit_size, nir_op.name, src_type_bits)) 705*61046927SAndroid Build Coastguard Worker else: 706*61046927SAndroid Build Coastguard Worker self.unify_bit_size(src, src_type_bits, 707*61046927SAndroid Build Coastguard Worker lambda src_bit_size, unused: 708*61046927SAndroid Build Coastguard Worker '{} has the bit size of {}, but as a source of ' \ 709*61046927SAndroid Build Coastguard Worker 'nir_op_{} it must have {} bits, which may not be the ' \ 710*61046927SAndroid Build Coastguard Worker 'same'.format( 711*61046927SAndroid Build Coastguard Worker src, src_bit_size, nir_op.name, src_type_bits)) 712*61046927SAndroid Build Coastguard Worker 713*61046927SAndroid Build Coastguard Worker if dst_type_bits == 0: 714*61046927SAndroid Build Coastguard Worker if first_unsized_src is not None: 715*61046927SAndroid Build Coastguard Worker if self.is_search: 716*61046927SAndroid Build Coastguard Worker self.unify_bit_size(val, first_unsized_src, 717*61046927SAndroid Build Coastguard Worker lambda val_bit_size, src_bit_size: 718*61046927SAndroid Build Coastguard Worker '{} must have the bit size of {}, while its source {} ' \ 719*61046927SAndroid Build Coastguard Worker 'must have incompatible bit size {}'.format( 720*61046927SAndroid Build Coastguard Worker val, val_bit_size, first_unsized_src, src_bit_size)) 721*61046927SAndroid Build Coastguard Worker else: 722*61046927SAndroid Build Coastguard Worker self.unify_bit_size(val, first_unsized_src, 723*61046927SAndroid Build Coastguard Worker lambda val_bit_size, src_bit_size: 724*61046927SAndroid Build Coastguard Worker '{} must have {} bits, but its source {} ' \ 725*61046927SAndroid Build Coastguard Worker '(bit size of {}) may not have that bit size ' \ 726*61046927SAndroid Build Coastguard Worker 'when building the replacement.'.format( 727*61046927SAndroid Build Coastguard Worker val, val_bit_size, first_unsized_src, src_bit_size)) 728*61046927SAndroid Build Coastguard Worker else: 729*61046927SAndroid Build Coastguard Worker self.unify_bit_size(val, dst_type_bits, 730*61046927SAndroid Build Coastguard Worker lambda dst_bit_size, unused: 731*61046927SAndroid Build Coastguard Worker '{} must have {} bits, but as a destination of nir_op_{} ' \ 732*61046927SAndroid Build Coastguard Worker 'it must have {} bits'.format( 733*61046927SAndroid Build Coastguard Worker val, dst_bit_size, nir_op.name, dst_type_bits)) 734*61046927SAndroid Build Coastguard Worker 735*61046927SAndroid Build Coastguard Worker def validate_replace(self, val, search): 736*61046927SAndroid Build Coastguard Worker bit_size = val.get_bit_size() 737*61046927SAndroid Build Coastguard Worker assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 738*61046927SAndroid Build Coastguard Worker bit_size == search.get_bit_size(), \ 739*61046927SAndroid Build Coastguard Worker 'Ambiguous bit size for replacement value {}: ' \ 740*61046927SAndroid Build Coastguard Worker 'it cannot be deduced from a variable, a fixed bit size ' \ 741*61046927SAndroid Build Coastguard Worker 'somewhere, or the search expression.'.format(val) 742*61046927SAndroid Build Coastguard Worker 743*61046927SAndroid Build Coastguard Worker if isinstance(val, Expression): 744*61046927SAndroid Build Coastguard Worker for src in val.sources: 745*61046927SAndroid Build Coastguard Worker self.validate_replace(src, search) 746*61046927SAndroid Build Coastguard Worker elif isinstance(val, Variable): 747*61046927SAndroid Build Coastguard Worker # These catch problems when someone copies and pastes the search 748*61046927SAndroid Build Coastguard Worker # into the replacement. 749*61046927SAndroid Build Coastguard Worker assert not val.is_constant, \ 750*61046927SAndroid Build Coastguard Worker 'Replacement variables must not be marked constant.' 751*61046927SAndroid Build Coastguard Worker 752*61046927SAndroid Build Coastguard Worker assert val.cond_index == -1, \ 753*61046927SAndroid Build Coastguard Worker 'Replacement variables must not have a condition.' 754*61046927SAndroid Build Coastguard Worker 755*61046927SAndroid Build Coastguard Worker assert not val.required_type, \ 756*61046927SAndroid Build Coastguard Worker 'Replacement variables must not have a required type.' 757*61046927SAndroid Build Coastguard Worker 758*61046927SAndroid Build Coastguard Worker def validate(self, search, replace): 759*61046927SAndroid Build Coastguard Worker self.is_search = True 760*61046927SAndroid Build Coastguard Worker self.merge_variables(search) 761*61046927SAndroid Build Coastguard Worker self.merge_variables(replace) 762*61046927SAndroid Build Coastguard Worker self.validate_value(search) 763*61046927SAndroid Build Coastguard Worker 764*61046927SAndroid Build Coastguard Worker self.is_search = False 765*61046927SAndroid Build Coastguard Worker self.validate_value(replace) 766*61046927SAndroid Build Coastguard Worker 767*61046927SAndroid Build Coastguard Worker # Check that search is always more specialized than replace. Note that 768*61046927SAndroid Build Coastguard Worker # we're doing this in replace mode, disallowing merging variables. 769*61046927SAndroid Build Coastguard Worker search_bit_size = search.get_bit_size() 770*61046927SAndroid Build Coastguard Worker replace_bit_size = replace.get_bit_size() 771*61046927SAndroid Build Coastguard Worker cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 772*61046927SAndroid Build Coastguard Worker 773*61046927SAndroid Build Coastguard Worker assert cmp_result is not None and cmp_result <= 0, \ 774*61046927SAndroid Build Coastguard Worker 'The search expression bit size {} and replace expression ' \ 775*61046927SAndroid Build Coastguard Worker 'bit size {} may not be the same'.format( 776*61046927SAndroid Build Coastguard Worker search_bit_size, replace_bit_size) 777*61046927SAndroid Build Coastguard Worker 778*61046927SAndroid Build Coastguard Worker replace.set_bit_size(search) 779*61046927SAndroid Build Coastguard Worker 780*61046927SAndroid Build Coastguard Worker self.validate_replace(replace, search) 781*61046927SAndroid Build Coastguard Worker 782*61046927SAndroid Build Coastguard Worker_optimization_ids = itertools.count() 783*61046927SAndroid Build Coastguard Worker 784*61046927SAndroid Build Coastguard Workercondition_list = ['true'] 785*61046927SAndroid Build Coastguard Worker 786*61046927SAndroid Build Coastguard Workerclass SearchAndReplace(object): 787*61046927SAndroid Build Coastguard Worker def __init__(self, transform, algebraic_pass): 788*61046927SAndroid Build Coastguard Worker self.id = next(_optimization_ids) 789*61046927SAndroid Build Coastguard Worker 790*61046927SAndroid Build Coastguard Worker search = transform[0] 791*61046927SAndroid Build Coastguard Worker replace = transform[1] 792*61046927SAndroid Build Coastguard Worker if len(transform) > 2: 793*61046927SAndroid Build Coastguard Worker self.condition = transform[2] 794*61046927SAndroid Build Coastguard Worker else: 795*61046927SAndroid Build Coastguard Worker self.condition = 'true' 796*61046927SAndroid Build Coastguard Worker 797*61046927SAndroid Build Coastguard Worker if self.condition not in condition_list: 798*61046927SAndroid Build Coastguard Worker condition_list.append(self.condition) 799*61046927SAndroid Build Coastguard Worker self.condition_index = condition_list.index(self.condition) 800*61046927SAndroid Build Coastguard Worker 801*61046927SAndroid Build Coastguard Worker varset = VarSet() 802*61046927SAndroid Build Coastguard Worker if isinstance(search, Expression): 803*61046927SAndroid Build Coastguard Worker self.search = search 804*61046927SAndroid Build Coastguard Worker else: 805*61046927SAndroid Build Coastguard Worker self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass) 806*61046927SAndroid Build Coastguard Worker 807*61046927SAndroid Build Coastguard Worker varset.lock() 808*61046927SAndroid Build Coastguard Worker 809*61046927SAndroid Build Coastguard Worker if isinstance(replace, Value): 810*61046927SAndroid Build Coastguard Worker self.replace = replace 811*61046927SAndroid Build Coastguard Worker else: 812*61046927SAndroid Build Coastguard Worker self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass) 813*61046927SAndroid Build Coastguard Worker 814*61046927SAndroid Build Coastguard Worker BitSizeValidator(varset).validate(self.search, self.replace) 815*61046927SAndroid Build Coastguard Worker 816*61046927SAndroid Build Coastguard Workerclass TreeAutomaton(object): 817*61046927SAndroid Build Coastguard Worker """This class calculates a bottom-up tree automaton to quickly search for 818*61046927SAndroid Build Coastguard Worker the left-hand sides of tranforms. Tree automatons are a generalization of 819*61046927SAndroid Build Coastguard Worker classical NFA's and DFA's, where the transition function determines the 820*61046927SAndroid Build Coastguard Worker state of the parent node based on the state of its children. We construct a 821*61046927SAndroid Build Coastguard Worker deterministic automaton to match patterns, using a similar algorithm to the 822*61046927SAndroid Build Coastguard Worker classical NFA to DFA construction. At the moment, it only matches opcodes 823*61046927SAndroid Build Coastguard Worker and constants (without checking the actual value), leaving more detailed 824*61046927SAndroid Build Coastguard Worker checking to the search function which actually checks the leaves. The 825*61046927SAndroid Build Coastguard Worker automaton acts as a quick filter for the search function, requiring only n 826*61046927SAndroid Build Coastguard Worker + 1 table lookups for each n-source operation. The implementation is based 827*61046927SAndroid Build Coastguard Worker on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 828*61046927SAndroid Build Coastguard Worker In the language of that reference, this is a frontier-to-root deterministic 829*61046927SAndroid Build Coastguard Worker automaton using only symbol filtering. The filtering is crucial to reduce 830*61046927SAndroid Build Coastguard Worker both the time taken to generate the tables and the size of the tables. 831*61046927SAndroid Build Coastguard Worker """ 832*61046927SAndroid Build Coastguard Worker def __init__(self, transforms): 833*61046927SAndroid Build Coastguard Worker self.patterns = [t.search for t in transforms] 834*61046927SAndroid Build Coastguard Worker self._compute_items() 835*61046927SAndroid Build Coastguard Worker self._build_table() 836*61046927SAndroid Build Coastguard Worker #print('num items: {}'.format(len(set(self.items.values())))) 837*61046927SAndroid Build Coastguard Worker #print('num states: {}'.format(len(self.states))) 838*61046927SAndroid Build Coastguard Worker #for state, patterns in zip(self.states, self.patterns): 839*61046927SAndroid Build Coastguard Worker # print('{}: num patterns: {}'.format(state, len(patterns))) 840*61046927SAndroid Build Coastguard Worker 841*61046927SAndroid Build Coastguard Worker class IndexMap(object): 842*61046927SAndroid Build Coastguard Worker """An indexed list of objects, where one can either lookup an object by 843*61046927SAndroid Build Coastguard Worker index or find the index associated to an object quickly using a hash 844*61046927SAndroid Build Coastguard Worker table. Compared to a list, it has a constant time index(). Compared to a 845*61046927SAndroid Build Coastguard Worker set, it provides a stable iteration order. 846*61046927SAndroid Build Coastguard Worker """ 847*61046927SAndroid Build Coastguard Worker def __init__(self, iterable=()): 848*61046927SAndroid Build Coastguard Worker self.objects = [] 849*61046927SAndroid Build Coastguard Worker self.map = {} 850*61046927SAndroid Build Coastguard Worker for obj in iterable: 851*61046927SAndroid Build Coastguard Worker self.add(obj) 852*61046927SAndroid Build Coastguard Worker 853*61046927SAndroid Build Coastguard Worker def __getitem__(self, i): 854*61046927SAndroid Build Coastguard Worker return self.objects[i] 855*61046927SAndroid Build Coastguard Worker 856*61046927SAndroid Build Coastguard Worker def __contains__(self, obj): 857*61046927SAndroid Build Coastguard Worker return obj in self.map 858*61046927SAndroid Build Coastguard Worker 859*61046927SAndroid Build Coastguard Worker def __len__(self): 860*61046927SAndroid Build Coastguard Worker return len(self.objects) 861*61046927SAndroid Build Coastguard Worker 862*61046927SAndroid Build Coastguard Worker def __iter__(self): 863*61046927SAndroid Build Coastguard Worker return iter(self.objects) 864*61046927SAndroid Build Coastguard Worker 865*61046927SAndroid Build Coastguard Worker def clear(self): 866*61046927SAndroid Build Coastguard Worker self.objects = [] 867*61046927SAndroid Build Coastguard Worker self.map.clear() 868*61046927SAndroid Build Coastguard Worker 869*61046927SAndroid Build Coastguard Worker def index(self, obj): 870*61046927SAndroid Build Coastguard Worker return self.map[obj] 871*61046927SAndroid Build Coastguard Worker 872*61046927SAndroid Build Coastguard Worker def add(self, obj): 873*61046927SAndroid Build Coastguard Worker if obj in self.map: 874*61046927SAndroid Build Coastguard Worker return self.map[obj] 875*61046927SAndroid Build Coastguard Worker else: 876*61046927SAndroid Build Coastguard Worker index = len(self.objects) 877*61046927SAndroid Build Coastguard Worker self.objects.append(obj) 878*61046927SAndroid Build Coastguard Worker self.map[obj] = index 879*61046927SAndroid Build Coastguard Worker return index 880*61046927SAndroid Build Coastguard Worker 881*61046927SAndroid Build Coastguard Worker def __repr__(self): 882*61046927SAndroid Build Coastguard Worker return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 883*61046927SAndroid Build Coastguard Worker 884*61046927SAndroid Build Coastguard Worker class Item(object): 885*61046927SAndroid Build Coastguard Worker """This represents an "item" in the language of "Tree Automatons." This 886*61046927SAndroid Build Coastguard Worker is just a subtree of some pattern, which represents a potential partial 887*61046927SAndroid Build Coastguard Worker match at runtime. We deduplicate them, so that identical subtrees of 888*61046927SAndroid Build Coastguard Worker different patterns share the same object, and store some extra 889*61046927SAndroid Build Coastguard Worker information needed for the main algorithm as well. 890*61046927SAndroid Build Coastguard Worker """ 891*61046927SAndroid Build Coastguard Worker def __init__(self, opcode, children): 892*61046927SAndroid Build Coastguard Worker self.opcode = opcode 893*61046927SAndroid Build Coastguard Worker self.children = children 894*61046927SAndroid Build Coastguard Worker # These are the indices of patterns for which this item is the root node. 895*61046927SAndroid Build Coastguard Worker self.patterns = [] 896*61046927SAndroid Build Coastguard Worker # This the set of opcodes for parents of this item. Used to speed up 897*61046927SAndroid Build Coastguard Worker # filtering. 898*61046927SAndroid Build Coastguard Worker self.parent_ops = set() 899*61046927SAndroid Build Coastguard Worker 900*61046927SAndroid Build Coastguard Worker def __str__(self): 901*61046927SAndroid Build Coastguard Worker return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 902*61046927SAndroid Build Coastguard Worker 903*61046927SAndroid Build Coastguard Worker def __repr__(self): 904*61046927SAndroid Build Coastguard Worker return str(self) 905*61046927SAndroid Build Coastguard Worker 906*61046927SAndroid Build Coastguard Worker def _compute_items(self): 907*61046927SAndroid Build Coastguard Worker """Build a set of all possible items, deduplicating them.""" 908*61046927SAndroid Build Coastguard Worker # This is a map from (opcode, sources) to item. 909*61046927SAndroid Build Coastguard Worker self.items = {} 910*61046927SAndroid Build Coastguard Worker 911*61046927SAndroid Build Coastguard Worker # The set of all opcodes used by the patterns. Used later to avoid 912*61046927SAndroid Build Coastguard Worker # building and emitting all the tables for opcodes that aren't used. 913*61046927SAndroid Build Coastguard Worker self.opcodes = self.IndexMap() 914*61046927SAndroid Build Coastguard Worker 915*61046927SAndroid Build Coastguard Worker def get_item(opcode, children, pattern=None): 916*61046927SAndroid Build Coastguard Worker commutative = len(children) >= 2 \ 917*61046927SAndroid Build Coastguard Worker and "2src_commutative" in opcodes[opcode].algebraic_properties 918*61046927SAndroid Build Coastguard Worker item = self.items.setdefault((opcode, children), 919*61046927SAndroid Build Coastguard Worker self.Item(opcode, children)) 920*61046927SAndroid Build Coastguard Worker if commutative: 921*61046927SAndroid Build Coastguard Worker self.items[opcode, (children[1], children[0]) + children[2:]] = item 922*61046927SAndroid Build Coastguard Worker if pattern is not None: 923*61046927SAndroid Build Coastguard Worker item.patterns.append(pattern) 924*61046927SAndroid Build Coastguard Worker return item 925*61046927SAndroid Build Coastguard Worker 926*61046927SAndroid Build Coastguard Worker self.wildcard = get_item("__wildcard", ()) 927*61046927SAndroid Build Coastguard Worker self.const = get_item("__const", ()) 928*61046927SAndroid Build Coastguard Worker 929*61046927SAndroid Build Coastguard Worker def process_subpattern(src, pattern=None): 930*61046927SAndroid Build Coastguard Worker if isinstance(src, Constant): 931*61046927SAndroid Build Coastguard Worker # Note: we throw away the actual constant value! 932*61046927SAndroid Build Coastguard Worker return self.const 933*61046927SAndroid Build Coastguard Worker elif isinstance(src, Variable): 934*61046927SAndroid Build Coastguard Worker if src.is_constant: 935*61046927SAndroid Build Coastguard Worker return self.const 936*61046927SAndroid Build Coastguard Worker else: 937*61046927SAndroid Build Coastguard Worker # Note: we throw away which variable it is here! This special 938*61046927SAndroid Build Coastguard Worker # item is equivalent to nu in "Tree Automatons." 939*61046927SAndroid Build Coastguard Worker return self.wildcard 940*61046927SAndroid Build Coastguard Worker else: 941*61046927SAndroid Build Coastguard Worker assert isinstance(src, Expression) 942*61046927SAndroid Build Coastguard Worker opcode = src.opcode 943*61046927SAndroid Build Coastguard Worker stripped = opcode.rstrip('0123456789') 944*61046927SAndroid Build Coastguard Worker if stripped in conv_opcode_types: 945*61046927SAndroid Build Coastguard Worker # Matches that use conversion opcodes with a specific type, 946*61046927SAndroid Build Coastguard Worker # like f2i1, are tricky. Either we construct the automaton to 947*61046927SAndroid Build Coastguard Worker # match specific NIR opcodes like nir_op_f2i1, in which case we 948*61046927SAndroid Build Coastguard Worker # need to create separate items for each possible NIR opcode 949*61046927SAndroid Build Coastguard Worker # for patterns that have a generic opcode like f2i, or we 950*61046927SAndroid Build Coastguard Worker # construct it to match the search opcode, in which case we 951*61046927SAndroid Build Coastguard Worker # need to map f2i1 to f2i when constructing the automaton. Here 952*61046927SAndroid Build Coastguard Worker # we do the latter. 953*61046927SAndroid Build Coastguard Worker opcode = stripped 954*61046927SAndroid Build Coastguard Worker self.opcodes.add(opcode) 955*61046927SAndroid Build Coastguard Worker children = tuple(process_subpattern(c) for c in src.sources) 956*61046927SAndroid Build Coastguard Worker item = get_item(opcode, children, pattern) 957*61046927SAndroid Build Coastguard Worker for i, child in enumerate(children): 958*61046927SAndroid Build Coastguard Worker child.parent_ops.add(opcode) 959*61046927SAndroid Build Coastguard Worker return item 960*61046927SAndroid Build Coastguard Worker 961*61046927SAndroid Build Coastguard Worker for i, pattern in enumerate(self.patterns): 962*61046927SAndroid Build Coastguard Worker process_subpattern(pattern, i) 963*61046927SAndroid Build Coastguard Worker 964*61046927SAndroid Build Coastguard Worker def _build_table(self): 965*61046927SAndroid Build Coastguard Worker """This is the core algorithm which builds up the transition table. It 966*61046927SAndroid Build Coastguard Worker is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 967*61046927SAndroid Build Coastguard Worker Comp_a and Filt_{a,i} using integers to identify match sets." It 968*61046927SAndroid Build Coastguard Worker simultaneously builds up a list of all possible "match sets" or 969*61046927SAndroid Build Coastguard Worker "states", where each match set represents the set of Item's that match a 970*61046927SAndroid Build Coastguard Worker given instruction, and builds up the transition table between states. 971*61046927SAndroid Build Coastguard Worker """ 972*61046927SAndroid Build Coastguard Worker # Map from opcode + filtered state indices to transitioned state. 973*61046927SAndroid Build Coastguard Worker self.table = defaultdict(dict) 974*61046927SAndroid Build Coastguard Worker # Bijection from state to index. q in the original algorithm is 975*61046927SAndroid Build Coastguard Worker # len(self.states) 976*61046927SAndroid Build Coastguard Worker self.states = self.IndexMap() 977*61046927SAndroid Build Coastguard Worker # Lists of pattern matches separated by None 978*61046927SAndroid Build Coastguard Worker self.state_patterns = [None] 979*61046927SAndroid Build Coastguard Worker # Offset in the ->transforms table for each state index 980*61046927SAndroid Build Coastguard Worker self.state_pattern_offsets = [] 981*61046927SAndroid Build Coastguard Worker # Map from state index to filtered state index for each opcode. 982*61046927SAndroid Build Coastguard Worker self.filter = defaultdict(list) 983*61046927SAndroid Build Coastguard Worker # Bijections from filtered state to filtered state index for each 984*61046927SAndroid Build Coastguard Worker # opcode, called the "representor sets" in the original algorithm. 985*61046927SAndroid Build Coastguard Worker # q_{a,j} in the original algorithm is len(self.rep[op]). 986*61046927SAndroid Build Coastguard Worker self.rep = defaultdict(self.IndexMap) 987*61046927SAndroid Build Coastguard Worker 988*61046927SAndroid Build Coastguard Worker # Everything in self.states with a index at least worklist_index is part 989*61046927SAndroid Build Coastguard Worker # of the worklist of newly created states. There is also a worklist of 990*61046927SAndroid Build Coastguard Worker # newly fitered states for each opcode, for which worklist_indices 991*61046927SAndroid Build Coastguard Worker # serves a similar purpose. worklist_index corresponds to p in the 992*61046927SAndroid Build Coastguard Worker # original algorithm, while worklist_indices is p_{a,j} (although since 993*61046927SAndroid Build Coastguard Worker # we only filter by opcode/symbol, it's really just p_a). 994*61046927SAndroid Build Coastguard Worker self.worklist_index = 0 995*61046927SAndroid Build Coastguard Worker worklist_indices = defaultdict(lambda: 0) 996*61046927SAndroid Build Coastguard Worker 997*61046927SAndroid Build Coastguard Worker # This is the set of opcodes for which the filtered worklist is non-empty. 998*61046927SAndroid Build Coastguard Worker # It's used to avoid scanning opcodes for which there is nothing to 999*61046927SAndroid Build Coastguard Worker # process when building the transition table. It corresponds to new_a in 1000*61046927SAndroid Build Coastguard Worker # the original algorithm. 1001*61046927SAndroid Build Coastguard Worker new_opcodes = self.IndexMap() 1002*61046927SAndroid Build Coastguard Worker 1003*61046927SAndroid Build Coastguard Worker # Process states on the global worklist, filtering them for each opcode, 1004*61046927SAndroid Build Coastguard Worker # updating the filter tables, and updating the filtered worklists if any 1005*61046927SAndroid Build Coastguard Worker # new filtered states are found. Similar to ComputeRepresenterSets() in 1006*61046927SAndroid Build Coastguard Worker # the original algorithm, although that only processes a single state. 1007*61046927SAndroid Build Coastguard Worker def process_new_states(): 1008*61046927SAndroid Build Coastguard Worker while self.worklist_index < len(self.states): 1009*61046927SAndroid Build Coastguard Worker state = self.states[self.worklist_index] 1010*61046927SAndroid Build Coastguard Worker # Calculate pattern matches for this state. Each pattern is 1011*61046927SAndroid Build Coastguard Worker # assigned to a unique item, so we don't have to worry about 1012*61046927SAndroid Build Coastguard Worker # deduplicating them here. However, we do have to sort them so 1013*61046927SAndroid Build Coastguard Worker # that they're visited at runtime in the order they're specified 1014*61046927SAndroid Build Coastguard Worker # in the source. 1015*61046927SAndroid Build Coastguard Worker patterns = list(sorted(p for item in state for p in item.patterns)) 1016*61046927SAndroid Build Coastguard Worker 1017*61046927SAndroid Build Coastguard Worker if patterns: 1018*61046927SAndroid Build Coastguard Worker # Add our patterns to the global table. 1019*61046927SAndroid Build Coastguard Worker self.state_pattern_offsets.append(len(self.state_patterns)) 1020*61046927SAndroid Build Coastguard Worker self.state_patterns.extend(patterns) 1021*61046927SAndroid Build Coastguard Worker self.state_patterns.append(None) 1022*61046927SAndroid Build Coastguard Worker else: 1023*61046927SAndroid Build Coastguard Worker # Point to the initial sentinel in the global table. 1024*61046927SAndroid Build Coastguard Worker self.state_pattern_offsets.append(0) 1025*61046927SAndroid Build Coastguard Worker 1026*61046927SAndroid Build Coastguard Worker # calculate filter table for this state, and update filtered 1027*61046927SAndroid Build Coastguard Worker # worklists. 1028*61046927SAndroid Build Coastguard Worker for op in self.opcodes: 1029*61046927SAndroid Build Coastguard Worker filt = self.filter[op] 1030*61046927SAndroid Build Coastguard Worker rep = self.rep[op] 1031*61046927SAndroid Build Coastguard Worker filtered = frozenset(item for item in state if \ 1032*61046927SAndroid Build Coastguard Worker op in item.parent_ops) 1033*61046927SAndroid Build Coastguard Worker if filtered in rep: 1034*61046927SAndroid Build Coastguard Worker rep_index = rep.index(filtered) 1035*61046927SAndroid Build Coastguard Worker else: 1036*61046927SAndroid Build Coastguard Worker rep_index = rep.add(filtered) 1037*61046927SAndroid Build Coastguard Worker new_opcodes.add(op) 1038*61046927SAndroid Build Coastguard Worker assert len(filt) == self.worklist_index 1039*61046927SAndroid Build Coastguard Worker filt.append(rep_index) 1040*61046927SAndroid Build Coastguard Worker self.worklist_index += 1 1041*61046927SAndroid Build Coastguard Worker 1042*61046927SAndroid Build Coastguard Worker # There are two start states: one which can only match as a wildcard, 1043*61046927SAndroid Build Coastguard Worker # and one which can match as a wildcard or constant. These will be the 1044*61046927SAndroid Build Coastguard Worker # states of intrinsics/other instructions and load_const instructions, 1045*61046927SAndroid Build Coastguard Worker # respectively. The indices of these must match the definitions of 1046*61046927SAndroid Build Coastguard Worker # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 1047*61046927SAndroid Build Coastguard Worker # initialize things correctly. 1048*61046927SAndroid Build Coastguard Worker self.states.add(frozenset((self.wildcard,))) 1049*61046927SAndroid Build Coastguard Worker self.states.add(frozenset((self.const,self.wildcard))) 1050*61046927SAndroid Build Coastguard Worker process_new_states() 1051*61046927SAndroid Build Coastguard Worker 1052*61046927SAndroid Build Coastguard Worker while len(new_opcodes) > 0: 1053*61046927SAndroid Build Coastguard Worker for op in new_opcodes: 1054*61046927SAndroid Build Coastguard Worker rep = self.rep[op] 1055*61046927SAndroid Build Coastguard Worker table = self.table[op] 1056*61046927SAndroid Build Coastguard Worker op_worklist_index = worklist_indices[op] 1057*61046927SAndroid Build Coastguard Worker if op in conv_opcode_types: 1058*61046927SAndroid Build Coastguard Worker num_srcs = 1 1059*61046927SAndroid Build Coastguard Worker else: 1060*61046927SAndroid Build Coastguard Worker num_srcs = opcodes[op].num_inputs 1061*61046927SAndroid Build Coastguard Worker 1062*61046927SAndroid Build Coastguard Worker # Iterate over all possible source combinations where at least one 1063*61046927SAndroid Build Coastguard Worker # is on the worklist. 1064*61046927SAndroid Build Coastguard Worker for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 1065*61046927SAndroid Build Coastguard Worker if all(src_idx < op_worklist_index for src_idx in src_indices): 1066*61046927SAndroid Build Coastguard Worker continue 1067*61046927SAndroid Build Coastguard Worker 1068*61046927SAndroid Build Coastguard Worker srcs = tuple(rep[src_idx] for src_idx in src_indices) 1069*61046927SAndroid Build Coastguard Worker 1070*61046927SAndroid Build Coastguard Worker # Try all possible pairings of source items and add the 1071*61046927SAndroid Build Coastguard Worker # corresponding parent items. This is Comp_a from the paper. 1072*61046927SAndroid Build Coastguard Worker parent = set(self.items[op, item_srcs] for item_srcs in 1073*61046927SAndroid Build Coastguard Worker itertools.product(*srcs) if (op, item_srcs) in self.items) 1074*61046927SAndroid Build Coastguard Worker 1075*61046927SAndroid Build Coastguard Worker # We could always start matching something else with a 1076*61046927SAndroid Build Coastguard Worker # wildcard. This is Cl from the paper. 1077*61046927SAndroid Build Coastguard Worker parent.add(self.wildcard) 1078*61046927SAndroid Build Coastguard Worker 1079*61046927SAndroid Build Coastguard Worker table[src_indices] = self.states.add(frozenset(parent)) 1080*61046927SAndroid Build Coastguard Worker worklist_indices[op] = len(rep) 1081*61046927SAndroid Build Coastguard Worker new_opcodes.clear() 1082*61046927SAndroid Build Coastguard Worker process_new_states() 1083*61046927SAndroid Build Coastguard Worker 1084*61046927SAndroid Build Coastguard Worker_algebraic_pass_template = mako.template.Template(""" 1085*61046927SAndroid Build Coastguard Worker#include "nir.h" 1086*61046927SAndroid Build Coastguard Worker#include "nir_builder.h" 1087*61046927SAndroid Build Coastguard Worker#include "nir_search.h" 1088*61046927SAndroid Build Coastguard Worker#include "nir_search_helpers.h" 1089*61046927SAndroid Build Coastguard Worker 1090*61046927SAndroid Build Coastguard Worker/* What follows is NIR algebraic transform code for the following ${len(xforms)} 1091*61046927SAndroid Build Coastguard Worker * transforms: 1092*61046927SAndroid Build Coastguard Worker% for xform in xforms: 1093*61046927SAndroid Build Coastguard Worker * ${xform.search} => ${xform.replace} 1094*61046927SAndroid Build Coastguard Worker% endfor 1095*61046927SAndroid Build Coastguard Worker */ 1096*61046927SAndroid Build Coastguard Worker 1097*61046927SAndroid Build Coastguard Worker<% cache = {"next_index": 0} %> 1098*61046927SAndroid Build Coastguard Workerstatic const nir_search_value_union ${pass_name}_values[] = { 1099*61046927SAndroid Build Coastguard Worker% for xform in xforms: 1100*61046927SAndroid Build Coastguard Worker /* ${xform.search} => ${xform.replace} */ 1101*61046927SAndroid Build Coastguard Worker${xform.search.render(cache)} 1102*61046927SAndroid Build Coastguard Worker${xform.replace.render(cache)} 1103*61046927SAndroid Build Coastguard Worker% endfor 1104*61046927SAndroid Build Coastguard Worker}; 1105*61046927SAndroid Build Coastguard Worker 1106*61046927SAndroid Build Coastguard Worker% if expression_cond: 1107*61046927SAndroid Build Coastguard Workerstatic const nir_search_expression_cond ${pass_name}_expression_cond[] = { 1108*61046927SAndroid Build Coastguard Worker% for cond in expression_cond: 1109*61046927SAndroid Build Coastguard Worker ${cond[0]}, 1110*61046927SAndroid Build Coastguard Worker% endfor 1111*61046927SAndroid Build Coastguard Worker}; 1112*61046927SAndroid Build Coastguard Worker% endif 1113*61046927SAndroid Build Coastguard Worker 1114*61046927SAndroid Build Coastguard Worker% if variable_cond: 1115*61046927SAndroid Build Coastguard Workerstatic const nir_search_variable_cond ${pass_name}_variable_cond[] = { 1116*61046927SAndroid Build Coastguard Worker% for cond in variable_cond: 1117*61046927SAndroid Build Coastguard Worker ${cond[0]}, 1118*61046927SAndroid Build Coastguard Worker% endfor 1119*61046927SAndroid Build Coastguard Worker}; 1120*61046927SAndroid Build Coastguard Worker% endif 1121*61046927SAndroid Build Coastguard Worker 1122*61046927SAndroid Build Coastguard Workerstatic const struct transform ${pass_name}_transforms[] = { 1123*61046927SAndroid Build Coastguard Worker% for i in automaton.state_patterns: 1124*61046927SAndroid Build Coastguard Worker% if i is not None: 1125*61046927SAndroid Build Coastguard Worker { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} }, 1126*61046927SAndroid Build Coastguard Worker% else: 1127*61046927SAndroid Build Coastguard Worker { ~0, ~0, ~0 }, /* Sentinel */ 1128*61046927SAndroid Build Coastguard Worker 1129*61046927SAndroid Build Coastguard Worker% endif 1130*61046927SAndroid Build Coastguard Worker% endfor 1131*61046927SAndroid Build Coastguard Worker}; 1132*61046927SAndroid Build Coastguard Worker 1133*61046927SAndroid Build Coastguard Workerstatic const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = { 1134*61046927SAndroid Build Coastguard Worker% for op in automaton.opcodes: 1135*61046927SAndroid Build Coastguard Worker [${get_c_opcode(op)}] = { 1136*61046927SAndroid Build Coastguard Worker% if all(e == 0 for e in automaton.filter[op]): 1137*61046927SAndroid Build Coastguard Worker .filter = NULL, 1138*61046927SAndroid Build Coastguard Worker% else: 1139*61046927SAndroid Build Coastguard Worker .filter = (const uint16_t []) { 1140*61046927SAndroid Build Coastguard Worker % for e in automaton.filter[op]: 1141*61046927SAndroid Build Coastguard Worker ${e}, 1142*61046927SAndroid Build Coastguard Worker % endfor 1143*61046927SAndroid Build Coastguard Worker }, 1144*61046927SAndroid Build Coastguard Worker% endif 1145*61046927SAndroid Build Coastguard Worker <% 1146*61046927SAndroid Build Coastguard Worker num_filtered = len(automaton.rep[op]) 1147*61046927SAndroid Build Coastguard Worker %> 1148*61046927SAndroid Build Coastguard Worker .num_filtered_states = ${num_filtered}, 1149*61046927SAndroid Build Coastguard Worker .table = (const uint16_t []) { 1150*61046927SAndroid Build Coastguard Worker <% 1151*61046927SAndroid Build Coastguard Worker num_srcs = len(next(iter(automaton.table[op]))) 1152*61046927SAndroid Build Coastguard Worker %> 1153*61046927SAndroid Build Coastguard Worker % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1154*61046927SAndroid Build Coastguard Worker ${automaton.table[op][indices]}, 1155*61046927SAndroid Build Coastguard Worker % endfor 1156*61046927SAndroid Build Coastguard Worker }, 1157*61046927SAndroid Build Coastguard Worker }, 1158*61046927SAndroid Build Coastguard Worker% endfor 1159*61046927SAndroid Build Coastguard Worker}; 1160*61046927SAndroid Build Coastguard Worker 1161*61046927SAndroid Build Coastguard Worker/* Mapping from state index to offset in transforms (0 being no transforms) */ 1162*61046927SAndroid Build Coastguard Workerstatic const uint16_t ${pass_name}_transform_offsets[] = { 1163*61046927SAndroid Build Coastguard Worker% for offset in automaton.state_pattern_offsets: 1164*61046927SAndroid Build Coastguard Worker ${offset}, 1165*61046927SAndroid Build Coastguard Worker% endfor 1166*61046927SAndroid Build Coastguard Worker}; 1167*61046927SAndroid Build Coastguard Worker 1168*61046927SAndroid Build Coastguard Workerstatic const nir_algebraic_table ${pass_name}_table = { 1169*61046927SAndroid Build Coastguard Worker .transforms = ${pass_name}_transforms, 1170*61046927SAndroid Build Coastguard Worker .transform_offsets = ${pass_name}_transform_offsets, 1171*61046927SAndroid Build Coastguard Worker .pass_op_table = ${pass_name}_pass_op_table, 1172*61046927SAndroid Build Coastguard Worker .values = ${pass_name}_values, 1173*61046927SAndroid Build Coastguard Worker .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" }, 1174*61046927SAndroid Build Coastguard Worker .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" }, 1175*61046927SAndroid Build Coastguard Worker}; 1176*61046927SAndroid Build Coastguard Worker 1177*61046927SAndroid Build Coastguard Workerbool 1178*61046927SAndroid Build Coastguard Worker${pass_name}( 1179*61046927SAndroid Build Coastguard Worker nir_shader *shader 1180*61046927SAndroid Build Coastguard Worker% for type, name in params: 1181*61046927SAndroid Build Coastguard Worker , ${type} ${name} 1182*61046927SAndroid Build Coastguard Worker% endfor 1183*61046927SAndroid Build Coastguard Worker) { 1184*61046927SAndroid Build Coastguard Worker bool progress = false; 1185*61046927SAndroid Build Coastguard Worker bool condition_flags[${len(condition_list)}]; 1186*61046927SAndroid Build Coastguard Worker const nir_shader_compiler_options *options = shader->options; 1187*61046927SAndroid Build Coastguard Worker const shader_info *info = &shader->info; 1188*61046927SAndroid Build Coastguard Worker (void) options; 1189*61046927SAndroid Build Coastguard Worker (void) info; 1190*61046927SAndroid Build Coastguard Worker 1191*61046927SAndroid Build Coastguard Worker STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values)); 1192*61046927SAndroid Build Coastguard Worker % for index, condition in enumerate(condition_list): 1193*61046927SAndroid Build Coastguard Worker condition_flags[${index}] = ${condition}; 1194*61046927SAndroid Build Coastguard Worker % endfor 1195*61046927SAndroid Build Coastguard Worker 1196*61046927SAndroid Build Coastguard Worker nir_foreach_function_impl(impl, shader) { 1197*61046927SAndroid Build Coastguard Worker progress |= nir_algebraic_impl(impl, condition_flags, &${pass_name}_table); 1198*61046927SAndroid Build Coastguard Worker } 1199*61046927SAndroid Build Coastguard Worker 1200*61046927SAndroid Build Coastguard Worker return progress; 1201*61046927SAndroid Build Coastguard Worker} 1202*61046927SAndroid Build Coastguard Worker""") 1203*61046927SAndroid Build Coastguard Worker 1204*61046927SAndroid Build Coastguard Worker 1205*61046927SAndroid Build Coastguard Workerclass AlgebraicPass(object): 1206*61046927SAndroid Build Coastguard Worker # params is a list of `("type", "name")` tuples 1207*61046927SAndroid Build Coastguard Worker def __init__(self, pass_name, transforms, params=[]): 1208*61046927SAndroid Build Coastguard Worker self.xforms = [] 1209*61046927SAndroid Build Coastguard Worker self.opcode_xforms = defaultdict(lambda : []) 1210*61046927SAndroid Build Coastguard Worker self.pass_name = pass_name 1211*61046927SAndroid Build Coastguard Worker self.expression_cond = {} 1212*61046927SAndroid Build Coastguard Worker self.variable_cond = {} 1213*61046927SAndroid Build Coastguard Worker self.params = params 1214*61046927SAndroid Build Coastguard Worker 1215*61046927SAndroid Build Coastguard Worker error = False 1216*61046927SAndroid Build Coastguard Worker 1217*61046927SAndroid Build Coastguard Worker for xform in transforms: 1218*61046927SAndroid Build Coastguard Worker if not isinstance(xform, SearchAndReplace): 1219*61046927SAndroid Build Coastguard Worker try: 1220*61046927SAndroid Build Coastguard Worker xform = SearchAndReplace(xform, self) 1221*61046927SAndroid Build Coastguard Worker except: 1222*61046927SAndroid Build Coastguard Worker print("Failed to parse transformation:", file=sys.stderr) 1223*61046927SAndroid Build Coastguard Worker print(" " + str(xform), file=sys.stderr) 1224*61046927SAndroid Build Coastguard Worker traceback.print_exc(file=sys.stderr) 1225*61046927SAndroid Build Coastguard Worker print('', file=sys.stderr) 1226*61046927SAndroid Build Coastguard Worker error = True 1227*61046927SAndroid Build Coastguard Worker continue 1228*61046927SAndroid Build Coastguard Worker 1229*61046927SAndroid Build Coastguard Worker self.xforms.append(xform) 1230*61046927SAndroid Build Coastguard Worker if xform.search.opcode in conv_opcode_types: 1231*61046927SAndroid Build Coastguard Worker dst_type = conv_opcode_types[xform.search.opcode] 1232*61046927SAndroid Build Coastguard Worker for size in type_sizes(dst_type): 1233*61046927SAndroid Build Coastguard Worker sized_opcode = xform.search.opcode + str(size) 1234*61046927SAndroid Build Coastguard Worker self.opcode_xforms[sized_opcode].append(xform) 1235*61046927SAndroid Build Coastguard Worker else: 1236*61046927SAndroid Build Coastguard Worker self.opcode_xforms[xform.search.opcode].append(xform) 1237*61046927SAndroid Build Coastguard Worker 1238*61046927SAndroid Build Coastguard Worker # Check to make sure the search pattern does not unexpectedly contain 1239*61046927SAndroid Build Coastguard Worker # more commutative expressions than match_expression (nir_search.c) 1240*61046927SAndroid Build Coastguard Worker # can handle. 1241*61046927SAndroid Build Coastguard Worker comm_exprs = xform.search.comm_exprs 1242*61046927SAndroid Build Coastguard Worker 1243*61046927SAndroid Build Coastguard Worker if xform.search.many_commutative_expressions: 1244*61046927SAndroid Build Coastguard Worker if comm_exprs <= nir_search_max_comm_ops: 1245*61046927SAndroid Build Coastguard Worker print("Transform expected to have too many commutative " \ 1246*61046927SAndroid Build Coastguard Worker "expression but did not " \ 1247*61046927SAndroid Build Coastguard Worker "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 1248*61046927SAndroid Build Coastguard Worker file=sys.stderr) 1249*61046927SAndroid Build Coastguard Worker print(" " + str(xform), file=sys.stderr) 1250*61046927SAndroid Build Coastguard Worker traceback.print_exc(file=sys.stderr) 1251*61046927SAndroid Build Coastguard Worker print('', file=sys.stderr) 1252*61046927SAndroid Build Coastguard Worker error = True 1253*61046927SAndroid Build Coastguard Worker else: 1254*61046927SAndroid Build Coastguard Worker if comm_exprs > nir_search_max_comm_ops: 1255*61046927SAndroid Build Coastguard Worker print("Transformation with too many commutative expressions " \ 1256*61046927SAndroid Build Coastguard Worker "({} > {}). Modify pattern or annotate with " \ 1257*61046927SAndroid Build Coastguard Worker "\"many-comm-expr\".".format(comm_exprs, 1258*61046927SAndroid Build Coastguard Worker nir_search_max_comm_ops), 1259*61046927SAndroid Build Coastguard Worker file=sys.stderr) 1260*61046927SAndroid Build Coastguard Worker print(" " + str(xform.search), file=sys.stderr) 1261*61046927SAndroid Build Coastguard Worker print("{}".format(xform.search.cond), file=sys.stderr) 1262*61046927SAndroid Build Coastguard Worker error = True 1263*61046927SAndroid Build Coastguard Worker 1264*61046927SAndroid Build Coastguard Worker self.automaton = TreeAutomaton(self.xforms) 1265*61046927SAndroid Build Coastguard Worker 1266*61046927SAndroid Build Coastguard Worker if error: 1267*61046927SAndroid Build Coastguard Worker sys.exit(1) 1268*61046927SAndroid Build Coastguard Worker 1269*61046927SAndroid Build Coastguard Worker 1270*61046927SAndroid Build Coastguard Worker def render(self): 1271*61046927SAndroid Build Coastguard Worker return _algebraic_pass_template.render(pass_name=self.pass_name, 1272*61046927SAndroid Build Coastguard Worker xforms=self.xforms, 1273*61046927SAndroid Build Coastguard Worker opcode_xforms=self.opcode_xforms, 1274*61046927SAndroid Build Coastguard Worker condition_list=condition_list, 1275*61046927SAndroid Build Coastguard Worker automaton=self.automaton, 1276*61046927SAndroid Build Coastguard Worker expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]), 1277*61046927SAndroid Build Coastguard Worker variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]), 1278*61046927SAndroid Build Coastguard Worker get_c_opcode=get_c_opcode, 1279*61046927SAndroid Build Coastguard Worker itertools=itertools, 1280*61046927SAndroid Build Coastguard Worker params=self.params) 1281*61046927SAndroid Build Coastguard Worker 1282*61046927SAndroid Build Coastguard Worker# The replacement expression isn't necessarily exact if the search expression is exact. 1283*61046927SAndroid Build Coastguard Workerdef ignore_exact(*expr): 1284*61046927SAndroid Build Coastguard Worker expr = SearchExpression.create(expr) 1285*61046927SAndroid Build Coastguard Worker expr.ignore_exact = True 1286*61046927SAndroid Build Coastguard Worker return expr 1287