xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_algebraic.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1*61046927SAndroid Build Coastguard Worker#
2*61046927SAndroid Build Coastguard Worker# Copyright (C) 2014 Intel Corporation
3*61046927SAndroid Build Coastguard Worker#
4*61046927SAndroid Build Coastguard Worker# Permission is hereby granted, free of charge, to any person obtaining a
5*61046927SAndroid Build Coastguard Worker# copy of this software and associated documentation files (the "Software"),
6*61046927SAndroid Build Coastguard Worker# to deal in the Software without restriction, including without limitation
7*61046927SAndroid Build Coastguard Worker# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8*61046927SAndroid Build Coastguard Worker# and/or sell copies of the Software, and to permit persons to whom the
9*61046927SAndroid Build Coastguard Worker# Software is furnished to do so, subject to the following conditions:
10*61046927SAndroid Build Coastguard Worker#
11*61046927SAndroid Build Coastguard Worker# The above copyright notice and this permission notice (including the next
12*61046927SAndroid Build Coastguard Worker# paragraph) shall be included in all copies or substantial portions of the
13*61046927SAndroid Build Coastguard Worker# Software.
14*61046927SAndroid Build Coastguard Worker#
15*61046927SAndroid Build Coastguard Worker# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16*61046927SAndroid Build Coastguard Worker# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17*61046927SAndroid Build Coastguard Worker# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18*61046927SAndroid Build Coastguard Worker# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19*61046927SAndroid Build Coastguard Worker# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20*61046927SAndroid Build Coastguard Worker# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21*61046927SAndroid Build Coastguard Worker# IN THE SOFTWARE.
22*61046927SAndroid Build Coastguard Worker
23*61046927SAndroid Build Coastguard Workerimport ast
24*61046927SAndroid Build Coastguard Workerfrom collections import defaultdict
25*61046927SAndroid Build Coastguard Workerimport itertools
26*61046927SAndroid Build Coastguard Workerimport struct
27*61046927SAndroid Build Coastguard Workerimport sys
28*61046927SAndroid Build Coastguard Workerimport mako.template
29*61046927SAndroid Build Coastguard Workerimport re
30*61046927SAndroid Build Coastguard Workerimport traceback
31*61046927SAndroid Build Coastguard Worker
32*61046927SAndroid Build Coastguard Workerfrom nir_opcodes import opcodes, type_sizes
33*61046927SAndroid Build Coastguard Worker
34*61046927SAndroid Build Coastguard Worker# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
35*61046927SAndroid Build Coastguard Workernir_search_max_comm_ops = 8
36*61046927SAndroid Build Coastguard Worker
37*61046927SAndroid Build Coastguard Worker# These opcodes are only employed by nir_search.  This provides a mapping from
38*61046927SAndroid Build Coastguard Worker# opcode to destination type.
39*61046927SAndroid Build Coastguard Workerconv_opcode_types = {
40*61046927SAndroid Build Coastguard Worker    'i2f' : 'float',
41*61046927SAndroid Build Coastguard Worker    'u2f' : 'float',
42*61046927SAndroid Build Coastguard Worker    'f2f' : 'float',
43*61046927SAndroid Build Coastguard Worker    'f2u' : 'uint',
44*61046927SAndroid Build Coastguard Worker    'f2i' : 'int',
45*61046927SAndroid Build Coastguard Worker    'u2u' : 'uint',
46*61046927SAndroid Build Coastguard Worker    'i2i' : 'int',
47*61046927SAndroid Build Coastguard Worker    'b2f' : 'float',
48*61046927SAndroid Build Coastguard Worker    'b2i' : 'int',
49*61046927SAndroid Build Coastguard Worker    'i2b' : 'bool',
50*61046927SAndroid Build Coastguard Worker    'f2b' : 'bool',
51*61046927SAndroid Build Coastguard Worker}
52*61046927SAndroid Build Coastguard Worker
53*61046927SAndroid Build Coastguard Workerdef get_cond_index(conds, cond):
54*61046927SAndroid Build Coastguard Worker    if cond:
55*61046927SAndroid Build Coastguard Worker        if cond in conds:
56*61046927SAndroid Build Coastguard Worker            return conds[cond]
57*61046927SAndroid Build Coastguard Worker        else:
58*61046927SAndroid Build Coastguard Worker            cond_index = len(conds)
59*61046927SAndroid Build Coastguard Worker            conds[cond] = cond_index
60*61046927SAndroid Build Coastguard Worker            return cond_index
61*61046927SAndroid Build Coastguard Worker    else:
62*61046927SAndroid Build Coastguard Worker        return -1
63*61046927SAndroid Build Coastguard Worker
64*61046927SAndroid Build Coastguard Workerdef get_c_opcode(op):
65*61046927SAndroid Build Coastguard Worker      if op in conv_opcode_types:
66*61046927SAndroid Build Coastguard Worker         return 'nir_search_op_' + op
67*61046927SAndroid Build Coastguard Worker      else:
68*61046927SAndroid Build Coastguard Worker         return 'nir_op_' + op
69*61046927SAndroid Build Coastguard Worker
70*61046927SAndroid Build Coastguard Worker_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
71*61046927SAndroid Build Coastguard Worker
72*61046927SAndroid Build Coastguard Workerdef type_bits(type_str):
73*61046927SAndroid Build Coastguard Worker   m = _type_re.match(type_str)
74*61046927SAndroid Build Coastguard Worker   assert m.group('type')
75*61046927SAndroid Build Coastguard Worker
76*61046927SAndroid Build Coastguard Worker   if m.group('bits') is None:
77*61046927SAndroid Build Coastguard Worker      return 0
78*61046927SAndroid Build Coastguard Worker   else:
79*61046927SAndroid Build Coastguard Worker      return int(m.group('bits'))
80*61046927SAndroid Build Coastguard Worker
81*61046927SAndroid Build Coastguard Worker# Represents a set of variables, each with a unique id
82*61046927SAndroid Build Coastguard Workerclass VarSet(object):
83*61046927SAndroid Build Coastguard Worker   def __init__(self):
84*61046927SAndroid Build Coastguard Worker      self.names = {}
85*61046927SAndroid Build Coastguard Worker      self.ids = itertools.count()
86*61046927SAndroid Build Coastguard Worker      self.immutable = False;
87*61046927SAndroid Build Coastguard Worker
88*61046927SAndroid Build Coastguard Worker   def __getitem__(self, name):
89*61046927SAndroid Build Coastguard Worker      if name not in self.names:
90*61046927SAndroid Build Coastguard Worker         assert not self.immutable, "Unknown replacement variable: " + name
91*61046927SAndroid Build Coastguard Worker         self.names[name] = next(self.ids)
92*61046927SAndroid Build Coastguard Worker
93*61046927SAndroid Build Coastguard Worker      return self.names[name]
94*61046927SAndroid Build Coastguard Worker
95*61046927SAndroid Build Coastguard Worker   def lock(self):
96*61046927SAndroid Build Coastguard Worker      self.immutable = True
97*61046927SAndroid Build Coastguard Worker
98*61046927SAndroid Build Coastguard Workerclass SearchExpression(object):
99*61046927SAndroid Build Coastguard Worker   def __init__(self, expr):
100*61046927SAndroid Build Coastguard Worker      self.opcode = expr[0]
101*61046927SAndroid Build Coastguard Worker      self.sources = expr[1:]
102*61046927SAndroid Build Coastguard Worker      self.ignore_exact = False
103*61046927SAndroid Build Coastguard Worker
104*61046927SAndroid Build Coastguard Worker   @staticmethod
105*61046927SAndroid Build Coastguard Worker   def create(val):
106*61046927SAndroid Build Coastguard Worker      if isinstance(val, tuple):
107*61046927SAndroid Build Coastguard Worker         return SearchExpression(val)
108*61046927SAndroid Build Coastguard Worker      else:
109*61046927SAndroid Build Coastguard Worker         assert(isinstance(val, SearchExpression))
110*61046927SAndroid Build Coastguard Worker         return val
111*61046927SAndroid Build Coastguard Worker
112*61046927SAndroid Build Coastguard Worker   def __repr__(self):
113*61046927SAndroid Build Coastguard Worker      l = [self.opcode, *self.sources]
114*61046927SAndroid Build Coastguard Worker      if self.ignore_exact:
115*61046927SAndroid Build Coastguard Worker         l.append('ignore_exact')
116*61046927SAndroid Build Coastguard Worker      return repr((*l,))
117*61046927SAndroid Build Coastguard Worker
118*61046927SAndroid Build Coastguard Workerclass Value(object):
119*61046927SAndroid Build Coastguard Worker   @staticmethod
120*61046927SAndroid Build Coastguard Worker   def create(val, name_base, varset, algebraic_pass):
121*61046927SAndroid Build Coastguard Worker      if isinstance(val, bytes):
122*61046927SAndroid Build Coastguard Worker         val = val.decode('utf-8')
123*61046927SAndroid Build Coastguard Worker
124*61046927SAndroid Build Coastguard Worker      if isinstance(val, tuple) or isinstance(val, SearchExpression):
125*61046927SAndroid Build Coastguard Worker         return Expression(val, name_base, varset, algebraic_pass)
126*61046927SAndroid Build Coastguard Worker      elif isinstance(val, Expression):
127*61046927SAndroid Build Coastguard Worker         return val
128*61046927SAndroid Build Coastguard Worker      elif isinstance(val, str):
129*61046927SAndroid Build Coastguard Worker         return Variable(val, name_base, varset, algebraic_pass)
130*61046927SAndroid Build Coastguard Worker      elif isinstance(val, (bool, float, int)):
131*61046927SAndroid Build Coastguard Worker         return Constant(val, name_base)
132*61046927SAndroid Build Coastguard Worker
133*61046927SAndroid Build Coastguard Worker   def __init__(self, val, name, type_str):
134*61046927SAndroid Build Coastguard Worker      self.in_val = str(val)
135*61046927SAndroid Build Coastguard Worker      self.name = name
136*61046927SAndroid Build Coastguard Worker      self.type_str = type_str
137*61046927SAndroid Build Coastguard Worker
138*61046927SAndroid Build Coastguard Worker   def __str__(self):
139*61046927SAndroid Build Coastguard Worker      return self.in_val
140*61046927SAndroid Build Coastguard Worker
141*61046927SAndroid Build Coastguard Worker   def get_bit_size(self):
142*61046927SAndroid Build Coastguard Worker      """Get the physical bit-size that has been chosen for this value, or if
143*61046927SAndroid Build Coastguard Worker      there is none, the canonical value which currently represents this
144*61046927SAndroid Build Coastguard Worker      bit-size class. Variables will be preferred, i.e. if there are any
145*61046927SAndroid Build Coastguard Worker      variables in the equivalence class, the canonical value will be a
146*61046927SAndroid Build Coastguard Worker      variable. We do this since we'll need to know which variable each value
147*61046927SAndroid Build Coastguard Worker      is equivalent to when constructing the replacement expression. This is
148*61046927SAndroid Build Coastguard Worker      the "find" part of the union-find algorithm.
149*61046927SAndroid Build Coastguard Worker      """
150*61046927SAndroid Build Coastguard Worker      bit_size = self
151*61046927SAndroid Build Coastguard Worker
152*61046927SAndroid Build Coastguard Worker      while isinstance(bit_size, Value):
153*61046927SAndroid Build Coastguard Worker         if bit_size._bit_size is None:
154*61046927SAndroid Build Coastguard Worker            break
155*61046927SAndroid Build Coastguard Worker         bit_size = bit_size._bit_size
156*61046927SAndroid Build Coastguard Worker
157*61046927SAndroid Build Coastguard Worker      if bit_size is not self:
158*61046927SAndroid Build Coastguard Worker         self._bit_size = bit_size
159*61046927SAndroid Build Coastguard Worker      return bit_size
160*61046927SAndroid Build Coastguard Worker
161*61046927SAndroid Build Coastguard Worker   def set_bit_size(self, other):
162*61046927SAndroid Build Coastguard Worker      """Make self.get_bit_size() return what other.get_bit_size() return
163*61046927SAndroid Build Coastguard Worker      before calling this, or just "other" if it's a concrete bit-size. This is
164*61046927SAndroid Build Coastguard Worker      the "union" part of the union-find algorithm.
165*61046927SAndroid Build Coastguard Worker      """
166*61046927SAndroid Build Coastguard Worker
167*61046927SAndroid Build Coastguard Worker      self_bit_size = self.get_bit_size()
168*61046927SAndroid Build Coastguard Worker      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
169*61046927SAndroid Build Coastguard Worker
170*61046927SAndroid Build Coastguard Worker      if self_bit_size == other_bit_size:
171*61046927SAndroid Build Coastguard Worker         return
172*61046927SAndroid Build Coastguard Worker
173*61046927SAndroid Build Coastguard Worker      self_bit_size._bit_size = other_bit_size
174*61046927SAndroid Build Coastguard Worker
175*61046927SAndroid Build Coastguard Worker   @property
176*61046927SAndroid Build Coastguard Worker   def type_enum(self):
177*61046927SAndroid Build Coastguard Worker      return "nir_search_value_" + self.type_str
178*61046927SAndroid Build Coastguard Worker
179*61046927SAndroid Build Coastguard Worker   @property
180*61046927SAndroid Build Coastguard Worker   def c_bit_size(self):
181*61046927SAndroid Build Coastguard Worker      bit_size = self.get_bit_size()
182*61046927SAndroid Build Coastguard Worker      if isinstance(bit_size, int):
183*61046927SAndroid Build Coastguard Worker         return bit_size
184*61046927SAndroid Build Coastguard Worker      elif isinstance(bit_size, Variable):
185*61046927SAndroid Build Coastguard Worker         return -bit_size.index - 1
186*61046927SAndroid Build Coastguard Worker      else:
187*61046927SAndroid Build Coastguard Worker         # If the bit-size class is neither a variable, nor an actual bit-size, then
188*61046927SAndroid Build Coastguard Worker         # - If it's in the search expression, we don't need to check anything
189*61046927SAndroid Build Coastguard Worker         # - If it's in the replace expression, either it's ambiguous (in which
190*61046927SAndroid Build Coastguard Worker         # case we'd reject it), or it equals the bit-size of the search value
191*61046927SAndroid Build Coastguard Worker         # We represent these cases with a 0 bit-size.
192*61046927SAndroid Build Coastguard Worker         return 0
193*61046927SAndroid Build Coastguard Worker
194*61046927SAndroid Build Coastguard Worker   __template = mako.template.Template("""   { .${val.type_str} = {
195*61046927SAndroid Build Coastguard Worker      { ${val.type_enum}, ${val.c_bit_size} },
196*61046927SAndroid Build Coastguard Worker% if isinstance(val, Constant):
197*61046927SAndroid Build Coastguard Worker      ${val.type()}, { ${val.hex()} /* ${val.value} */ },
198*61046927SAndroid Build Coastguard Worker% elif isinstance(val, Variable):
199*61046927SAndroid Build Coastguard Worker      ${val.index}, /* ${val.var_name} */
200*61046927SAndroid Build Coastguard Worker      ${'true' if val.is_constant else 'false'},
201*61046927SAndroid Build Coastguard Worker      ${val.type() or 'nir_type_invalid' },
202*61046927SAndroid Build Coastguard Worker      ${val.cond_index},
203*61046927SAndroid Build Coastguard Worker      ${val.swizzle()},
204*61046927SAndroid Build Coastguard Worker% elif isinstance(val, Expression):
205*61046927SAndroid Build Coastguard Worker      ${'true' if val.inexact else 'false'},
206*61046927SAndroid Build Coastguard Worker      ${'true' if val.exact else 'false'},
207*61046927SAndroid Build Coastguard Worker      ${'true' if val.ignore_exact else 'false'},
208*61046927SAndroid Build Coastguard Worker      ${'true' if val.nsz else 'false'},
209*61046927SAndroid Build Coastguard Worker      ${'true' if val.nnan else 'false'},
210*61046927SAndroid Build Coastguard Worker      ${'true' if val.ninf else 'false'},
211*61046927SAndroid Build Coastguard Worker      ${val.c_opcode()},
212*61046927SAndroid Build Coastguard Worker      ${val.comm_expr_idx}, ${val.comm_exprs},
213*61046927SAndroid Build Coastguard Worker      { ${', '.join(src.array_index for src in val.sources)} },
214*61046927SAndroid Build Coastguard Worker      ${val.cond_index},
215*61046927SAndroid Build Coastguard Worker% endif
216*61046927SAndroid Build Coastguard Worker   } },
217*61046927SAndroid Build Coastguard Worker""")
218*61046927SAndroid Build Coastguard Worker
219*61046927SAndroid Build Coastguard Worker   def render(self, cache):
220*61046927SAndroid Build Coastguard Worker      struct_init = self.__template.render(val=self,
221*61046927SAndroid Build Coastguard Worker                                           Constant=Constant,
222*61046927SAndroid Build Coastguard Worker                                           Variable=Variable,
223*61046927SAndroid Build Coastguard Worker                                           Expression=Expression)
224*61046927SAndroid Build Coastguard Worker      if struct_init in cache:
225*61046927SAndroid Build Coastguard Worker         # If it's in the cache, register a name remap in the cache and render
226*61046927SAndroid Build Coastguard Worker         # only a comment saying it's been remapped
227*61046927SAndroid Build Coastguard Worker         self.array_index = cache[struct_init]
228*61046927SAndroid Build Coastguard Worker         return "   /* {} -> {} in the cache */\n".format(self.name,
229*61046927SAndroid Build Coastguard Worker                                                       cache[struct_init])
230*61046927SAndroid Build Coastguard Worker      else:
231*61046927SAndroid Build Coastguard Worker         self.array_index = str(cache["next_index"])
232*61046927SAndroid Build Coastguard Worker         cache[struct_init] = self.array_index
233*61046927SAndroid Build Coastguard Worker         cache["next_index"] += 1
234*61046927SAndroid Build Coastguard Worker         return struct_init
235*61046927SAndroid Build Coastguard Worker
236*61046927SAndroid Build Coastguard Worker_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
237*61046927SAndroid Build Coastguard Worker
238*61046927SAndroid Build Coastguard Workerclass Constant(Value):
239*61046927SAndroid Build Coastguard Worker   def __init__(self, val, name):
240*61046927SAndroid Build Coastguard Worker      Value.__init__(self, val, name, "constant")
241*61046927SAndroid Build Coastguard Worker
242*61046927SAndroid Build Coastguard Worker      if isinstance(val, (str)):
243*61046927SAndroid Build Coastguard Worker         m = _constant_re.match(val)
244*61046927SAndroid Build Coastguard Worker         self.value = ast.literal_eval(m.group('value'))
245*61046927SAndroid Build Coastguard Worker         self._bit_size = int(m.group('bits')) if m.group('bits') else None
246*61046927SAndroid Build Coastguard Worker      else:
247*61046927SAndroid Build Coastguard Worker         self.value = val
248*61046927SAndroid Build Coastguard Worker         self._bit_size = None
249*61046927SAndroid Build Coastguard Worker
250*61046927SAndroid Build Coastguard Worker      if isinstance(self.value, bool):
251*61046927SAndroid Build Coastguard Worker         assert self._bit_size is None or self._bit_size == 1
252*61046927SAndroid Build Coastguard Worker         self._bit_size = 1
253*61046927SAndroid Build Coastguard Worker
254*61046927SAndroid Build Coastguard Worker   def hex(self):
255*61046927SAndroid Build Coastguard Worker      if isinstance(self.value, (bool)):
256*61046927SAndroid Build Coastguard Worker         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
257*61046927SAndroid Build Coastguard Worker      if isinstance(self.value, int):
258*61046927SAndroid Build Coastguard Worker         # Explicitly sign-extend negative integers to 64-bit, ensuring correct
259*61046927SAndroid Build Coastguard Worker         # handling of -INT32_MIN which is not representable in 32-bit.
260*61046927SAndroid Build Coastguard Worker         if self.value < 0:
261*61046927SAndroid Build Coastguard Worker            return hex(struct.unpack('Q', struct.pack('q', self.value))[0]) + 'ull'
262*61046927SAndroid Build Coastguard Worker         else:
263*61046927SAndroid Build Coastguard Worker            return hex(self.value) + 'ull'
264*61046927SAndroid Build Coastguard Worker      elif isinstance(self.value, float):
265*61046927SAndroid Build Coastguard Worker         return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) + 'ull'
266*61046927SAndroid Build Coastguard Worker      else:
267*61046927SAndroid Build Coastguard Worker         assert False
268*61046927SAndroid Build Coastguard Worker
269*61046927SAndroid Build Coastguard Worker   def type(self):
270*61046927SAndroid Build Coastguard Worker      if isinstance(self.value, (bool)):
271*61046927SAndroid Build Coastguard Worker         return "nir_type_bool"
272*61046927SAndroid Build Coastguard Worker      elif isinstance(self.value, int):
273*61046927SAndroid Build Coastguard Worker         return "nir_type_int"
274*61046927SAndroid Build Coastguard Worker      elif isinstance(self.value, float):
275*61046927SAndroid Build Coastguard Worker         return "nir_type_float"
276*61046927SAndroid Build Coastguard Worker
277*61046927SAndroid Build Coastguard Worker   def equivalent(self, other):
278*61046927SAndroid Build Coastguard Worker      """Check that two constants are equivalent.
279*61046927SAndroid Build Coastguard Worker
280*61046927SAndroid Build Coastguard Worker      This is check is much weaker than equality.  One generally cannot be
281*61046927SAndroid Build Coastguard Worker      used in place of the other.  Using this implementation for the __eq__
282*61046927SAndroid Build Coastguard Worker      will break BitSizeValidator.
283*61046927SAndroid Build Coastguard Worker
284*61046927SAndroid Build Coastguard Worker      """
285*61046927SAndroid Build Coastguard Worker      if not isinstance(other, type(self)):
286*61046927SAndroid Build Coastguard Worker         return False
287*61046927SAndroid Build Coastguard Worker
288*61046927SAndroid Build Coastguard Worker      return self.value == other.value
289*61046927SAndroid Build Coastguard Worker
290*61046927SAndroid Build Coastguard Worker# The $ at the end forces there to be an error if any part of the string
291*61046927SAndroid Build Coastguard Worker# doesn't match one of the field patterns.
292*61046927SAndroid Build Coastguard Worker_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
293*61046927SAndroid Build Coastguard Worker                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
294*61046927SAndroid Build Coastguard Worker                          r"(?P<cond>\([^\)]+\))?"
295*61046927SAndroid Build Coastguard Worker                          r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?"
296*61046927SAndroid Build Coastguard Worker                          r"$")
297*61046927SAndroid Build Coastguard Worker
298*61046927SAndroid Build Coastguard Workerclass Variable(Value):
299*61046927SAndroid Build Coastguard Worker   def __init__(self, val, name, varset, algebraic_pass):
300*61046927SAndroid Build Coastguard Worker      Value.__init__(self, val, name, "variable")
301*61046927SAndroid Build Coastguard Worker
302*61046927SAndroid Build Coastguard Worker      m = _var_name_re.match(val)
303*61046927SAndroid Build Coastguard Worker      assert m and m.group('name') is not None, \
304*61046927SAndroid Build Coastguard Worker            "Malformed variable name \"{}\".".format(val)
305*61046927SAndroid Build Coastguard Worker
306*61046927SAndroid Build Coastguard Worker      self.var_name = m.group('name')
307*61046927SAndroid Build Coastguard Worker
308*61046927SAndroid Build Coastguard Worker      # Prevent common cases where someone puts quotes around a literal
309*61046927SAndroid Build Coastguard Worker      # constant.  If we want to support names that have numeric or
310*61046927SAndroid Build Coastguard Worker      # punctuation characters, we can me the first assertion more flexible.
311*61046927SAndroid Build Coastguard Worker      assert self.var_name.isalpha()
312*61046927SAndroid Build Coastguard Worker      assert self.var_name != 'True'
313*61046927SAndroid Build Coastguard Worker      assert self.var_name != 'False'
314*61046927SAndroid Build Coastguard Worker
315*61046927SAndroid Build Coastguard Worker      self.is_constant = m.group('const') is not None
316*61046927SAndroid Build Coastguard Worker      self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond'))
317*61046927SAndroid Build Coastguard Worker      self.required_type = m.group('type')
318*61046927SAndroid Build Coastguard Worker      self._bit_size = int(m.group('bits')) if m.group('bits') else None
319*61046927SAndroid Build Coastguard Worker      self.swiz = m.group('swiz')
320*61046927SAndroid Build Coastguard Worker
321*61046927SAndroid Build Coastguard Worker      if self.required_type == 'bool':
322*61046927SAndroid Build Coastguard Worker         if self._bit_size is not None:
323*61046927SAndroid Build Coastguard Worker            assert self._bit_size in type_sizes(self.required_type)
324*61046927SAndroid Build Coastguard Worker         else:
325*61046927SAndroid Build Coastguard Worker            self._bit_size = 1
326*61046927SAndroid Build Coastguard Worker
327*61046927SAndroid Build Coastguard Worker      if self.required_type is not None:
328*61046927SAndroid Build Coastguard Worker         assert self.required_type in ('float', 'bool', 'int', 'uint')
329*61046927SAndroid Build Coastguard Worker
330*61046927SAndroid Build Coastguard Worker      self.index = varset[self.var_name]
331*61046927SAndroid Build Coastguard Worker
332*61046927SAndroid Build Coastguard Worker   def type(self):
333*61046927SAndroid Build Coastguard Worker      if self.required_type == 'bool':
334*61046927SAndroid Build Coastguard Worker         return "nir_type_bool"
335*61046927SAndroid Build Coastguard Worker      elif self.required_type in ('int', 'uint'):
336*61046927SAndroid Build Coastguard Worker         return "nir_type_int"
337*61046927SAndroid Build Coastguard Worker      elif self.required_type == 'float':
338*61046927SAndroid Build Coastguard Worker         return "nir_type_float"
339*61046927SAndroid Build Coastguard Worker
340*61046927SAndroid Build Coastguard Worker   def equivalent(self, other):
341*61046927SAndroid Build Coastguard Worker      """Check that two variables are equivalent.
342*61046927SAndroid Build Coastguard Worker
343*61046927SAndroid Build Coastguard Worker      This is check is much weaker than equality.  One generally cannot be
344*61046927SAndroid Build Coastguard Worker      used in place of the other.  Using this implementation for the __eq__
345*61046927SAndroid Build Coastguard Worker      will break BitSizeValidator.
346*61046927SAndroid Build Coastguard Worker
347*61046927SAndroid Build Coastguard Worker      """
348*61046927SAndroid Build Coastguard Worker      if not isinstance(other, type(self)):
349*61046927SAndroid Build Coastguard Worker         return False
350*61046927SAndroid Build Coastguard Worker
351*61046927SAndroid Build Coastguard Worker      return self.index == other.index
352*61046927SAndroid Build Coastguard Worker
353*61046927SAndroid Build Coastguard Worker   def swizzle(self):
354*61046927SAndroid Build Coastguard Worker      if self.swiz is not None:
355*61046927SAndroid Build Coastguard Worker         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
356*61046927SAndroid Build Coastguard Worker                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
357*61046927SAndroid Build Coastguard Worker                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
358*61046927SAndroid Build Coastguard Worker                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
359*61046927SAndroid Build Coastguard Worker                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
360*61046927SAndroid Build Coastguard Worker         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
361*61046927SAndroid Build Coastguard Worker      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
362*61046927SAndroid Build Coastguard Worker
363*61046927SAndroid Build Coastguard Worker_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
364*61046927SAndroid Build Coastguard Worker                        r"(?P<cond>\([^\)]+\))?")
365*61046927SAndroid Build Coastguard Worker
366*61046927SAndroid Build Coastguard Workerclass Expression(Value):
367*61046927SAndroid Build Coastguard Worker   def __init__(self, expr, name_base, varset, algebraic_pass):
368*61046927SAndroid Build Coastguard Worker      Value.__init__(self, expr, name_base, "expression")
369*61046927SAndroid Build Coastguard Worker
370*61046927SAndroid Build Coastguard Worker      expr = SearchExpression.create(expr)
371*61046927SAndroid Build Coastguard Worker
372*61046927SAndroid Build Coastguard Worker      m = _opcode_re.match(expr.opcode)
373*61046927SAndroid Build Coastguard Worker      assert m and m.group('opcode') is not None
374*61046927SAndroid Build Coastguard Worker
375*61046927SAndroid Build Coastguard Worker      self.opcode = m.group('opcode')
376*61046927SAndroid Build Coastguard Worker      self._bit_size = int(m.group('bits')) if m.group('bits') else None
377*61046927SAndroid Build Coastguard Worker      self.inexact = m.group('inexact') is not None
378*61046927SAndroid Build Coastguard Worker      self.exact = m.group('exact') is not None
379*61046927SAndroid Build Coastguard Worker      self.ignore_exact = expr.ignore_exact
380*61046927SAndroid Build Coastguard Worker      self.cond = m.group('cond')
381*61046927SAndroid Build Coastguard Worker
382*61046927SAndroid Build Coastguard Worker      assert not self.inexact or not self.exact, \
383*61046927SAndroid Build Coastguard Worker            'Expression cannot be both exact and inexact.'
384*61046927SAndroid Build Coastguard Worker
385*61046927SAndroid Build Coastguard Worker      # "many-comm-expr" isn't really a condition.  It's notification to the
386*61046927SAndroid Build Coastguard Worker      # generator that this pattern is known to have too many commutative
387*61046927SAndroid Build Coastguard Worker      # expressions, and an error should not be generated for this case.
388*61046927SAndroid Build Coastguard Worker      # nsz, nnan and ninf are special conditions, so we treat them specially too.
389*61046927SAndroid Build Coastguard Worker      cond = {k: True for k in self.cond[1:-1].split(",")} if self.cond else {}
390*61046927SAndroid Build Coastguard Worker      self.many_commutative_expressions = cond.pop('many-comm-expr', False)
391*61046927SAndroid Build Coastguard Worker      self.nsz = cond.pop('nsz', False)
392*61046927SAndroid Build Coastguard Worker      self.nnan = cond.pop('nnan', False)
393*61046927SAndroid Build Coastguard Worker      self.ninf = cond.pop('ninf', False)
394*61046927SAndroid Build Coastguard Worker
395*61046927SAndroid Build Coastguard Worker      assert len(cond) <= 1
396*61046927SAndroid Build Coastguard Worker      self.cond = cond.popitem()[0] if cond else None
397*61046927SAndroid Build Coastguard Worker
398*61046927SAndroid Build Coastguard Worker      # Deduplicate references to the condition functions for the expressions
399*61046927SAndroid Build Coastguard Worker      # and save the index for the order they were added.
400*61046927SAndroid Build Coastguard Worker      self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
401*61046927SAndroid Build Coastguard Worker
402*61046927SAndroid Build Coastguard Worker      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
403*61046927SAndroid Build Coastguard Worker                       for (i, src) in enumerate(expr.sources) ]
404*61046927SAndroid Build Coastguard Worker
405*61046927SAndroid Build Coastguard Worker      # nir_search_expression::srcs is hard-coded to 4
406*61046927SAndroid Build Coastguard Worker      assert len(self.sources) <= 4
407*61046927SAndroid Build Coastguard Worker
408*61046927SAndroid Build Coastguard Worker      if self.opcode in conv_opcode_types:
409*61046927SAndroid Build Coastguard Worker         assert self._bit_size is None, \
410*61046927SAndroid Build Coastguard Worker                'Expression cannot use an unsized conversion opcode with ' \
411*61046927SAndroid Build Coastguard Worker                'an explicit size; that\'s silly.'
412*61046927SAndroid Build Coastguard Worker
413*61046927SAndroid Build Coastguard Worker      self.__index_comm_exprs(0)
414*61046927SAndroid Build Coastguard Worker
415*61046927SAndroid Build Coastguard Worker   def equivalent(self, other):
416*61046927SAndroid Build Coastguard Worker      """Check that two variables are equivalent.
417*61046927SAndroid Build Coastguard Worker
418*61046927SAndroid Build Coastguard Worker      This is check is much weaker than equality.  One generally cannot be
419*61046927SAndroid Build Coastguard Worker      used in place of the other.  Using this implementation for the __eq__
420*61046927SAndroid Build Coastguard Worker      will break BitSizeValidator.
421*61046927SAndroid Build Coastguard Worker
422*61046927SAndroid Build Coastguard Worker      This implementation does not check for equivalence due to commutativity,
423*61046927SAndroid Build Coastguard Worker      but it could.
424*61046927SAndroid Build Coastguard Worker
425*61046927SAndroid Build Coastguard Worker      """
426*61046927SAndroid Build Coastguard Worker      if not isinstance(other, type(self)):
427*61046927SAndroid Build Coastguard Worker         return False
428*61046927SAndroid Build Coastguard Worker
429*61046927SAndroid Build Coastguard Worker      if len(self.sources) != len(other.sources):
430*61046927SAndroid Build Coastguard Worker         return False
431*61046927SAndroid Build Coastguard Worker
432*61046927SAndroid Build Coastguard Worker      if self.opcode != other.opcode:
433*61046927SAndroid Build Coastguard Worker         return False
434*61046927SAndroid Build Coastguard Worker
435*61046927SAndroid Build Coastguard Worker      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
436*61046927SAndroid Build Coastguard Worker
437*61046927SAndroid Build Coastguard Worker   def __index_comm_exprs(self, base_idx):
438*61046927SAndroid Build Coastguard Worker      """Recursively count and index commutative expressions
439*61046927SAndroid Build Coastguard Worker      """
440*61046927SAndroid Build Coastguard Worker      self.comm_exprs = 0
441*61046927SAndroid Build Coastguard Worker
442*61046927SAndroid Build Coastguard Worker      # A note about the explicit "len(self.sources)" check. The list of
443*61046927SAndroid Build Coastguard Worker      # sources comes from user input, and that input might be bad.  Check
444*61046927SAndroid Build Coastguard Worker      # that the expected second source exists before accessing it. Without
445*61046927SAndroid Build Coastguard Worker      # this check, a unit test that does "('iadd', 'a')" will crash.
446*61046927SAndroid Build Coastguard Worker      if self.opcode not in conv_opcode_types and \
447*61046927SAndroid Build Coastguard Worker         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
448*61046927SAndroid Build Coastguard Worker         len(self.sources) >= 2 and \
449*61046927SAndroid Build Coastguard Worker         not self.sources[0].equivalent(self.sources[1]):
450*61046927SAndroid Build Coastguard Worker         self.comm_expr_idx = base_idx
451*61046927SAndroid Build Coastguard Worker         self.comm_exprs += 1
452*61046927SAndroid Build Coastguard Worker      else:
453*61046927SAndroid Build Coastguard Worker         self.comm_expr_idx = -1
454*61046927SAndroid Build Coastguard Worker
455*61046927SAndroid Build Coastguard Worker      for s in self.sources:
456*61046927SAndroid Build Coastguard Worker         if isinstance(s, Expression):
457*61046927SAndroid Build Coastguard Worker            s.__index_comm_exprs(base_idx + self.comm_exprs)
458*61046927SAndroid Build Coastguard Worker            self.comm_exprs += s.comm_exprs
459*61046927SAndroid Build Coastguard Worker
460*61046927SAndroid Build Coastguard Worker      return self.comm_exprs
461*61046927SAndroid Build Coastguard Worker
462*61046927SAndroid Build Coastguard Worker   def c_opcode(self):
463*61046927SAndroid Build Coastguard Worker      return get_c_opcode(self.opcode)
464*61046927SAndroid Build Coastguard Worker
465*61046927SAndroid Build Coastguard Worker   def render(self, cache):
466*61046927SAndroid Build Coastguard Worker      srcs = "".join(src.render(cache) for src in self.sources)
467*61046927SAndroid Build Coastguard Worker      return srcs + super(Expression, self).render(cache)
468*61046927SAndroid Build Coastguard Worker
469*61046927SAndroid Build Coastguard Workerclass BitSizeValidator(object):
470*61046927SAndroid Build Coastguard Worker   """A class for validating bit sizes of expressions.
471*61046927SAndroid Build Coastguard Worker
472*61046927SAndroid Build Coastguard Worker   NIR supports multiple bit-sizes on expressions in order to handle things
473*61046927SAndroid Build Coastguard Worker   such as fp64.  The source and destination of every ALU operation is
474*61046927SAndroid Build Coastguard Worker   assigned a type and that type may or may not specify a bit size.  Sources
475*61046927SAndroid Build Coastguard Worker   and destinations whose type does not specify a bit size are considered
476*61046927SAndroid Build Coastguard Worker   "unsized" and automatically take on the bit size of the corresponding
477*61046927SAndroid Build Coastguard Worker   register or SSA value.  NIR has two simple rules for bit sizes that are
478*61046927SAndroid Build Coastguard Worker   validated by nir_validator:
479*61046927SAndroid Build Coastguard Worker
480*61046927SAndroid Build Coastguard Worker    1) A given SSA def or register has a single bit size that is respected by
481*61046927SAndroid Build Coastguard Worker       everything that reads from it or writes to it.
482*61046927SAndroid Build Coastguard Worker
483*61046927SAndroid Build Coastguard Worker    2) The bit sizes of all unsized inputs/outputs on any given ALU
484*61046927SAndroid Build Coastguard Worker       instruction must match.  They need not match the sized inputs or
485*61046927SAndroid Build Coastguard Worker       outputs but they must match each other.
486*61046927SAndroid Build Coastguard Worker
487*61046927SAndroid Build Coastguard Worker   In order to keep nir_algebraic relatively simple and easy-to-use,
488*61046927SAndroid Build Coastguard Worker   nir_search supports a type of bit-size inference based on the two rules
489*61046927SAndroid Build Coastguard Worker   above.  This is similar to type inference in many common programming
490*61046927SAndroid Build Coastguard Worker   languages.  If, for instance, you are constructing an add operation and you
491*61046927SAndroid Build Coastguard Worker   know the second source is 16-bit, then you know that the other source and
492*61046927SAndroid Build Coastguard Worker   the destination must also be 16-bit.  There are, however, cases where this
493*61046927SAndroid Build Coastguard Worker   inference can be ambiguous or contradictory.  Consider, for instance, the
494*61046927SAndroid Build Coastguard Worker   following transformation:
495*61046927SAndroid Build Coastguard Worker
496*61046927SAndroid Build Coastguard Worker   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
497*61046927SAndroid Build Coastguard Worker
498*61046927SAndroid Build Coastguard Worker   This transformation can potentially cause a problem because usub_borrow is
499*61046927SAndroid Build Coastguard Worker   well-defined for any bit-size of integer.  However, b2i always generates a
500*61046927SAndroid Build Coastguard Worker   32-bit result so it could end up replacing a 64-bit expression with one
501*61046927SAndroid Build Coastguard Worker   that takes two 64-bit values and produces a 32-bit value.  As another
502*61046927SAndroid Build Coastguard Worker   example, consider this expression:
503*61046927SAndroid Build Coastguard Worker
504*61046927SAndroid Build Coastguard Worker   (('bcsel', a, b, 0), ('iand', a, b))
505*61046927SAndroid Build Coastguard Worker
506*61046927SAndroid Build Coastguard Worker   In this case, in the search expression a must be 32-bit but b can
507*61046927SAndroid Build Coastguard Worker   potentially have any bit size.  If we had a 64-bit b value, we would end up
508*61046927SAndroid Build Coastguard Worker   trying to and a 32-bit value with a 64-bit value which would be invalid
509*61046927SAndroid Build Coastguard Worker
510*61046927SAndroid Build Coastguard Worker   This class solves that problem by providing a validation layer that proves
511*61046927SAndroid Build Coastguard Worker   that a given search-and-replace operation is 100% well-defined before we
512*61046927SAndroid Build Coastguard Worker   generate any code.  This ensures that bugs are caught at compile time
513*61046927SAndroid Build Coastguard Worker   rather than at run time.
514*61046927SAndroid Build Coastguard Worker
515*61046927SAndroid Build Coastguard Worker   Each value maintains a "bit-size class", which is either an actual bit size
516*61046927SAndroid Build Coastguard Worker   or an equivalence class with other values that must have the same bit size.
517*61046927SAndroid Build Coastguard Worker   The validator works by combining bit-size classes with each other according
518*61046927SAndroid Build Coastguard Worker   to the NIR rules outlined above, checking that there are no inconsistencies.
519*61046927SAndroid Build Coastguard Worker   When doing this for the replacement expression, we make sure to never change
520*61046927SAndroid Build Coastguard Worker   the equivalence class of any of the search values. We could make the example
521*61046927SAndroid Build Coastguard Worker   transforms above work by doing some extra run-time checking of the search
522*61046927SAndroid Build Coastguard Worker   expression, but we make the user specify those constraints themselves, to
523*61046927SAndroid Build Coastguard Worker   avoid any surprises. Since the replacement bitsizes can only be connected to
524*61046927SAndroid Build Coastguard Worker   the source bitsize via variables (variables must have the same bitsize in
525*61046927SAndroid Build Coastguard Worker   the source and replacment expressions) or the roots of the expression (the
526*61046927SAndroid Build Coastguard Worker   replacement expression must produce the same bit size as the search
527*61046927SAndroid Build Coastguard Worker   expression), we prevent merging a variable with anything when processing the
528*61046927SAndroid Build Coastguard Worker   replacement expression, or specializing the search bitsize
529*61046927SAndroid Build Coastguard Worker   with anything. The former prevents
530*61046927SAndroid Build Coastguard Worker
531*61046927SAndroid Build Coastguard Worker   (('bcsel', a, b, 0), ('iand', a, b))
532*61046927SAndroid Build Coastguard Worker
533*61046927SAndroid Build Coastguard Worker   from being allowed, since we'd have to merge the bitsizes for a and b due to
534*61046927SAndroid Build Coastguard Worker   the 'iand', while the latter prevents
535*61046927SAndroid Build Coastguard Worker
536*61046927SAndroid Build Coastguard Worker   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
537*61046927SAndroid Build Coastguard Worker
538*61046927SAndroid Build Coastguard Worker   from being allowed, since the search expression has the bit size of a and b,
539*61046927SAndroid Build Coastguard Worker   which can't be specialized to 32 which is the bitsize of the replace
540*61046927SAndroid Build Coastguard Worker   expression. It also prevents something like:
541*61046927SAndroid Build Coastguard Worker
542*61046927SAndroid Build Coastguard Worker   (('b2i', ('i2b', a)), ('ineq', a, 0))
543*61046927SAndroid Build Coastguard Worker
544*61046927SAndroid Build Coastguard Worker   since the bitsize of 'b2i', which can be anything, can't be specialized to
545*61046927SAndroid Build Coastguard Worker   the bitsize of a.
546*61046927SAndroid Build Coastguard Worker
547*61046927SAndroid Build Coastguard Worker   After doing all this, we check that every subexpression of the replacement
548*61046927SAndroid Build Coastguard Worker   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
549*61046927SAndroid Build Coastguard Worker   of the search expresssion, since those are the things that are known when
550*61046927SAndroid Build Coastguard Worker   constructing the replacement expresssion. Finally, we record the bitsize
551*61046927SAndroid Build Coastguard Worker   needed in nir_search_value so that we know what to do when building the
552*61046927SAndroid Build Coastguard Worker   replacement expression.
553*61046927SAndroid Build Coastguard Worker   """
554*61046927SAndroid Build Coastguard Worker
555*61046927SAndroid Build Coastguard Worker   def __init__(self, varset):
556*61046927SAndroid Build Coastguard Worker      self._var_classes = [None] * len(varset.names)
557*61046927SAndroid Build Coastguard Worker
558*61046927SAndroid Build Coastguard Worker   def compare_bitsizes(self, a, b):
559*61046927SAndroid Build Coastguard Worker      """Determines which bitsize class is a specialization of the other, or
560*61046927SAndroid Build Coastguard Worker      whether neither is. When we merge two different bitsizes, the
561*61046927SAndroid Build Coastguard Worker      less-specialized bitsize always points to the more-specialized one, so
562*61046927SAndroid Build Coastguard Worker      that calling get_bit_size() always gets you the most specialized bitsize.
563*61046927SAndroid Build Coastguard Worker      The specialization partial order is given by:
564*61046927SAndroid Build Coastguard Worker      - Physical bitsizes are always the most specialized, and a different
565*61046927SAndroid Build Coastguard Worker        bitsize can never specialize another.
566*61046927SAndroid Build Coastguard Worker      - In the search expression, variables can always be specialized to each
567*61046927SAndroid Build Coastguard Worker        other and to physical bitsizes. In the replace expression, we disallow
568*61046927SAndroid Build Coastguard Worker        this to avoid adding extra constraints to the search expression that
569*61046927SAndroid Build Coastguard Worker        the user didn't specify.
570*61046927SAndroid Build Coastguard Worker      - Expressions and constants without a bitsize can always be specialized to
571*61046927SAndroid Build Coastguard Worker        each other and variables, but not the other way around.
572*61046927SAndroid Build Coastguard Worker
573*61046927SAndroid Build Coastguard Worker        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
574*61046927SAndroid Build Coastguard Worker        and None if they are not comparable (neither a <= b nor b <= a).
575*61046927SAndroid Build Coastguard Worker      """
576*61046927SAndroid Build Coastguard Worker      if isinstance(a, int):
577*61046927SAndroid Build Coastguard Worker         if isinstance(b, int):
578*61046927SAndroid Build Coastguard Worker            return 0 if a == b else None
579*61046927SAndroid Build Coastguard Worker         elif isinstance(b, Variable):
580*61046927SAndroid Build Coastguard Worker            return -1 if self.is_search else None
581*61046927SAndroid Build Coastguard Worker         else:
582*61046927SAndroid Build Coastguard Worker            return -1
583*61046927SAndroid Build Coastguard Worker      elif isinstance(a, Variable):
584*61046927SAndroid Build Coastguard Worker         if isinstance(b, int):
585*61046927SAndroid Build Coastguard Worker            return 1 if self.is_search else None
586*61046927SAndroid Build Coastguard Worker         elif isinstance(b, Variable):
587*61046927SAndroid Build Coastguard Worker            return 0 if self.is_search or a.index == b.index else None
588*61046927SAndroid Build Coastguard Worker         else:
589*61046927SAndroid Build Coastguard Worker            return -1
590*61046927SAndroid Build Coastguard Worker      else:
591*61046927SAndroid Build Coastguard Worker         if isinstance(b, int):
592*61046927SAndroid Build Coastguard Worker            return 1
593*61046927SAndroid Build Coastguard Worker         elif isinstance(b, Variable):
594*61046927SAndroid Build Coastguard Worker            return 1
595*61046927SAndroid Build Coastguard Worker         else:
596*61046927SAndroid Build Coastguard Worker            return 0
597*61046927SAndroid Build Coastguard Worker
598*61046927SAndroid Build Coastguard Worker   def unify_bit_size(self, a, b, error_msg):
599*61046927SAndroid Build Coastguard Worker      """Record that a must have the same bit-size as b. If both
600*61046927SAndroid Build Coastguard Worker      have been assigned conflicting physical bit-sizes, call "error_msg" with
601*61046927SAndroid Build Coastguard Worker      the bit-sizes of self and other to get a message and raise an error.
602*61046927SAndroid Build Coastguard Worker      In the replace expression, disallow merging variables with other
603*61046927SAndroid Build Coastguard Worker      variables and physical bit-sizes as well.
604*61046927SAndroid Build Coastguard Worker      """
605*61046927SAndroid Build Coastguard Worker      a_bit_size = a.get_bit_size()
606*61046927SAndroid Build Coastguard Worker      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
607*61046927SAndroid Build Coastguard Worker
608*61046927SAndroid Build Coastguard Worker      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
609*61046927SAndroid Build Coastguard Worker
610*61046927SAndroid Build Coastguard Worker      assert cmp_result is not None, \
611*61046927SAndroid Build Coastguard Worker         error_msg(a_bit_size, b_bit_size)
612*61046927SAndroid Build Coastguard Worker
613*61046927SAndroid Build Coastguard Worker      if cmp_result < 0:
614*61046927SAndroid Build Coastguard Worker         b_bit_size.set_bit_size(a)
615*61046927SAndroid Build Coastguard Worker      elif not isinstance(a_bit_size, int):
616*61046927SAndroid Build Coastguard Worker         a_bit_size.set_bit_size(b)
617*61046927SAndroid Build Coastguard Worker
618*61046927SAndroid Build Coastguard Worker   def merge_variables(self, val):
619*61046927SAndroid Build Coastguard Worker      """Perform the first part of type inference by merging all the different
620*61046927SAndroid Build Coastguard Worker      uses of the same variable. We always do this as if we're in the search
621*61046927SAndroid Build Coastguard Worker      expression, even if we're actually not, since otherwise we'd get errors
622*61046927SAndroid Build Coastguard Worker      if the search expression specified some constraint but the replace
623*61046927SAndroid Build Coastguard Worker      expression didn't, because we'd be merging a variable and a constant.
624*61046927SAndroid Build Coastguard Worker      """
625*61046927SAndroid Build Coastguard Worker      if isinstance(val, Variable):
626*61046927SAndroid Build Coastguard Worker         if self._var_classes[val.index] is None:
627*61046927SAndroid Build Coastguard Worker            self._var_classes[val.index] = val
628*61046927SAndroid Build Coastguard Worker         else:
629*61046927SAndroid Build Coastguard Worker            other = self._var_classes[val.index]
630*61046927SAndroid Build Coastguard Worker            self.unify_bit_size(other, val,
631*61046927SAndroid Build Coastguard Worker                  lambda other_bit_size, bit_size:
632*61046927SAndroid Build Coastguard Worker                     'Variable {} has conflicting bit size requirements: ' \
633*61046927SAndroid Build Coastguard Worker                     'it must have bit size {} and {}'.format(
634*61046927SAndroid Build Coastguard Worker                        val.var_name, other_bit_size, bit_size))
635*61046927SAndroid Build Coastguard Worker      elif isinstance(val, Expression):
636*61046927SAndroid Build Coastguard Worker         for src in val.sources:
637*61046927SAndroid Build Coastguard Worker            self.merge_variables(src)
638*61046927SAndroid Build Coastguard Worker
639*61046927SAndroid Build Coastguard Worker   def validate_value(self, val):
640*61046927SAndroid Build Coastguard Worker      """Validate the an expression by performing classic Hindley-Milner
641*61046927SAndroid Build Coastguard Worker      type inference on bitsizes. This will detect if there are any conflicting
642*61046927SAndroid Build Coastguard Worker      requirements, and unify variables so that we know which variables must
643*61046927SAndroid Build Coastguard Worker      have the same bitsize. If we're operating on the replace expression, we
644*61046927SAndroid Build Coastguard Worker      will refuse to merge different variables together or merge a variable
645*61046927SAndroid Build Coastguard Worker      with a constant, in order to prevent surprises due to rules unexpectedly
646*61046927SAndroid Build Coastguard Worker      not matching at runtime.
647*61046927SAndroid Build Coastguard Worker      """
648*61046927SAndroid Build Coastguard Worker      if not isinstance(val, Expression):
649*61046927SAndroid Build Coastguard Worker         return
650*61046927SAndroid Build Coastguard Worker
651*61046927SAndroid Build Coastguard Worker      # Generic conversion ops are special in that they have a single unsized
652*61046927SAndroid Build Coastguard Worker      # source and an unsized destination and the two don't have to match.
653*61046927SAndroid Build Coastguard Worker      # This means there's no validation or unioning to do here besides the
654*61046927SAndroid Build Coastguard Worker      # len(val.sources) check.
655*61046927SAndroid Build Coastguard Worker      if val.opcode in conv_opcode_types:
656*61046927SAndroid Build Coastguard Worker         assert len(val.sources) == 1, \
657*61046927SAndroid Build Coastguard Worker            "Expression {} has {} sources, expected 1".format(
658*61046927SAndroid Build Coastguard Worker               val, len(val.sources))
659*61046927SAndroid Build Coastguard Worker         self.validate_value(val.sources[0])
660*61046927SAndroid Build Coastguard Worker         return
661*61046927SAndroid Build Coastguard Worker
662*61046927SAndroid Build Coastguard Worker      nir_op = opcodes[val.opcode]
663*61046927SAndroid Build Coastguard Worker      assert len(val.sources) == nir_op.num_inputs, \
664*61046927SAndroid Build Coastguard Worker         "Expression {} has {} sources, expected {}".format(
665*61046927SAndroid Build Coastguard Worker            val, len(val.sources), nir_op.num_inputs)
666*61046927SAndroid Build Coastguard Worker
667*61046927SAndroid Build Coastguard Worker      for src in val.sources:
668*61046927SAndroid Build Coastguard Worker         self.validate_value(src)
669*61046927SAndroid Build Coastguard Worker
670*61046927SAndroid Build Coastguard Worker      dst_type_bits = type_bits(nir_op.output_type)
671*61046927SAndroid Build Coastguard Worker
672*61046927SAndroid Build Coastguard Worker      # First, unify all the sources. That way, an error coming up because two
673*61046927SAndroid Build Coastguard Worker      # sources have an incompatible bit-size won't produce an error message
674*61046927SAndroid Build Coastguard Worker      # involving the destination.
675*61046927SAndroid Build Coastguard Worker      first_unsized_src = None
676*61046927SAndroid Build Coastguard Worker      for src_type, src in zip(nir_op.input_types, val.sources):
677*61046927SAndroid Build Coastguard Worker         src_type_bits = type_bits(src_type)
678*61046927SAndroid Build Coastguard Worker         if src_type_bits == 0:
679*61046927SAndroid Build Coastguard Worker            if first_unsized_src is None:
680*61046927SAndroid Build Coastguard Worker               first_unsized_src = src
681*61046927SAndroid Build Coastguard Worker               continue
682*61046927SAndroid Build Coastguard Worker
683*61046927SAndroid Build Coastguard Worker            if self.is_search:
684*61046927SAndroid Build Coastguard Worker               self.unify_bit_size(first_unsized_src, src,
685*61046927SAndroid Build Coastguard Worker                  lambda first_unsized_src_bit_size, src_bit_size:
686*61046927SAndroid Build Coastguard Worker                     'Source {} of {} must have bit size {}, while source {} ' \
687*61046927SAndroid Build Coastguard Worker                     'must have incompatible bit size {}'.format(
688*61046927SAndroid Build Coastguard Worker                        first_unsized_src, val, first_unsized_src_bit_size,
689*61046927SAndroid Build Coastguard Worker                        src, src_bit_size))
690*61046927SAndroid Build Coastguard Worker            else:
691*61046927SAndroid Build Coastguard Worker               self.unify_bit_size(first_unsized_src, src,
692*61046927SAndroid Build Coastguard Worker                  lambda first_unsized_src_bit_size, src_bit_size:
693*61046927SAndroid Build Coastguard Worker                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
694*61046927SAndroid Build Coastguard Worker                     'of {} may not have the same bit size when building the ' \
695*61046927SAndroid Build Coastguard Worker                     'replacement expression.'.format(
696*61046927SAndroid Build Coastguard Worker                        first_unsized_src, first_unsized_src_bit_size, src,
697*61046927SAndroid Build Coastguard Worker                        src_bit_size, val))
698*61046927SAndroid Build Coastguard Worker         else:
699*61046927SAndroid Build Coastguard Worker            if self.is_search:
700*61046927SAndroid Build Coastguard Worker               self.unify_bit_size(src, src_type_bits,
701*61046927SAndroid Build Coastguard Worker                  lambda src_bit_size, unused:
702*61046927SAndroid Build Coastguard Worker                     '{} must have {} bits, but as a source of nir_op_{} '\
703*61046927SAndroid Build Coastguard Worker                     'it must have {} bits'.format(
704*61046927SAndroid Build Coastguard Worker                        src, src_bit_size, nir_op.name, src_type_bits))
705*61046927SAndroid Build Coastguard Worker            else:
706*61046927SAndroid Build Coastguard Worker               self.unify_bit_size(src, src_type_bits,
707*61046927SAndroid Build Coastguard Worker                  lambda src_bit_size, unused:
708*61046927SAndroid Build Coastguard Worker                     '{} has the bit size of {}, but as a source of ' \
709*61046927SAndroid Build Coastguard Worker                     'nir_op_{} it must have {} bits, which may not be the ' \
710*61046927SAndroid Build Coastguard Worker                     'same'.format(
711*61046927SAndroid Build Coastguard Worker                        src, src_bit_size, nir_op.name, src_type_bits))
712*61046927SAndroid Build Coastguard Worker
713*61046927SAndroid Build Coastguard Worker      if dst_type_bits == 0:
714*61046927SAndroid Build Coastguard Worker         if first_unsized_src is not None:
715*61046927SAndroid Build Coastguard Worker            if self.is_search:
716*61046927SAndroid Build Coastguard Worker               self.unify_bit_size(val, first_unsized_src,
717*61046927SAndroid Build Coastguard Worker                  lambda val_bit_size, src_bit_size:
718*61046927SAndroid Build Coastguard Worker                     '{} must have the bit size of {}, while its source {} ' \
719*61046927SAndroid Build Coastguard Worker                     'must have incompatible bit size {}'.format(
720*61046927SAndroid Build Coastguard Worker                        val, val_bit_size, first_unsized_src, src_bit_size))
721*61046927SAndroid Build Coastguard Worker            else:
722*61046927SAndroid Build Coastguard Worker               self.unify_bit_size(val, first_unsized_src,
723*61046927SAndroid Build Coastguard Worker                  lambda val_bit_size, src_bit_size:
724*61046927SAndroid Build Coastguard Worker                     '{} must have {} bits, but its source {} ' \
725*61046927SAndroid Build Coastguard Worker                     '(bit size of {}) may not have that bit size ' \
726*61046927SAndroid Build Coastguard Worker                     'when building the replacement.'.format(
727*61046927SAndroid Build Coastguard Worker                        val, val_bit_size, first_unsized_src, src_bit_size))
728*61046927SAndroid Build Coastguard Worker      else:
729*61046927SAndroid Build Coastguard Worker         self.unify_bit_size(val, dst_type_bits,
730*61046927SAndroid Build Coastguard Worker            lambda dst_bit_size, unused:
731*61046927SAndroid Build Coastguard Worker               '{} must have {} bits, but as a destination of nir_op_{} ' \
732*61046927SAndroid Build Coastguard Worker               'it must have {} bits'.format(
733*61046927SAndroid Build Coastguard Worker                  val, dst_bit_size, nir_op.name, dst_type_bits))
734*61046927SAndroid Build Coastguard Worker
735*61046927SAndroid Build Coastguard Worker   def validate_replace(self, val, search):
736*61046927SAndroid Build Coastguard Worker      bit_size = val.get_bit_size()
737*61046927SAndroid Build Coastguard Worker      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
738*61046927SAndroid Build Coastguard Worker            bit_size == search.get_bit_size(), \
739*61046927SAndroid Build Coastguard Worker            'Ambiguous bit size for replacement value {}: ' \
740*61046927SAndroid Build Coastguard Worker            'it cannot be deduced from a variable, a fixed bit size ' \
741*61046927SAndroid Build Coastguard Worker            'somewhere, or the search expression.'.format(val)
742*61046927SAndroid Build Coastguard Worker
743*61046927SAndroid Build Coastguard Worker      if isinstance(val, Expression):
744*61046927SAndroid Build Coastguard Worker         for src in val.sources:
745*61046927SAndroid Build Coastguard Worker            self.validate_replace(src, search)
746*61046927SAndroid Build Coastguard Worker      elif isinstance(val, Variable):
747*61046927SAndroid Build Coastguard Worker          # These catch problems when someone copies and pastes the search
748*61046927SAndroid Build Coastguard Worker          # into the replacement.
749*61046927SAndroid Build Coastguard Worker          assert not val.is_constant, \
750*61046927SAndroid Build Coastguard Worker              'Replacement variables must not be marked constant.'
751*61046927SAndroid Build Coastguard Worker
752*61046927SAndroid Build Coastguard Worker          assert val.cond_index == -1, \
753*61046927SAndroid Build Coastguard Worker              'Replacement variables must not have a condition.'
754*61046927SAndroid Build Coastguard Worker
755*61046927SAndroid Build Coastguard Worker          assert not val.required_type, \
756*61046927SAndroid Build Coastguard Worker              'Replacement variables must not have a required type.'
757*61046927SAndroid Build Coastguard Worker
758*61046927SAndroid Build Coastguard Worker   def validate(self, search, replace):
759*61046927SAndroid Build Coastguard Worker      self.is_search = True
760*61046927SAndroid Build Coastguard Worker      self.merge_variables(search)
761*61046927SAndroid Build Coastguard Worker      self.merge_variables(replace)
762*61046927SAndroid Build Coastguard Worker      self.validate_value(search)
763*61046927SAndroid Build Coastguard Worker
764*61046927SAndroid Build Coastguard Worker      self.is_search = False
765*61046927SAndroid Build Coastguard Worker      self.validate_value(replace)
766*61046927SAndroid Build Coastguard Worker
767*61046927SAndroid Build Coastguard Worker      # Check that search is always more specialized than replace. Note that
768*61046927SAndroid Build Coastguard Worker      # we're doing this in replace mode, disallowing merging variables.
769*61046927SAndroid Build Coastguard Worker      search_bit_size = search.get_bit_size()
770*61046927SAndroid Build Coastguard Worker      replace_bit_size = replace.get_bit_size()
771*61046927SAndroid Build Coastguard Worker      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
772*61046927SAndroid Build Coastguard Worker
773*61046927SAndroid Build Coastguard Worker      assert cmp_result is not None and cmp_result <= 0, \
774*61046927SAndroid Build Coastguard Worker         'The search expression bit size {} and replace expression ' \
775*61046927SAndroid Build Coastguard Worker         'bit size {} may not be the same'.format(
776*61046927SAndroid Build Coastguard Worker               search_bit_size, replace_bit_size)
777*61046927SAndroid Build Coastguard Worker
778*61046927SAndroid Build Coastguard Worker      replace.set_bit_size(search)
779*61046927SAndroid Build Coastguard Worker
780*61046927SAndroid Build Coastguard Worker      self.validate_replace(replace, search)
781*61046927SAndroid Build Coastguard Worker
782*61046927SAndroid Build Coastguard Worker_optimization_ids = itertools.count()
783*61046927SAndroid Build Coastguard Worker
784*61046927SAndroid Build Coastguard Workercondition_list = ['true']
785*61046927SAndroid Build Coastguard Worker
786*61046927SAndroid Build Coastguard Workerclass SearchAndReplace(object):
787*61046927SAndroid Build Coastguard Worker   def __init__(self, transform, algebraic_pass):
788*61046927SAndroid Build Coastguard Worker      self.id = next(_optimization_ids)
789*61046927SAndroid Build Coastguard Worker
790*61046927SAndroid Build Coastguard Worker      search = transform[0]
791*61046927SAndroid Build Coastguard Worker      replace = transform[1]
792*61046927SAndroid Build Coastguard Worker      if len(transform) > 2:
793*61046927SAndroid Build Coastguard Worker         self.condition = transform[2]
794*61046927SAndroid Build Coastguard Worker      else:
795*61046927SAndroid Build Coastguard Worker         self.condition = 'true'
796*61046927SAndroid Build Coastguard Worker
797*61046927SAndroid Build Coastguard Worker      if self.condition not in condition_list:
798*61046927SAndroid Build Coastguard Worker         condition_list.append(self.condition)
799*61046927SAndroid Build Coastguard Worker      self.condition_index = condition_list.index(self.condition)
800*61046927SAndroid Build Coastguard Worker
801*61046927SAndroid Build Coastguard Worker      varset = VarSet()
802*61046927SAndroid Build Coastguard Worker      if isinstance(search, Expression):
803*61046927SAndroid Build Coastguard Worker         self.search = search
804*61046927SAndroid Build Coastguard Worker      else:
805*61046927SAndroid Build Coastguard Worker         self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
806*61046927SAndroid Build Coastguard Worker
807*61046927SAndroid Build Coastguard Worker      varset.lock()
808*61046927SAndroid Build Coastguard Worker
809*61046927SAndroid Build Coastguard Worker      if isinstance(replace, Value):
810*61046927SAndroid Build Coastguard Worker         self.replace = replace
811*61046927SAndroid Build Coastguard Worker      else:
812*61046927SAndroid Build Coastguard Worker         self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
813*61046927SAndroid Build Coastguard Worker
814*61046927SAndroid Build Coastguard Worker      BitSizeValidator(varset).validate(self.search, self.replace)
815*61046927SAndroid Build Coastguard Worker
816*61046927SAndroid Build Coastguard Workerclass TreeAutomaton(object):
817*61046927SAndroid Build Coastguard Worker   """This class calculates a bottom-up tree automaton to quickly search for
818*61046927SAndroid Build Coastguard Worker   the left-hand sides of tranforms. Tree automatons are a generalization of
819*61046927SAndroid Build Coastguard Worker   classical NFA's and DFA's, where the transition function determines the
820*61046927SAndroid Build Coastguard Worker   state of the parent node based on the state of its children. We construct a
821*61046927SAndroid Build Coastguard Worker   deterministic automaton to match patterns, using a similar algorithm to the
822*61046927SAndroid Build Coastguard Worker   classical NFA to DFA construction. At the moment, it only matches opcodes
823*61046927SAndroid Build Coastguard Worker   and constants (without checking the actual value), leaving more detailed
824*61046927SAndroid Build Coastguard Worker   checking to the search function which actually checks the leaves. The
825*61046927SAndroid Build Coastguard Worker   automaton acts as a quick filter for the search function, requiring only n
826*61046927SAndroid Build Coastguard Worker   + 1 table lookups for each n-source operation. The implementation is based
827*61046927SAndroid Build Coastguard Worker   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
828*61046927SAndroid Build Coastguard Worker   In the language of that reference, this is a frontier-to-root deterministic
829*61046927SAndroid Build Coastguard Worker   automaton using only symbol filtering. The filtering is crucial to reduce
830*61046927SAndroid Build Coastguard Worker   both the time taken to generate the tables and the size of the tables.
831*61046927SAndroid Build Coastguard Worker   """
832*61046927SAndroid Build Coastguard Worker   def __init__(self, transforms):
833*61046927SAndroid Build Coastguard Worker      self.patterns = [t.search for t in transforms]
834*61046927SAndroid Build Coastguard Worker      self._compute_items()
835*61046927SAndroid Build Coastguard Worker      self._build_table()
836*61046927SAndroid Build Coastguard Worker      #print('num items: {}'.format(len(set(self.items.values()))))
837*61046927SAndroid Build Coastguard Worker      #print('num states: {}'.format(len(self.states)))
838*61046927SAndroid Build Coastguard Worker      #for state, patterns in zip(self.states, self.patterns):
839*61046927SAndroid Build Coastguard Worker      #   print('{}: num patterns: {}'.format(state, len(patterns)))
840*61046927SAndroid Build Coastguard Worker
841*61046927SAndroid Build Coastguard Worker   class IndexMap(object):
842*61046927SAndroid Build Coastguard Worker      """An indexed list of objects, where one can either lookup an object by
843*61046927SAndroid Build Coastguard Worker      index or find the index associated to an object quickly using a hash
844*61046927SAndroid Build Coastguard Worker      table. Compared to a list, it has a constant time index(). Compared to a
845*61046927SAndroid Build Coastguard Worker      set, it provides a stable iteration order.
846*61046927SAndroid Build Coastguard Worker      """
847*61046927SAndroid Build Coastguard Worker      def __init__(self, iterable=()):
848*61046927SAndroid Build Coastguard Worker         self.objects = []
849*61046927SAndroid Build Coastguard Worker         self.map = {}
850*61046927SAndroid Build Coastguard Worker         for obj in iterable:
851*61046927SAndroid Build Coastguard Worker            self.add(obj)
852*61046927SAndroid Build Coastguard Worker
853*61046927SAndroid Build Coastguard Worker      def __getitem__(self, i):
854*61046927SAndroid Build Coastguard Worker         return self.objects[i]
855*61046927SAndroid Build Coastguard Worker
856*61046927SAndroid Build Coastguard Worker      def __contains__(self, obj):
857*61046927SAndroid Build Coastguard Worker         return obj in self.map
858*61046927SAndroid Build Coastguard Worker
859*61046927SAndroid Build Coastguard Worker      def __len__(self):
860*61046927SAndroid Build Coastguard Worker         return len(self.objects)
861*61046927SAndroid Build Coastguard Worker
862*61046927SAndroid Build Coastguard Worker      def __iter__(self):
863*61046927SAndroid Build Coastguard Worker         return iter(self.objects)
864*61046927SAndroid Build Coastguard Worker
865*61046927SAndroid Build Coastguard Worker      def clear(self):
866*61046927SAndroid Build Coastguard Worker         self.objects = []
867*61046927SAndroid Build Coastguard Worker         self.map.clear()
868*61046927SAndroid Build Coastguard Worker
869*61046927SAndroid Build Coastguard Worker      def index(self, obj):
870*61046927SAndroid Build Coastguard Worker         return self.map[obj]
871*61046927SAndroid Build Coastguard Worker
872*61046927SAndroid Build Coastguard Worker      def add(self, obj):
873*61046927SAndroid Build Coastguard Worker         if obj in self.map:
874*61046927SAndroid Build Coastguard Worker            return self.map[obj]
875*61046927SAndroid Build Coastguard Worker         else:
876*61046927SAndroid Build Coastguard Worker            index = len(self.objects)
877*61046927SAndroid Build Coastguard Worker            self.objects.append(obj)
878*61046927SAndroid Build Coastguard Worker            self.map[obj] = index
879*61046927SAndroid Build Coastguard Worker            return index
880*61046927SAndroid Build Coastguard Worker
881*61046927SAndroid Build Coastguard Worker      def __repr__(self):
882*61046927SAndroid Build Coastguard Worker         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
883*61046927SAndroid Build Coastguard Worker
884*61046927SAndroid Build Coastguard Worker   class Item(object):
885*61046927SAndroid Build Coastguard Worker      """This represents an "item" in the language of "Tree Automatons." This
886*61046927SAndroid Build Coastguard Worker      is just a subtree of some pattern, which represents a potential partial
887*61046927SAndroid Build Coastguard Worker      match at runtime. We deduplicate them, so that identical subtrees of
888*61046927SAndroid Build Coastguard Worker      different patterns share the same object, and store some extra
889*61046927SAndroid Build Coastguard Worker      information needed for the main algorithm as well.
890*61046927SAndroid Build Coastguard Worker      """
891*61046927SAndroid Build Coastguard Worker      def __init__(self, opcode, children):
892*61046927SAndroid Build Coastguard Worker         self.opcode = opcode
893*61046927SAndroid Build Coastguard Worker         self.children = children
894*61046927SAndroid Build Coastguard Worker         # These are the indices of patterns for which this item is the root node.
895*61046927SAndroid Build Coastguard Worker         self.patterns = []
896*61046927SAndroid Build Coastguard Worker         # This the set of opcodes for parents of this item. Used to speed up
897*61046927SAndroid Build Coastguard Worker         # filtering.
898*61046927SAndroid Build Coastguard Worker         self.parent_ops = set()
899*61046927SAndroid Build Coastguard Worker
900*61046927SAndroid Build Coastguard Worker      def __str__(self):
901*61046927SAndroid Build Coastguard Worker         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
902*61046927SAndroid Build Coastguard Worker
903*61046927SAndroid Build Coastguard Worker      def __repr__(self):
904*61046927SAndroid Build Coastguard Worker         return str(self)
905*61046927SAndroid Build Coastguard Worker
906*61046927SAndroid Build Coastguard Worker   def _compute_items(self):
907*61046927SAndroid Build Coastguard Worker      """Build a set of all possible items, deduplicating them."""
908*61046927SAndroid Build Coastguard Worker      # This is a map from (opcode, sources) to item.
909*61046927SAndroid Build Coastguard Worker      self.items = {}
910*61046927SAndroid Build Coastguard Worker
911*61046927SAndroid Build Coastguard Worker      # The set of all opcodes used by the patterns. Used later to avoid
912*61046927SAndroid Build Coastguard Worker      # building and emitting all the tables for opcodes that aren't used.
913*61046927SAndroid Build Coastguard Worker      self.opcodes = self.IndexMap()
914*61046927SAndroid Build Coastguard Worker
915*61046927SAndroid Build Coastguard Worker      def get_item(opcode, children, pattern=None):
916*61046927SAndroid Build Coastguard Worker         commutative = len(children) >= 2 \
917*61046927SAndroid Build Coastguard Worker               and "2src_commutative" in opcodes[opcode].algebraic_properties
918*61046927SAndroid Build Coastguard Worker         item = self.items.setdefault((opcode, children),
919*61046927SAndroid Build Coastguard Worker                                      self.Item(opcode, children))
920*61046927SAndroid Build Coastguard Worker         if commutative:
921*61046927SAndroid Build Coastguard Worker            self.items[opcode, (children[1], children[0]) + children[2:]] = item
922*61046927SAndroid Build Coastguard Worker         if pattern is not None:
923*61046927SAndroid Build Coastguard Worker            item.patterns.append(pattern)
924*61046927SAndroid Build Coastguard Worker         return item
925*61046927SAndroid Build Coastguard Worker
926*61046927SAndroid Build Coastguard Worker      self.wildcard = get_item("__wildcard", ())
927*61046927SAndroid Build Coastguard Worker      self.const = get_item("__const", ())
928*61046927SAndroid Build Coastguard Worker
929*61046927SAndroid Build Coastguard Worker      def process_subpattern(src, pattern=None):
930*61046927SAndroid Build Coastguard Worker         if isinstance(src, Constant):
931*61046927SAndroid Build Coastguard Worker            # Note: we throw away the actual constant value!
932*61046927SAndroid Build Coastguard Worker            return self.const
933*61046927SAndroid Build Coastguard Worker         elif isinstance(src, Variable):
934*61046927SAndroid Build Coastguard Worker            if src.is_constant:
935*61046927SAndroid Build Coastguard Worker               return self.const
936*61046927SAndroid Build Coastguard Worker            else:
937*61046927SAndroid Build Coastguard Worker               # Note: we throw away which variable it is here! This special
938*61046927SAndroid Build Coastguard Worker               # item is equivalent to nu in "Tree Automatons."
939*61046927SAndroid Build Coastguard Worker               return self.wildcard
940*61046927SAndroid Build Coastguard Worker         else:
941*61046927SAndroid Build Coastguard Worker            assert isinstance(src, Expression)
942*61046927SAndroid Build Coastguard Worker            opcode = src.opcode
943*61046927SAndroid Build Coastguard Worker            stripped = opcode.rstrip('0123456789')
944*61046927SAndroid Build Coastguard Worker            if stripped in conv_opcode_types:
945*61046927SAndroid Build Coastguard Worker               # Matches that use conversion opcodes with a specific type,
946*61046927SAndroid Build Coastguard Worker               # like f2i1, are tricky.  Either we construct the automaton to
947*61046927SAndroid Build Coastguard Worker               # match specific NIR opcodes like nir_op_f2i1, in which case we
948*61046927SAndroid Build Coastguard Worker               # need to create separate items for each possible NIR opcode
949*61046927SAndroid Build Coastguard Worker               # for patterns that have a generic opcode like f2i, or we
950*61046927SAndroid Build Coastguard Worker               # construct it to match the search opcode, in which case we
951*61046927SAndroid Build Coastguard Worker               # need to map f2i1 to f2i when constructing the automaton. Here
952*61046927SAndroid Build Coastguard Worker               # we do the latter.
953*61046927SAndroid Build Coastguard Worker               opcode = stripped
954*61046927SAndroid Build Coastguard Worker            self.opcodes.add(opcode)
955*61046927SAndroid Build Coastguard Worker            children = tuple(process_subpattern(c) for c in src.sources)
956*61046927SAndroid Build Coastguard Worker            item = get_item(opcode, children, pattern)
957*61046927SAndroid Build Coastguard Worker            for i, child in enumerate(children):
958*61046927SAndroid Build Coastguard Worker               child.parent_ops.add(opcode)
959*61046927SAndroid Build Coastguard Worker            return item
960*61046927SAndroid Build Coastguard Worker
961*61046927SAndroid Build Coastguard Worker      for i, pattern in enumerate(self.patterns):
962*61046927SAndroid Build Coastguard Worker         process_subpattern(pattern, i)
963*61046927SAndroid Build Coastguard Worker
964*61046927SAndroid Build Coastguard Worker   def _build_table(self):
965*61046927SAndroid Build Coastguard Worker      """This is the core algorithm which builds up the transition table. It
966*61046927SAndroid Build Coastguard Worker      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
967*61046927SAndroid Build Coastguard Worker      Comp_a and Filt_{a,i} using integers to identify match sets." It
968*61046927SAndroid Build Coastguard Worker      simultaneously builds up a list of all possible "match sets" or
969*61046927SAndroid Build Coastguard Worker      "states", where each match set represents the set of Item's that match a
970*61046927SAndroid Build Coastguard Worker      given instruction, and builds up the transition table between states.
971*61046927SAndroid Build Coastguard Worker      """
972*61046927SAndroid Build Coastguard Worker      # Map from opcode + filtered state indices to transitioned state.
973*61046927SAndroid Build Coastguard Worker      self.table = defaultdict(dict)
974*61046927SAndroid Build Coastguard Worker      # Bijection from state to index. q in the original algorithm is
975*61046927SAndroid Build Coastguard Worker      # len(self.states)
976*61046927SAndroid Build Coastguard Worker      self.states = self.IndexMap()
977*61046927SAndroid Build Coastguard Worker      # Lists of pattern matches separated by None
978*61046927SAndroid Build Coastguard Worker      self.state_patterns = [None]
979*61046927SAndroid Build Coastguard Worker      # Offset in the ->transforms table for each state index
980*61046927SAndroid Build Coastguard Worker      self.state_pattern_offsets = []
981*61046927SAndroid Build Coastguard Worker      # Map from state index to filtered state index for each opcode.
982*61046927SAndroid Build Coastguard Worker      self.filter = defaultdict(list)
983*61046927SAndroid Build Coastguard Worker      # Bijections from filtered state to filtered state index for each
984*61046927SAndroid Build Coastguard Worker      # opcode, called the "representor sets" in the original algorithm.
985*61046927SAndroid Build Coastguard Worker      # q_{a,j} in the original algorithm is len(self.rep[op]).
986*61046927SAndroid Build Coastguard Worker      self.rep = defaultdict(self.IndexMap)
987*61046927SAndroid Build Coastguard Worker
988*61046927SAndroid Build Coastguard Worker      # Everything in self.states with a index at least worklist_index is part
989*61046927SAndroid Build Coastguard Worker      # of the worklist of newly created states. There is also a worklist of
990*61046927SAndroid Build Coastguard Worker      # newly fitered states for each opcode, for which worklist_indices
991*61046927SAndroid Build Coastguard Worker      # serves a similar purpose. worklist_index corresponds to p in the
992*61046927SAndroid Build Coastguard Worker      # original algorithm, while worklist_indices is p_{a,j} (although since
993*61046927SAndroid Build Coastguard Worker      # we only filter by opcode/symbol, it's really just p_a).
994*61046927SAndroid Build Coastguard Worker      self.worklist_index = 0
995*61046927SAndroid Build Coastguard Worker      worklist_indices = defaultdict(lambda: 0)
996*61046927SAndroid Build Coastguard Worker
997*61046927SAndroid Build Coastguard Worker      # This is the set of opcodes for which the filtered worklist is non-empty.
998*61046927SAndroid Build Coastguard Worker      # It's used to avoid scanning opcodes for which there is nothing to
999*61046927SAndroid Build Coastguard Worker      # process when building the transition table. It corresponds to new_a in
1000*61046927SAndroid Build Coastguard Worker      # the original algorithm.
1001*61046927SAndroid Build Coastguard Worker      new_opcodes = self.IndexMap()
1002*61046927SAndroid Build Coastguard Worker
1003*61046927SAndroid Build Coastguard Worker      # Process states on the global worklist, filtering them for each opcode,
1004*61046927SAndroid Build Coastguard Worker      # updating the filter tables, and updating the filtered worklists if any
1005*61046927SAndroid Build Coastguard Worker      # new filtered states are found. Similar to ComputeRepresenterSets() in
1006*61046927SAndroid Build Coastguard Worker      # the original algorithm, although that only processes a single state.
1007*61046927SAndroid Build Coastguard Worker      def process_new_states():
1008*61046927SAndroid Build Coastguard Worker         while self.worklist_index < len(self.states):
1009*61046927SAndroid Build Coastguard Worker            state = self.states[self.worklist_index]
1010*61046927SAndroid Build Coastguard Worker            # Calculate pattern matches for this state. Each pattern is
1011*61046927SAndroid Build Coastguard Worker            # assigned to a unique item, so we don't have to worry about
1012*61046927SAndroid Build Coastguard Worker            # deduplicating them here. However, we do have to sort them so
1013*61046927SAndroid Build Coastguard Worker            # that they're visited at runtime in the order they're specified
1014*61046927SAndroid Build Coastguard Worker            # in the source.
1015*61046927SAndroid Build Coastguard Worker            patterns = list(sorted(p for item in state for p in item.patterns))
1016*61046927SAndroid Build Coastguard Worker
1017*61046927SAndroid Build Coastguard Worker            if patterns:
1018*61046927SAndroid Build Coastguard Worker                # Add our patterns to the global table.
1019*61046927SAndroid Build Coastguard Worker                self.state_pattern_offsets.append(len(self.state_patterns))
1020*61046927SAndroid Build Coastguard Worker                self.state_patterns.extend(patterns)
1021*61046927SAndroid Build Coastguard Worker                self.state_patterns.append(None)
1022*61046927SAndroid Build Coastguard Worker            else:
1023*61046927SAndroid Build Coastguard Worker                # Point to the initial sentinel in the global table.
1024*61046927SAndroid Build Coastguard Worker                self.state_pattern_offsets.append(0)
1025*61046927SAndroid Build Coastguard Worker
1026*61046927SAndroid Build Coastguard Worker            # calculate filter table for this state, and update filtered
1027*61046927SAndroid Build Coastguard Worker            # worklists.
1028*61046927SAndroid Build Coastguard Worker            for op in self.opcodes:
1029*61046927SAndroid Build Coastguard Worker               filt = self.filter[op]
1030*61046927SAndroid Build Coastguard Worker               rep = self.rep[op]
1031*61046927SAndroid Build Coastguard Worker               filtered = frozenset(item for item in state if \
1032*61046927SAndroid Build Coastguard Worker                  op in item.parent_ops)
1033*61046927SAndroid Build Coastguard Worker               if filtered in rep:
1034*61046927SAndroid Build Coastguard Worker                  rep_index = rep.index(filtered)
1035*61046927SAndroid Build Coastguard Worker               else:
1036*61046927SAndroid Build Coastguard Worker                  rep_index = rep.add(filtered)
1037*61046927SAndroid Build Coastguard Worker                  new_opcodes.add(op)
1038*61046927SAndroid Build Coastguard Worker               assert len(filt) == self.worklist_index
1039*61046927SAndroid Build Coastguard Worker               filt.append(rep_index)
1040*61046927SAndroid Build Coastguard Worker            self.worklist_index += 1
1041*61046927SAndroid Build Coastguard Worker
1042*61046927SAndroid Build Coastguard Worker      # There are two start states: one which can only match as a wildcard,
1043*61046927SAndroid Build Coastguard Worker      # and one which can match as a wildcard or constant. These will be the
1044*61046927SAndroid Build Coastguard Worker      # states of intrinsics/other instructions and load_const instructions,
1045*61046927SAndroid Build Coastguard Worker      # respectively. The indices of these must match the definitions of
1046*61046927SAndroid Build Coastguard Worker      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1047*61046927SAndroid Build Coastguard Worker      # initialize things correctly.
1048*61046927SAndroid Build Coastguard Worker      self.states.add(frozenset((self.wildcard,)))
1049*61046927SAndroid Build Coastguard Worker      self.states.add(frozenset((self.const,self.wildcard)))
1050*61046927SAndroid Build Coastguard Worker      process_new_states()
1051*61046927SAndroid Build Coastguard Worker
1052*61046927SAndroid Build Coastguard Worker      while len(new_opcodes) > 0:
1053*61046927SAndroid Build Coastguard Worker         for op in new_opcodes:
1054*61046927SAndroid Build Coastguard Worker            rep = self.rep[op]
1055*61046927SAndroid Build Coastguard Worker            table = self.table[op]
1056*61046927SAndroid Build Coastguard Worker            op_worklist_index = worklist_indices[op]
1057*61046927SAndroid Build Coastguard Worker            if op in conv_opcode_types:
1058*61046927SAndroid Build Coastguard Worker               num_srcs = 1
1059*61046927SAndroid Build Coastguard Worker            else:
1060*61046927SAndroid Build Coastguard Worker               num_srcs = opcodes[op].num_inputs
1061*61046927SAndroid Build Coastguard Worker
1062*61046927SAndroid Build Coastguard Worker            # Iterate over all possible source combinations where at least one
1063*61046927SAndroid Build Coastguard Worker            # is on the worklist.
1064*61046927SAndroid Build Coastguard Worker            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1065*61046927SAndroid Build Coastguard Worker               if all(src_idx < op_worklist_index for src_idx in src_indices):
1066*61046927SAndroid Build Coastguard Worker                  continue
1067*61046927SAndroid Build Coastguard Worker
1068*61046927SAndroid Build Coastguard Worker               srcs = tuple(rep[src_idx] for src_idx in src_indices)
1069*61046927SAndroid Build Coastguard Worker
1070*61046927SAndroid Build Coastguard Worker               # Try all possible pairings of source items and add the
1071*61046927SAndroid Build Coastguard Worker               # corresponding parent items. This is Comp_a from the paper.
1072*61046927SAndroid Build Coastguard Worker               parent = set(self.items[op, item_srcs] for item_srcs in
1073*61046927SAndroid Build Coastguard Worker                  itertools.product(*srcs) if (op, item_srcs) in self.items)
1074*61046927SAndroid Build Coastguard Worker
1075*61046927SAndroid Build Coastguard Worker               # We could always start matching something else with a
1076*61046927SAndroid Build Coastguard Worker               # wildcard. This is Cl from the paper.
1077*61046927SAndroid Build Coastguard Worker               parent.add(self.wildcard)
1078*61046927SAndroid Build Coastguard Worker
1079*61046927SAndroid Build Coastguard Worker               table[src_indices] = self.states.add(frozenset(parent))
1080*61046927SAndroid Build Coastguard Worker            worklist_indices[op] = len(rep)
1081*61046927SAndroid Build Coastguard Worker         new_opcodes.clear()
1082*61046927SAndroid Build Coastguard Worker         process_new_states()
1083*61046927SAndroid Build Coastguard Worker
1084*61046927SAndroid Build Coastguard Worker_algebraic_pass_template = mako.template.Template("""
1085*61046927SAndroid Build Coastguard Worker#include "nir.h"
1086*61046927SAndroid Build Coastguard Worker#include "nir_builder.h"
1087*61046927SAndroid Build Coastguard Worker#include "nir_search.h"
1088*61046927SAndroid Build Coastguard Worker#include "nir_search_helpers.h"
1089*61046927SAndroid Build Coastguard Worker
1090*61046927SAndroid Build Coastguard Worker/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1091*61046927SAndroid Build Coastguard Worker * transforms:
1092*61046927SAndroid Build Coastguard Worker% for xform in xforms:
1093*61046927SAndroid Build Coastguard Worker *    ${xform.search} => ${xform.replace}
1094*61046927SAndroid Build Coastguard Worker% endfor
1095*61046927SAndroid Build Coastguard Worker */
1096*61046927SAndroid Build Coastguard Worker
1097*61046927SAndroid Build Coastguard Worker<% cache = {"next_index": 0} %>
1098*61046927SAndroid Build Coastguard Workerstatic const nir_search_value_union ${pass_name}_values[] = {
1099*61046927SAndroid Build Coastguard Worker% for xform in xforms:
1100*61046927SAndroid Build Coastguard Worker   /* ${xform.search} => ${xform.replace} */
1101*61046927SAndroid Build Coastguard Worker${xform.search.render(cache)}
1102*61046927SAndroid Build Coastguard Worker${xform.replace.render(cache)}
1103*61046927SAndroid Build Coastguard Worker% endfor
1104*61046927SAndroid Build Coastguard Worker};
1105*61046927SAndroid Build Coastguard Worker
1106*61046927SAndroid Build Coastguard Worker% if expression_cond:
1107*61046927SAndroid Build Coastguard Workerstatic const nir_search_expression_cond ${pass_name}_expression_cond[] = {
1108*61046927SAndroid Build Coastguard Worker% for cond in expression_cond:
1109*61046927SAndroid Build Coastguard Worker   ${cond[0]},
1110*61046927SAndroid Build Coastguard Worker% endfor
1111*61046927SAndroid Build Coastguard Worker};
1112*61046927SAndroid Build Coastguard Worker% endif
1113*61046927SAndroid Build Coastguard Worker
1114*61046927SAndroid Build Coastguard Worker% if variable_cond:
1115*61046927SAndroid Build Coastguard Workerstatic const nir_search_variable_cond ${pass_name}_variable_cond[] = {
1116*61046927SAndroid Build Coastguard Worker% for cond in variable_cond:
1117*61046927SAndroid Build Coastguard Worker   ${cond[0]},
1118*61046927SAndroid Build Coastguard Worker% endfor
1119*61046927SAndroid Build Coastguard Worker};
1120*61046927SAndroid Build Coastguard Worker% endif
1121*61046927SAndroid Build Coastguard Worker
1122*61046927SAndroid Build Coastguard Workerstatic const struct transform ${pass_name}_transforms[] = {
1123*61046927SAndroid Build Coastguard Worker% for i in automaton.state_patterns:
1124*61046927SAndroid Build Coastguard Worker% if i is not None:
1125*61046927SAndroid Build Coastguard Worker   { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} },
1126*61046927SAndroid Build Coastguard Worker% else:
1127*61046927SAndroid Build Coastguard Worker   { ~0, ~0, ~0 }, /* Sentinel */
1128*61046927SAndroid Build Coastguard Worker
1129*61046927SAndroid Build Coastguard Worker% endif
1130*61046927SAndroid Build Coastguard Worker% endfor
1131*61046927SAndroid Build Coastguard Worker};
1132*61046927SAndroid Build Coastguard Worker
1133*61046927SAndroid Build Coastguard Workerstatic const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = {
1134*61046927SAndroid Build Coastguard Worker% for op in automaton.opcodes:
1135*61046927SAndroid Build Coastguard Worker   [${get_c_opcode(op)}] = {
1136*61046927SAndroid Build Coastguard Worker% if all(e == 0 for e in automaton.filter[op]):
1137*61046927SAndroid Build Coastguard Worker      .filter = NULL,
1138*61046927SAndroid Build Coastguard Worker% else:
1139*61046927SAndroid Build Coastguard Worker      .filter = (const uint16_t []) {
1140*61046927SAndroid Build Coastguard Worker      % for e in automaton.filter[op]:
1141*61046927SAndroid Build Coastguard Worker         ${e},
1142*61046927SAndroid Build Coastguard Worker      % endfor
1143*61046927SAndroid Build Coastguard Worker      },
1144*61046927SAndroid Build Coastguard Worker% endif
1145*61046927SAndroid Build Coastguard Worker      <%
1146*61046927SAndroid Build Coastguard Worker        num_filtered = len(automaton.rep[op])
1147*61046927SAndroid Build Coastguard Worker      %>
1148*61046927SAndroid Build Coastguard Worker      .num_filtered_states = ${num_filtered},
1149*61046927SAndroid Build Coastguard Worker      .table = (const uint16_t []) {
1150*61046927SAndroid Build Coastguard Worker      <%
1151*61046927SAndroid Build Coastguard Worker        num_srcs = len(next(iter(automaton.table[op])))
1152*61046927SAndroid Build Coastguard Worker      %>
1153*61046927SAndroid Build Coastguard Worker      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1154*61046927SAndroid Build Coastguard Worker         ${automaton.table[op][indices]},
1155*61046927SAndroid Build Coastguard Worker      % endfor
1156*61046927SAndroid Build Coastguard Worker      },
1157*61046927SAndroid Build Coastguard Worker   },
1158*61046927SAndroid Build Coastguard Worker% endfor
1159*61046927SAndroid Build Coastguard Worker};
1160*61046927SAndroid Build Coastguard Worker
1161*61046927SAndroid Build Coastguard Worker/* Mapping from state index to offset in transforms (0 being no transforms) */
1162*61046927SAndroid Build Coastguard Workerstatic const uint16_t ${pass_name}_transform_offsets[] = {
1163*61046927SAndroid Build Coastguard Worker% for offset in automaton.state_pattern_offsets:
1164*61046927SAndroid Build Coastguard Worker   ${offset},
1165*61046927SAndroid Build Coastguard Worker% endfor
1166*61046927SAndroid Build Coastguard Worker};
1167*61046927SAndroid Build Coastguard Worker
1168*61046927SAndroid Build Coastguard Workerstatic const nir_algebraic_table ${pass_name}_table = {
1169*61046927SAndroid Build Coastguard Worker   .transforms = ${pass_name}_transforms,
1170*61046927SAndroid Build Coastguard Worker   .transform_offsets = ${pass_name}_transform_offsets,
1171*61046927SAndroid Build Coastguard Worker   .pass_op_table = ${pass_name}_pass_op_table,
1172*61046927SAndroid Build Coastguard Worker   .values = ${pass_name}_values,
1173*61046927SAndroid Build Coastguard Worker   .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
1174*61046927SAndroid Build Coastguard Worker   .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" },
1175*61046927SAndroid Build Coastguard Worker};
1176*61046927SAndroid Build Coastguard Worker
1177*61046927SAndroid Build Coastguard Workerbool
1178*61046927SAndroid Build Coastguard Worker${pass_name}(
1179*61046927SAndroid Build Coastguard Worker   nir_shader *shader
1180*61046927SAndroid Build Coastguard Worker% for type, name in params:
1181*61046927SAndroid Build Coastguard Worker   , ${type} ${name}
1182*61046927SAndroid Build Coastguard Worker% endfor
1183*61046927SAndroid Build Coastguard Worker) {
1184*61046927SAndroid Build Coastguard Worker   bool progress = false;
1185*61046927SAndroid Build Coastguard Worker   bool condition_flags[${len(condition_list)}];
1186*61046927SAndroid Build Coastguard Worker   const nir_shader_compiler_options *options = shader->options;
1187*61046927SAndroid Build Coastguard Worker   const shader_info *info = &shader->info;
1188*61046927SAndroid Build Coastguard Worker   (void) options;
1189*61046927SAndroid Build Coastguard Worker   (void) info;
1190*61046927SAndroid Build Coastguard Worker
1191*61046927SAndroid Build Coastguard Worker   STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values));
1192*61046927SAndroid Build Coastguard Worker   % for index, condition in enumerate(condition_list):
1193*61046927SAndroid Build Coastguard Worker   condition_flags[${index}] = ${condition};
1194*61046927SAndroid Build Coastguard Worker   % endfor
1195*61046927SAndroid Build Coastguard Worker
1196*61046927SAndroid Build Coastguard Worker   nir_foreach_function_impl(impl, shader) {
1197*61046927SAndroid Build Coastguard Worker     progress |= nir_algebraic_impl(impl, condition_flags, &${pass_name}_table);
1198*61046927SAndroid Build Coastguard Worker   }
1199*61046927SAndroid Build Coastguard Worker
1200*61046927SAndroid Build Coastguard Worker   return progress;
1201*61046927SAndroid Build Coastguard Worker}
1202*61046927SAndroid Build Coastguard Worker""")
1203*61046927SAndroid Build Coastguard Worker
1204*61046927SAndroid Build Coastguard Worker
1205*61046927SAndroid Build Coastguard Workerclass AlgebraicPass(object):
1206*61046927SAndroid Build Coastguard Worker   # params is a list of `("type", "name")` tuples
1207*61046927SAndroid Build Coastguard Worker   def __init__(self, pass_name, transforms, params=[]):
1208*61046927SAndroid Build Coastguard Worker      self.xforms = []
1209*61046927SAndroid Build Coastguard Worker      self.opcode_xforms = defaultdict(lambda : [])
1210*61046927SAndroid Build Coastguard Worker      self.pass_name = pass_name
1211*61046927SAndroid Build Coastguard Worker      self.expression_cond = {}
1212*61046927SAndroid Build Coastguard Worker      self.variable_cond = {}
1213*61046927SAndroid Build Coastguard Worker      self.params = params
1214*61046927SAndroid Build Coastguard Worker
1215*61046927SAndroid Build Coastguard Worker      error = False
1216*61046927SAndroid Build Coastguard Worker
1217*61046927SAndroid Build Coastguard Worker      for xform in transforms:
1218*61046927SAndroid Build Coastguard Worker         if not isinstance(xform, SearchAndReplace):
1219*61046927SAndroid Build Coastguard Worker            try:
1220*61046927SAndroid Build Coastguard Worker               xform = SearchAndReplace(xform, self)
1221*61046927SAndroid Build Coastguard Worker            except:
1222*61046927SAndroid Build Coastguard Worker               print("Failed to parse transformation:", file=sys.stderr)
1223*61046927SAndroid Build Coastguard Worker               print("  " + str(xform), file=sys.stderr)
1224*61046927SAndroid Build Coastguard Worker               traceback.print_exc(file=sys.stderr)
1225*61046927SAndroid Build Coastguard Worker               print('', file=sys.stderr)
1226*61046927SAndroid Build Coastguard Worker               error = True
1227*61046927SAndroid Build Coastguard Worker               continue
1228*61046927SAndroid Build Coastguard Worker
1229*61046927SAndroid Build Coastguard Worker         self.xforms.append(xform)
1230*61046927SAndroid Build Coastguard Worker         if xform.search.opcode in conv_opcode_types:
1231*61046927SAndroid Build Coastguard Worker            dst_type = conv_opcode_types[xform.search.opcode]
1232*61046927SAndroid Build Coastguard Worker            for size in type_sizes(dst_type):
1233*61046927SAndroid Build Coastguard Worker               sized_opcode = xform.search.opcode + str(size)
1234*61046927SAndroid Build Coastguard Worker               self.opcode_xforms[sized_opcode].append(xform)
1235*61046927SAndroid Build Coastguard Worker         else:
1236*61046927SAndroid Build Coastguard Worker            self.opcode_xforms[xform.search.opcode].append(xform)
1237*61046927SAndroid Build Coastguard Worker
1238*61046927SAndroid Build Coastguard Worker         # Check to make sure the search pattern does not unexpectedly contain
1239*61046927SAndroid Build Coastguard Worker         # more commutative expressions than match_expression (nir_search.c)
1240*61046927SAndroid Build Coastguard Worker         # can handle.
1241*61046927SAndroid Build Coastguard Worker         comm_exprs = xform.search.comm_exprs
1242*61046927SAndroid Build Coastguard Worker
1243*61046927SAndroid Build Coastguard Worker         if xform.search.many_commutative_expressions:
1244*61046927SAndroid Build Coastguard Worker            if comm_exprs <= nir_search_max_comm_ops:
1245*61046927SAndroid Build Coastguard Worker               print("Transform expected to have too many commutative " \
1246*61046927SAndroid Build Coastguard Worker                     "expression but did not " \
1247*61046927SAndroid Build Coastguard Worker                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1248*61046927SAndroid Build Coastguard Worker                     file=sys.stderr)
1249*61046927SAndroid Build Coastguard Worker               print("  " + str(xform), file=sys.stderr)
1250*61046927SAndroid Build Coastguard Worker               traceback.print_exc(file=sys.stderr)
1251*61046927SAndroid Build Coastguard Worker               print('', file=sys.stderr)
1252*61046927SAndroid Build Coastguard Worker               error = True
1253*61046927SAndroid Build Coastguard Worker         else:
1254*61046927SAndroid Build Coastguard Worker            if comm_exprs > nir_search_max_comm_ops:
1255*61046927SAndroid Build Coastguard Worker               print("Transformation with too many commutative expressions " \
1256*61046927SAndroid Build Coastguard Worker                     "({} > {}).  Modify pattern or annotate with " \
1257*61046927SAndroid Build Coastguard Worker                     "\"many-comm-expr\".".format(comm_exprs,
1258*61046927SAndroid Build Coastguard Worker                                                  nir_search_max_comm_ops),
1259*61046927SAndroid Build Coastguard Worker                     file=sys.stderr)
1260*61046927SAndroid Build Coastguard Worker               print("  " + str(xform.search), file=sys.stderr)
1261*61046927SAndroid Build Coastguard Worker               print("{}".format(xform.search.cond), file=sys.stderr)
1262*61046927SAndroid Build Coastguard Worker               error = True
1263*61046927SAndroid Build Coastguard Worker
1264*61046927SAndroid Build Coastguard Worker      self.automaton = TreeAutomaton(self.xforms)
1265*61046927SAndroid Build Coastguard Worker
1266*61046927SAndroid Build Coastguard Worker      if error:
1267*61046927SAndroid Build Coastguard Worker         sys.exit(1)
1268*61046927SAndroid Build Coastguard Worker
1269*61046927SAndroid Build Coastguard Worker
1270*61046927SAndroid Build Coastguard Worker   def render(self):
1271*61046927SAndroid Build Coastguard Worker      return _algebraic_pass_template.render(pass_name=self.pass_name,
1272*61046927SAndroid Build Coastguard Worker                                             xforms=self.xforms,
1273*61046927SAndroid Build Coastguard Worker                                             opcode_xforms=self.opcode_xforms,
1274*61046927SAndroid Build Coastguard Worker                                             condition_list=condition_list,
1275*61046927SAndroid Build Coastguard Worker                                             automaton=self.automaton,
1276*61046927SAndroid Build Coastguard Worker                                             expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
1277*61046927SAndroid Build Coastguard Worker                                             variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
1278*61046927SAndroid Build Coastguard Worker                                             get_c_opcode=get_c_opcode,
1279*61046927SAndroid Build Coastguard Worker                                             itertools=itertools,
1280*61046927SAndroid Build Coastguard Worker                                             params=self.params)
1281*61046927SAndroid Build Coastguard Worker
1282*61046927SAndroid Build Coastguard Worker# The replacement expression isn't necessarily exact if the search expression is exact.
1283*61046927SAndroid Build Coastguard Workerdef ignore_exact(*expr):
1284*61046927SAndroid Build Coastguard Worker   expr = SearchExpression.create(expr)
1285*61046927SAndroid Build Coastguard Worker   expr.ignore_exact = True
1286*61046927SAndroid Build Coastguard Worker   return expr
1287