xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_algebraic.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#
2# Copyright (C) 2014 Intel Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22
23import ast
24from collections import defaultdict
25import itertools
26import struct
27import sys
28import mako.template
29import re
30import traceback
31
32from nir_opcodes import opcodes, type_sizes
33
34# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
35nir_search_max_comm_ops = 8
36
37# These opcodes are only employed by nir_search.  This provides a mapping from
38# opcode to destination type.
39conv_opcode_types = {
40    'i2f' : 'float',
41    'u2f' : 'float',
42    'f2f' : 'float',
43    'f2u' : 'uint',
44    'f2i' : 'int',
45    'u2u' : 'uint',
46    'i2i' : 'int',
47    'b2f' : 'float',
48    'b2i' : 'int',
49    'i2b' : 'bool',
50    'f2b' : 'bool',
51}
52
53def get_cond_index(conds, cond):
54    if cond:
55        if cond in conds:
56            return conds[cond]
57        else:
58            cond_index = len(conds)
59            conds[cond] = cond_index
60            return cond_index
61    else:
62        return -1
63
64def get_c_opcode(op):
65      if op in conv_opcode_types:
66         return 'nir_search_op_' + op
67      else:
68         return 'nir_op_' + op
69
70_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
71
72def type_bits(type_str):
73   m = _type_re.match(type_str)
74   assert m.group('type')
75
76   if m.group('bits') is None:
77      return 0
78   else:
79      return int(m.group('bits'))
80
81# Represents a set of variables, each with a unique id
82class VarSet(object):
83   def __init__(self):
84      self.names = {}
85      self.ids = itertools.count()
86      self.immutable = False;
87
88   def __getitem__(self, name):
89      if name not in self.names:
90         assert not self.immutable, "Unknown replacement variable: " + name
91         self.names[name] = next(self.ids)
92
93      return self.names[name]
94
95   def lock(self):
96      self.immutable = True
97
98class SearchExpression(object):
99   def __init__(self, expr):
100      self.opcode = expr[0]
101      self.sources = expr[1:]
102      self.ignore_exact = False
103
104   @staticmethod
105   def create(val):
106      if isinstance(val, tuple):
107         return SearchExpression(val)
108      else:
109         assert(isinstance(val, SearchExpression))
110         return val
111
112   def __repr__(self):
113      l = [self.opcode, *self.sources]
114      if self.ignore_exact:
115         l.append('ignore_exact')
116      return repr((*l,))
117
118class Value(object):
119   @staticmethod
120   def create(val, name_base, varset, algebraic_pass):
121      if isinstance(val, bytes):
122         val = val.decode('utf-8')
123
124      if isinstance(val, tuple) or isinstance(val, SearchExpression):
125         return Expression(val, name_base, varset, algebraic_pass)
126      elif isinstance(val, Expression):
127         return val
128      elif isinstance(val, str):
129         return Variable(val, name_base, varset, algebraic_pass)
130      elif isinstance(val, (bool, float, int)):
131         return Constant(val, name_base)
132
133   def __init__(self, val, name, type_str):
134      self.in_val = str(val)
135      self.name = name
136      self.type_str = type_str
137
138   def __str__(self):
139      return self.in_val
140
141   def get_bit_size(self):
142      """Get the physical bit-size that has been chosen for this value, or if
143      there is none, the canonical value which currently represents this
144      bit-size class. Variables will be preferred, i.e. if there are any
145      variables in the equivalence class, the canonical value will be a
146      variable. We do this since we'll need to know which variable each value
147      is equivalent to when constructing the replacement expression. This is
148      the "find" part of the union-find algorithm.
149      """
150      bit_size = self
151
152      while isinstance(bit_size, Value):
153         if bit_size._bit_size is None:
154            break
155         bit_size = bit_size._bit_size
156
157      if bit_size is not self:
158         self._bit_size = bit_size
159      return bit_size
160
161   def set_bit_size(self, other):
162      """Make self.get_bit_size() return what other.get_bit_size() return
163      before calling this, or just "other" if it's a concrete bit-size. This is
164      the "union" part of the union-find algorithm.
165      """
166
167      self_bit_size = self.get_bit_size()
168      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
169
170      if self_bit_size == other_bit_size:
171         return
172
173      self_bit_size._bit_size = other_bit_size
174
175   @property
176   def type_enum(self):
177      return "nir_search_value_" + self.type_str
178
179   @property
180   def c_bit_size(self):
181      bit_size = self.get_bit_size()
182      if isinstance(bit_size, int):
183         return bit_size
184      elif isinstance(bit_size, Variable):
185         return -bit_size.index - 1
186      else:
187         # If the bit-size class is neither a variable, nor an actual bit-size, then
188         # - If it's in the search expression, we don't need to check anything
189         # - If it's in the replace expression, either it's ambiguous (in which
190         # case we'd reject it), or it equals the bit-size of the search value
191         # We represent these cases with a 0 bit-size.
192         return 0
193
194   __template = mako.template.Template("""   { .${val.type_str} = {
195      { ${val.type_enum}, ${val.c_bit_size} },
196% if isinstance(val, Constant):
197      ${val.type()}, { ${val.hex()} /* ${val.value} */ },
198% elif isinstance(val, Variable):
199      ${val.index}, /* ${val.var_name} */
200      ${'true' if val.is_constant else 'false'},
201      ${val.type() or 'nir_type_invalid' },
202      ${val.cond_index},
203      ${val.swizzle()},
204% elif isinstance(val, Expression):
205      ${'true' if val.inexact else 'false'},
206      ${'true' if val.exact else 'false'},
207      ${'true' if val.ignore_exact else 'false'},
208      ${'true' if val.nsz else 'false'},
209      ${'true' if val.nnan else 'false'},
210      ${'true' if val.ninf else 'false'},
211      ${val.c_opcode()},
212      ${val.comm_expr_idx}, ${val.comm_exprs},
213      { ${', '.join(src.array_index for src in val.sources)} },
214      ${val.cond_index},
215% endif
216   } },
217""")
218
219   def render(self, cache):
220      struct_init = self.__template.render(val=self,
221                                           Constant=Constant,
222                                           Variable=Variable,
223                                           Expression=Expression)
224      if struct_init in cache:
225         # If it's in the cache, register a name remap in the cache and render
226         # only a comment saying it's been remapped
227         self.array_index = cache[struct_init]
228         return "   /* {} -> {} in the cache */\n".format(self.name,
229                                                       cache[struct_init])
230      else:
231         self.array_index = str(cache["next_index"])
232         cache[struct_init] = self.array_index
233         cache["next_index"] += 1
234         return struct_init
235
236_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
237
238class Constant(Value):
239   def __init__(self, val, name):
240      Value.__init__(self, val, name, "constant")
241
242      if isinstance(val, (str)):
243         m = _constant_re.match(val)
244         self.value = ast.literal_eval(m.group('value'))
245         self._bit_size = int(m.group('bits')) if m.group('bits') else None
246      else:
247         self.value = val
248         self._bit_size = None
249
250      if isinstance(self.value, bool):
251         assert self._bit_size is None or self._bit_size == 1
252         self._bit_size = 1
253
254   def hex(self):
255      if isinstance(self.value, (bool)):
256         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
257      if isinstance(self.value, int):
258         # Explicitly sign-extend negative integers to 64-bit, ensuring correct
259         # handling of -INT32_MIN which is not representable in 32-bit.
260         if self.value < 0:
261            return hex(struct.unpack('Q', struct.pack('q', self.value))[0]) + 'ull'
262         else:
263            return hex(self.value) + 'ull'
264      elif isinstance(self.value, float):
265         return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) + 'ull'
266      else:
267         assert False
268
269   def type(self):
270      if isinstance(self.value, (bool)):
271         return "nir_type_bool"
272      elif isinstance(self.value, int):
273         return "nir_type_int"
274      elif isinstance(self.value, float):
275         return "nir_type_float"
276
277   def equivalent(self, other):
278      """Check that two constants are equivalent.
279
280      This is check is much weaker than equality.  One generally cannot be
281      used in place of the other.  Using this implementation for the __eq__
282      will break BitSizeValidator.
283
284      """
285      if not isinstance(other, type(self)):
286         return False
287
288      return self.value == other.value
289
290# The $ at the end forces there to be an error if any part of the string
291# doesn't match one of the field patterns.
292_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
293                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
294                          r"(?P<cond>\([^\)]+\))?"
295                          r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?"
296                          r"$")
297
298class Variable(Value):
299   def __init__(self, val, name, varset, algebraic_pass):
300      Value.__init__(self, val, name, "variable")
301
302      m = _var_name_re.match(val)
303      assert m and m.group('name') is not None, \
304            "Malformed variable name \"{}\".".format(val)
305
306      self.var_name = m.group('name')
307
308      # Prevent common cases where someone puts quotes around a literal
309      # constant.  If we want to support names that have numeric or
310      # punctuation characters, we can me the first assertion more flexible.
311      assert self.var_name.isalpha()
312      assert self.var_name != 'True'
313      assert self.var_name != 'False'
314
315      self.is_constant = m.group('const') is not None
316      self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond'))
317      self.required_type = m.group('type')
318      self._bit_size = int(m.group('bits')) if m.group('bits') else None
319      self.swiz = m.group('swiz')
320
321      if self.required_type == 'bool':
322         if self._bit_size is not None:
323            assert self._bit_size in type_sizes(self.required_type)
324         else:
325            self._bit_size = 1
326
327      if self.required_type is not None:
328         assert self.required_type in ('float', 'bool', 'int', 'uint')
329
330      self.index = varset[self.var_name]
331
332   def type(self):
333      if self.required_type == 'bool':
334         return "nir_type_bool"
335      elif self.required_type in ('int', 'uint'):
336         return "nir_type_int"
337      elif self.required_type == 'float':
338         return "nir_type_float"
339
340   def equivalent(self, other):
341      """Check that two variables are equivalent.
342
343      This is check is much weaker than equality.  One generally cannot be
344      used in place of the other.  Using this implementation for the __eq__
345      will break BitSizeValidator.
346
347      """
348      if not isinstance(other, type(self)):
349         return False
350
351      return self.index == other.index
352
353   def swizzle(self):
354      if self.swiz is not None:
355         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
356                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
357                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
358                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
359                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
360         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
361      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
362
363_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
364                        r"(?P<cond>\([^\)]+\))?")
365
366class Expression(Value):
367   def __init__(self, expr, name_base, varset, algebraic_pass):
368      Value.__init__(self, expr, name_base, "expression")
369
370      expr = SearchExpression.create(expr)
371
372      m = _opcode_re.match(expr.opcode)
373      assert m and m.group('opcode') is not None
374
375      self.opcode = m.group('opcode')
376      self._bit_size = int(m.group('bits')) if m.group('bits') else None
377      self.inexact = m.group('inexact') is not None
378      self.exact = m.group('exact') is not None
379      self.ignore_exact = expr.ignore_exact
380      self.cond = m.group('cond')
381
382      assert not self.inexact or not self.exact, \
383            'Expression cannot be both exact and inexact.'
384
385      # "many-comm-expr" isn't really a condition.  It's notification to the
386      # generator that this pattern is known to have too many commutative
387      # expressions, and an error should not be generated for this case.
388      # nsz, nnan and ninf are special conditions, so we treat them specially too.
389      cond = {k: True for k in self.cond[1:-1].split(",")} if self.cond else {}
390      self.many_commutative_expressions = cond.pop('many-comm-expr', False)
391      self.nsz = cond.pop('nsz', False)
392      self.nnan = cond.pop('nnan', False)
393      self.ninf = cond.pop('ninf', False)
394
395      assert len(cond) <= 1
396      self.cond = cond.popitem()[0] if cond else None
397
398      # Deduplicate references to the condition functions for the expressions
399      # and save the index for the order they were added.
400      self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
401
402      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
403                       for (i, src) in enumerate(expr.sources) ]
404
405      # nir_search_expression::srcs is hard-coded to 4
406      assert len(self.sources) <= 4
407
408      if self.opcode in conv_opcode_types:
409         assert self._bit_size is None, \
410                'Expression cannot use an unsized conversion opcode with ' \
411                'an explicit size; that\'s silly.'
412
413      self.__index_comm_exprs(0)
414
415   def equivalent(self, other):
416      """Check that two variables are equivalent.
417
418      This is check is much weaker than equality.  One generally cannot be
419      used in place of the other.  Using this implementation for the __eq__
420      will break BitSizeValidator.
421
422      This implementation does not check for equivalence due to commutativity,
423      but it could.
424
425      """
426      if not isinstance(other, type(self)):
427         return False
428
429      if len(self.sources) != len(other.sources):
430         return False
431
432      if self.opcode != other.opcode:
433         return False
434
435      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
436
437   def __index_comm_exprs(self, base_idx):
438      """Recursively count and index commutative expressions
439      """
440      self.comm_exprs = 0
441
442      # A note about the explicit "len(self.sources)" check. The list of
443      # sources comes from user input, and that input might be bad.  Check
444      # that the expected second source exists before accessing it. Without
445      # this check, a unit test that does "('iadd', 'a')" will crash.
446      if self.opcode not in conv_opcode_types and \
447         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
448         len(self.sources) >= 2 and \
449         not self.sources[0].equivalent(self.sources[1]):
450         self.comm_expr_idx = base_idx
451         self.comm_exprs += 1
452      else:
453         self.comm_expr_idx = -1
454
455      for s in self.sources:
456         if isinstance(s, Expression):
457            s.__index_comm_exprs(base_idx + self.comm_exprs)
458            self.comm_exprs += s.comm_exprs
459
460      return self.comm_exprs
461
462   def c_opcode(self):
463      return get_c_opcode(self.opcode)
464
465   def render(self, cache):
466      srcs = "".join(src.render(cache) for src in self.sources)
467      return srcs + super(Expression, self).render(cache)
468
469class BitSizeValidator(object):
470   """A class for validating bit sizes of expressions.
471
472   NIR supports multiple bit-sizes on expressions in order to handle things
473   such as fp64.  The source and destination of every ALU operation is
474   assigned a type and that type may or may not specify a bit size.  Sources
475   and destinations whose type does not specify a bit size are considered
476   "unsized" and automatically take on the bit size of the corresponding
477   register or SSA value.  NIR has two simple rules for bit sizes that are
478   validated by nir_validator:
479
480    1) A given SSA def or register has a single bit size that is respected by
481       everything that reads from it or writes to it.
482
483    2) The bit sizes of all unsized inputs/outputs on any given ALU
484       instruction must match.  They need not match the sized inputs or
485       outputs but they must match each other.
486
487   In order to keep nir_algebraic relatively simple and easy-to-use,
488   nir_search supports a type of bit-size inference based on the two rules
489   above.  This is similar to type inference in many common programming
490   languages.  If, for instance, you are constructing an add operation and you
491   know the second source is 16-bit, then you know that the other source and
492   the destination must also be 16-bit.  There are, however, cases where this
493   inference can be ambiguous or contradictory.  Consider, for instance, the
494   following transformation:
495
496   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
497
498   This transformation can potentially cause a problem because usub_borrow is
499   well-defined for any bit-size of integer.  However, b2i always generates a
500   32-bit result so it could end up replacing a 64-bit expression with one
501   that takes two 64-bit values and produces a 32-bit value.  As another
502   example, consider this expression:
503
504   (('bcsel', a, b, 0), ('iand', a, b))
505
506   In this case, in the search expression a must be 32-bit but b can
507   potentially have any bit size.  If we had a 64-bit b value, we would end up
508   trying to and a 32-bit value with a 64-bit value which would be invalid
509
510   This class solves that problem by providing a validation layer that proves
511   that a given search-and-replace operation is 100% well-defined before we
512   generate any code.  This ensures that bugs are caught at compile time
513   rather than at run time.
514
515   Each value maintains a "bit-size class", which is either an actual bit size
516   or an equivalence class with other values that must have the same bit size.
517   The validator works by combining bit-size classes with each other according
518   to the NIR rules outlined above, checking that there are no inconsistencies.
519   When doing this for the replacement expression, we make sure to never change
520   the equivalence class of any of the search values. We could make the example
521   transforms above work by doing some extra run-time checking of the search
522   expression, but we make the user specify those constraints themselves, to
523   avoid any surprises. Since the replacement bitsizes can only be connected to
524   the source bitsize via variables (variables must have the same bitsize in
525   the source and replacment expressions) or the roots of the expression (the
526   replacement expression must produce the same bit size as the search
527   expression), we prevent merging a variable with anything when processing the
528   replacement expression, or specializing the search bitsize
529   with anything. The former prevents
530
531   (('bcsel', a, b, 0), ('iand', a, b))
532
533   from being allowed, since we'd have to merge the bitsizes for a and b due to
534   the 'iand', while the latter prevents
535
536   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
537
538   from being allowed, since the search expression has the bit size of a and b,
539   which can't be specialized to 32 which is the bitsize of the replace
540   expression. It also prevents something like:
541
542   (('b2i', ('i2b', a)), ('ineq', a, 0))
543
544   since the bitsize of 'b2i', which can be anything, can't be specialized to
545   the bitsize of a.
546
547   After doing all this, we check that every subexpression of the replacement
548   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
549   of the search expresssion, since those are the things that are known when
550   constructing the replacement expresssion. Finally, we record the bitsize
551   needed in nir_search_value so that we know what to do when building the
552   replacement expression.
553   """
554
555   def __init__(self, varset):
556      self._var_classes = [None] * len(varset.names)
557
558   def compare_bitsizes(self, a, b):
559      """Determines which bitsize class is a specialization of the other, or
560      whether neither is. When we merge two different bitsizes, the
561      less-specialized bitsize always points to the more-specialized one, so
562      that calling get_bit_size() always gets you the most specialized bitsize.
563      The specialization partial order is given by:
564      - Physical bitsizes are always the most specialized, and a different
565        bitsize can never specialize another.
566      - In the search expression, variables can always be specialized to each
567        other and to physical bitsizes. In the replace expression, we disallow
568        this to avoid adding extra constraints to the search expression that
569        the user didn't specify.
570      - Expressions and constants without a bitsize can always be specialized to
571        each other and variables, but not the other way around.
572
573        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
574        and None if they are not comparable (neither a <= b nor b <= a).
575      """
576      if isinstance(a, int):
577         if isinstance(b, int):
578            return 0 if a == b else None
579         elif isinstance(b, Variable):
580            return -1 if self.is_search else None
581         else:
582            return -1
583      elif isinstance(a, Variable):
584         if isinstance(b, int):
585            return 1 if self.is_search else None
586         elif isinstance(b, Variable):
587            return 0 if self.is_search or a.index == b.index else None
588         else:
589            return -1
590      else:
591         if isinstance(b, int):
592            return 1
593         elif isinstance(b, Variable):
594            return 1
595         else:
596            return 0
597
598   def unify_bit_size(self, a, b, error_msg):
599      """Record that a must have the same bit-size as b. If both
600      have been assigned conflicting physical bit-sizes, call "error_msg" with
601      the bit-sizes of self and other to get a message and raise an error.
602      In the replace expression, disallow merging variables with other
603      variables and physical bit-sizes as well.
604      """
605      a_bit_size = a.get_bit_size()
606      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
607
608      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
609
610      assert cmp_result is not None, \
611         error_msg(a_bit_size, b_bit_size)
612
613      if cmp_result < 0:
614         b_bit_size.set_bit_size(a)
615      elif not isinstance(a_bit_size, int):
616         a_bit_size.set_bit_size(b)
617
618   def merge_variables(self, val):
619      """Perform the first part of type inference by merging all the different
620      uses of the same variable. We always do this as if we're in the search
621      expression, even if we're actually not, since otherwise we'd get errors
622      if the search expression specified some constraint but the replace
623      expression didn't, because we'd be merging a variable and a constant.
624      """
625      if isinstance(val, Variable):
626         if self._var_classes[val.index] is None:
627            self._var_classes[val.index] = val
628         else:
629            other = self._var_classes[val.index]
630            self.unify_bit_size(other, val,
631                  lambda other_bit_size, bit_size:
632                     'Variable {} has conflicting bit size requirements: ' \
633                     'it must have bit size {} and {}'.format(
634                        val.var_name, other_bit_size, bit_size))
635      elif isinstance(val, Expression):
636         for src in val.sources:
637            self.merge_variables(src)
638
639   def validate_value(self, val):
640      """Validate the an expression by performing classic Hindley-Milner
641      type inference on bitsizes. This will detect if there are any conflicting
642      requirements, and unify variables so that we know which variables must
643      have the same bitsize. If we're operating on the replace expression, we
644      will refuse to merge different variables together or merge a variable
645      with a constant, in order to prevent surprises due to rules unexpectedly
646      not matching at runtime.
647      """
648      if not isinstance(val, Expression):
649         return
650
651      # Generic conversion ops are special in that they have a single unsized
652      # source and an unsized destination and the two don't have to match.
653      # This means there's no validation or unioning to do here besides the
654      # len(val.sources) check.
655      if val.opcode in conv_opcode_types:
656         assert len(val.sources) == 1, \
657            "Expression {} has {} sources, expected 1".format(
658               val, len(val.sources))
659         self.validate_value(val.sources[0])
660         return
661
662      nir_op = opcodes[val.opcode]
663      assert len(val.sources) == nir_op.num_inputs, \
664         "Expression {} has {} sources, expected {}".format(
665            val, len(val.sources), nir_op.num_inputs)
666
667      for src in val.sources:
668         self.validate_value(src)
669
670      dst_type_bits = type_bits(nir_op.output_type)
671
672      # First, unify all the sources. That way, an error coming up because two
673      # sources have an incompatible bit-size won't produce an error message
674      # involving the destination.
675      first_unsized_src = None
676      for src_type, src in zip(nir_op.input_types, val.sources):
677         src_type_bits = type_bits(src_type)
678         if src_type_bits == 0:
679            if first_unsized_src is None:
680               first_unsized_src = src
681               continue
682
683            if self.is_search:
684               self.unify_bit_size(first_unsized_src, src,
685                  lambda first_unsized_src_bit_size, src_bit_size:
686                     'Source {} of {} must have bit size {}, while source {} ' \
687                     'must have incompatible bit size {}'.format(
688                        first_unsized_src, val, first_unsized_src_bit_size,
689                        src, src_bit_size))
690            else:
691               self.unify_bit_size(first_unsized_src, src,
692                  lambda first_unsized_src_bit_size, src_bit_size:
693                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
694                     'of {} may not have the same bit size when building the ' \
695                     'replacement expression.'.format(
696                        first_unsized_src, first_unsized_src_bit_size, src,
697                        src_bit_size, val))
698         else:
699            if self.is_search:
700               self.unify_bit_size(src, src_type_bits,
701                  lambda src_bit_size, unused:
702                     '{} must have {} bits, but as a source of nir_op_{} '\
703                     'it must have {} bits'.format(
704                        src, src_bit_size, nir_op.name, src_type_bits))
705            else:
706               self.unify_bit_size(src, src_type_bits,
707                  lambda src_bit_size, unused:
708                     '{} has the bit size of {}, but as a source of ' \
709                     'nir_op_{} it must have {} bits, which may not be the ' \
710                     'same'.format(
711                        src, src_bit_size, nir_op.name, src_type_bits))
712
713      if dst_type_bits == 0:
714         if first_unsized_src is not None:
715            if self.is_search:
716               self.unify_bit_size(val, first_unsized_src,
717                  lambda val_bit_size, src_bit_size:
718                     '{} must have the bit size of {}, while its source {} ' \
719                     'must have incompatible bit size {}'.format(
720                        val, val_bit_size, first_unsized_src, src_bit_size))
721            else:
722               self.unify_bit_size(val, first_unsized_src,
723                  lambda val_bit_size, src_bit_size:
724                     '{} must have {} bits, but its source {} ' \
725                     '(bit size of {}) may not have that bit size ' \
726                     'when building the replacement.'.format(
727                        val, val_bit_size, first_unsized_src, src_bit_size))
728      else:
729         self.unify_bit_size(val, dst_type_bits,
730            lambda dst_bit_size, unused:
731               '{} must have {} bits, but as a destination of nir_op_{} ' \
732               'it must have {} bits'.format(
733                  val, dst_bit_size, nir_op.name, dst_type_bits))
734
735   def validate_replace(self, val, search):
736      bit_size = val.get_bit_size()
737      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
738            bit_size == search.get_bit_size(), \
739            'Ambiguous bit size for replacement value {}: ' \
740            'it cannot be deduced from a variable, a fixed bit size ' \
741            'somewhere, or the search expression.'.format(val)
742
743      if isinstance(val, Expression):
744         for src in val.sources:
745            self.validate_replace(src, search)
746      elif isinstance(val, Variable):
747          # These catch problems when someone copies and pastes the search
748          # into the replacement.
749          assert not val.is_constant, \
750              'Replacement variables must not be marked constant.'
751
752          assert val.cond_index == -1, \
753              'Replacement variables must not have a condition.'
754
755          assert not val.required_type, \
756              'Replacement variables must not have a required type.'
757
758   def validate(self, search, replace):
759      self.is_search = True
760      self.merge_variables(search)
761      self.merge_variables(replace)
762      self.validate_value(search)
763
764      self.is_search = False
765      self.validate_value(replace)
766
767      # Check that search is always more specialized than replace. Note that
768      # we're doing this in replace mode, disallowing merging variables.
769      search_bit_size = search.get_bit_size()
770      replace_bit_size = replace.get_bit_size()
771      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
772
773      assert cmp_result is not None and cmp_result <= 0, \
774         'The search expression bit size {} and replace expression ' \
775         'bit size {} may not be the same'.format(
776               search_bit_size, replace_bit_size)
777
778      replace.set_bit_size(search)
779
780      self.validate_replace(replace, search)
781
782_optimization_ids = itertools.count()
783
784condition_list = ['true']
785
786class SearchAndReplace(object):
787   def __init__(self, transform, algebraic_pass):
788      self.id = next(_optimization_ids)
789
790      search = transform[0]
791      replace = transform[1]
792      if len(transform) > 2:
793         self.condition = transform[2]
794      else:
795         self.condition = 'true'
796
797      if self.condition not in condition_list:
798         condition_list.append(self.condition)
799      self.condition_index = condition_list.index(self.condition)
800
801      varset = VarSet()
802      if isinstance(search, Expression):
803         self.search = search
804      else:
805         self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
806
807      varset.lock()
808
809      if isinstance(replace, Value):
810         self.replace = replace
811      else:
812         self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
813
814      BitSizeValidator(varset).validate(self.search, self.replace)
815
816class TreeAutomaton(object):
817   """This class calculates a bottom-up tree automaton to quickly search for
818   the left-hand sides of tranforms. Tree automatons are a generalization of
819   classical NFA's and DFA's, where the transition function determines the
820   state of the parent node based on the state of its children. We construct a
821   deterministic automaton to match patterns, using a similar algorithm to the
822   classical NFA to DFA construction. At the moment, it only matches opcodes
823   and constants (without checking the actual value), leaving more detailed
824   checking to the search function which actually checks the leaves. The
825   automaton acts as a quick filter for the search function, requiring only n
826   + 1 table lookups for each n-source operation. The implementation is based
827   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
828   In the language of that reference, this is a frontier-to-root deterministic
829   automaton using only symbol filtering. The filtering is crucial to reduce
830   both the time taken to generate the tables and the size of the tables.
831   """
832   def __init__(self, transforms):
833      self.patterns = [t.search for t in transforms]
834      self._compute_items()
835      self._build_table()
836      #print('num items: {}'.format(len(set(self.items.values()))))
837      #print('num states: {}'.format(len(self.states)))
838      #for state, patterns in zip(self.states, self.patterns):
839      #   print('{}: num patterns: {}'.format(state, len(patterns)))
840
841   class IndexMap(object):
842      """An indexed list of objects, where one can either lookup an object by
843      index or find the index associated to an object quickly using a hash
844      table. Compared to a list, it has a constant time index(). Compared to a
845      set, it provides a stable iteration order.
846      """
847      def __init__(self, iterable=()):
848         self.objects = []
849         self.map = {}
850         for obj in iterable:
851            self.add(obj)
852
853      def __getitem__(self, i):
854         return self.objects[i]
855
856      def __contains__(self, obj):
857         return obj in self.map
858
859      def __len__(self):
860         return len(self.objects)
861
862      def __iter__(self):
863         return iter(self.objects)
864
865      def clear(self):
866         self.objects = []
867         self.map.clear()
868
869      def index(self, obj):
870         return self.map[obj]
871
872      def add(self, obj):
873         if obj in self.map:
874            return self.map[obj]
875         else:
876            index = len(self.objects)
877            self.objects.append(obj)
878            self.map[obj] = index
879            return index
880
881      def __repr__(self):
882         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
883
884   class Item(object):
885      """This represents an "item" in the language of "Tree Automatons." This
886      is just a subtree of some pattern, which represents a potential partial
887      match at runtime. We deduplicate them, so that identical subtrees of
888      different patterns share the same object, and store some extra
889      information needed for the main algorithm as well.
890      """
891      def __init__(self, opcode, children):
892         self.opcode = opcode
893         self.children = children
894         # These are the indices of patterns for which this item is the root node.
895         self.patterns = []
896         # This the set of opcodes for parents of this item. Used to speed up
897         # filtering.
898         self.parent_ops = set()
899
900      def __str__(self):
901         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
902
903      def __repr__(self):
904         return str(self)
905
906   def _compute_items(self):
907      """Build a set of all possible items, deduplicating them."""
908      # This is a map from (opcode, sources) to item.
909      self.items = {}
910
911      # The set of all opcodes used by the patterns. Used later to avoid
912      # building and emitting all the tables for opcodes that aren't used.
913      self.opcodes = self.IndexMap()
914
915      def get_item(opcode, children, pattern=None):
916         commutative = len(children) >= 2 \
917               and "2src_commutative" in opcodes[opcode].algebraic_properties
918         item = self.items.setdefault((opcode, children),
919                                      self.Item(opcode, children))
920         if commutative:
921            self.items[opcode, (children[1], children[0]) + children[2:]] = item
922         if pattern is not None:
923            item.patterns.append(pattern)
924         return item
925
926      self.wildcard = get_item("__wildcard", ())
927      self.const = get_item("__const", ())
928
929      def process_subpattern(src, pattern=None):
930         if isinstance(src, Constant):
931            # Note: we throw away the actual constant value!
932            return self.const
933         elif isinstance(src, Variable):
934            if src.is_constant:
935               return self.const
936            else:
937               # Note: we throw away which variable it is here! This special
938               # item is equivalent to nu in "Tree Automatons."
939               return self.wildcard
940         else:
941            assert isinstance(src, Expression)
942            opcode = src.opcode
943            stripped = opcode.rstrip('0123456789')
944            if stripped in conv_opcode_types:
945               # Matches that use conversion opcodes with a specific type,
946               # like f2i1, are tricky.  Either we construct the automaton to
947               # match specific NIR opcodes like nir_op_f2i1, in which case we
948               # need to create separate items for each possible NIR opcode
949               # for patterns that have a generic opcode like f2i, or we
950               # construct it to match the search opcode, in which case we
951               # need to map f2i1 to f2i when constructing the automaton. Here
952               # we do the latter.
953               opcode = stripped
954            self.opcodes.add(opcode)
955            children = tuple(process_subpattern(c) for c in src.sources)
956            item = get_item(opcode, children, pattern)
957            for i, child in enumerate(children):
958               child.parent_ops.add(opcode)
959            return item
960
961      for i, pattern in enumerate(self.patterns):
962         process_subpattern(pattern, i)
963
964   def _build_table(self):
965      """This is the core algorithm which builds up the transition table. It
966      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
967      Comp_a and Filt_{a,i} using integers to identify match sets." It
968      simultaneously builds up a list of all possible "match sets" or
969      "states", where each match set represents the set of Item's that match a
970      given instruction, and builds up the transition table between states.
971      """
972      # Map from opcode + filtered state indices to transitioned state.
973      self.table = defaultdict(dict)
974      # Bijection from state to index. q in the original algorithm is
975      # len(self.states)
976      self.states = self.IndexMap()
977      # Lists of pattern matches separated by None
978      self.state_patterns = [None]
979      # Offset in the ->transforms table for each state index
980      self.state_pattern_offsets = []
981      # Map from state index to filtered state index for each opcode.
982      self.filter = defaultdict(list)
983      # Bijections from filtered state to filtered state index for each
984      # opcode, called the "representor sets" in the original algorithm.
985      # q_{a,j} in the original algorithm is len(self.rep[op]).
986      self.rep = defaultdict(self.IndexMap)
987
988      # Everything in self.states with a index at least worklist_index is part
989      # of the worklist of newly created states. There is also a worklist of
990      # newly fitered states for each opcode, for which worklist_indices
991      # serves a similar purpose. worklist_index corresponds to p in the
992      # original algorithm, while worklist_indices is p_{a,j} (although since
993      # we only filter by opcode/symbol, it's really just p_a).
994      self.worklist_index = 0
995      worklist_indices = defaultdict(lambda: 0)
996
997      # This is the set of opcodes for which the filtered worklist is non-empty.
998      # It's used to avoid scanning opcodes for which there is nothing to
999      # process when building the transition table. It corresponds to new_a in
1000      # the original algorithm.
1001      new_opcodes = self.IndexMap()
1002
1003      # Process states on the global worklist, filtering them for each opcode,
1004      # updating the filter tables, and updating the filtered worklists if any
1005      # new filtered states are found. Similar to ComputeRepresenterSets() in
1006      # the original algorithm, although that only processes a single state.
1007      def process_new_states():
1008         while self.worklist_index < len(self.states):
1009            state = self.states[self.worklist_index]
1010            # Calculate pattern matches for this state. Each pattern is
1011            # assigned to a unique item, so we don't have to worry about
1012            # deduplicating them here. However, we do have to sort them so
1013            # that they're visited at runtime in the order they're specified
1014            # in the source.
1015            patterns = list(sorted(p for item in state for p in item.patterns))
1016
1017            if patterns:
1018                # Add our patterns to the global table.
1019                self.state_pattern_offsets.append(len(self.state_patterns))
1020                self.state_patterns.extend(patterns)
1021                self.state_patterns.append(None)
1022            else:
1023                # Point to the initial sentinel in the global table.
1024                self.state_pattern_offsets.append(0)
1025
1026            # calculate filter table for this state, and update filtered
1027            # worklists.
1028            for op in self.opcodes:
1029               filt = self.filter[op]
1030               rep = self.rep[op]
1031               filtered = frozenset(item for item in state if \
1032                  op in item.parent_ops)
1033               if filtered in rep:
1034                  rep_index = rep.index(filtered)
1035               else:
1036                  rep_index = rep.add(filtered)
1037                  new_opcodes.add(op)
1038               assert len(filt) == self.worklist_index
1039               filt.append(rep_index)
1040            self.worklist_index += 1
1041
1042      # There are two start states: one which can only match as a wildcard,
1043      # and one which can match as a wildcard or constant. These will be the
1044      # states of intrinsics/other instructions and load_const instructions,
1045      # respectively. The indices of these must match the definitions of
1046      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1047      # initialize things correctly.
1048      self.states.add(frozenset((self.wildcard,)))
1049      self.states.add(frozenset((self.const,self.wildcard)))
1050      process_new_states()
1051
1052      while len(new_opcodes) > 0:
1053         for op in new_opcodes:
1054            rep = self.rep[op]
1055            table = self.table[op]
1056            op_worklist_index = worklist_indices[op]
1057            if op in conv_opcode_types:
1058               num_srcs = 1
1059            else:
1060               num_srcs = opcodes[op].num_inputs
1061
1062            # Iterate over all possible source combinations where at least one
1063            # is on the worklist.
1064            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1065               if all(src_idx < op_worklist_index for src_idx in src_indices):
1066                  continue
1067
1068               srcs = tuple(rep[src_idx] for src_idx in src_indices)
1069
1070               # Try all possible pairings of source items and add the
1071               # corresponding parent items. This is Comp_a from the paper.
1072               parent = set(self.items[op, item_srcs] for item_srcs in
1073                  itertools.product(*srcs) if (op, item_srcs) in self.items)
1074
1075               # We could always start matching something else with a
1076               # wildcard. This is Cl from the paper.
1077               parent.add(self.wildcard)
1078
1079               table[src_indices] = self.states.add(frozenset(parent))
1080            worklist_indices[op] = len(rep)
1081         new_opcodes.clear()
1082         process_new_states()
1083
1084_algebraic_pass_template = mako.template.Template("""
1085#include "nir.h"
1086#include "nir_builder.h"
1087#include "nir_search.h"
1088#include "nir_search_helpers.h"
1089
1090/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1091 * transforms:
1092% for xform in xforms:
1093 *    ${xform.search} => ${xform.replace}
1094% endfor
1095 */
1096
1097<% cache = {"next_index": 0} %>
1098static const nir_search_value_union ${pass_name}_values[] = {
1099% for xform in xforms:
1100   /* ${xform.search} => ${xform.replace} */
1101${xform.search.render(cache)}
1102${xform.replace.render(cache)}
1103% endfor
1104};
1105
1106% if expression_cond:
1107static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
1108% for cond in expression_cond:
1109   ${cond[0]},
1110% endfor
1111};
1112% endif
1113
1114% if variable_cond:
1115static const nir_search_variable_cond ${pass_name}_variable_cond[] = {
1116% for cond in variable_cond:
1117   ${cond[0]},
1118% endfor
1119};
1120% endif
1121
1122static const struct transform ${pass_name}_transforms[] = {
1123% for i in automaton.state_patterns:
1124% if i is not None:
1125   { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} },
1126% else:
1127   { ~0, ~0, ~0 }, /* Sentinel */
1128
1129% endif
1130% endfor
1131};
1132
1133static const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = {
1134% for op in automaton.opcodes:
1135   [${get_c_opcode(op)}] = {
1136% if all(e == 0 for e in automaton.filter[op]):
1137      .filter = NULL,
1138% else:
1139      .filter = (const uint16_t []) {
1140      % for e in automaton.filter[op]:
1141         ${e},
1142      % endfor
1143      },
1144% endif
1145      <%
1146        num_filtered = len(automaton.rep[op])
1147      %>
1148      .num_filtered_states = ${num_filtered},
1149      .table = (const uint16_t []) {
1150      <%
1151        num_srcs = len(next(iter(automaton.table[op])))
1152      %>
1153      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1154         ${automaton.table[op][indices]},
1155      % endfor
1156      },
1157   },
1158% endfor
1159};
1160
1161/* Mapping from state index to offset in transforms (0 being no transforms) */
1162static const uint16_t ${pass_name}_transform_offsets[] = {
1163% for offset in automaton.state_pattern_offsets:
1164   ${offset},
1165% endfor
1166};
1167
1168static const nir_algebraic_table ${pass_name}_table = {
1169   .transforms = ${pass_name}_transforms,
1170   .transform_offsets = ${pass_name}_transform_offsets,
1171   .pass_op_table = ${pass_name}_pass_op_table,
1172   .values = ${pass_name}_values,
1173   .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
1174   .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" },
1175};
1176
1177bool
1178${pass_name}(
1179   nir_shader *shader
1180% for type, name in params:
1181   , ${type} ${name}
1182% endfor
1183) {
1184   bool progress = false;
1185   bool condition_flags[${len(condition_list)}];
1186   const nir_shader_compiler_options *options = shader->options;
1187   const shader_info *info = &shader->info;
1188   (void) options;
1189   (void) info;
1190
1191   STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values));
1192   % for index, condition in enumerate(condition_list):
1193   condition_flags[${index}] = ${condition};
1194   % endfor
1195
1196   nir_foreach_function_impl(impl, shader) {
1197     progress |= nir_algebraic_impl(impl, condition_flags, &${pass_name}_table);
1198   }
1199
1200   return progress;
1201}
1202""")
1203
1204
1205class AlgebraicPass(object):
1206   # params is a list of `("type", "name")` tuples
1207   def __init__(self, pass_name, transforms, params=[]):
1208      self.xforms = []
1209      self.opcode_xforms = defaultdict(lambda : [])
1210      self.pass_name = pass_name
1211      self.expression_cond = {}
1212      self.variable_cond = {}
1213      self.params = params
1214
1215      error = False
1216
1217      for xform in transforms:
1218         if not isinstance(xform, SearchAndReplace):
1219            try:
1220               xform = SearchAndReplace(xform, self)
1221            except:
1222               print("Failed to parse transformation:", file=sys.stderr)
1223               print("  " + str(xform), file=sys.stderr)
1224               traceback.print_exc(file=sys.stderr)
1225               print('', file=sys.stderr)
1226               error = True
1227               continue
1228
1229         self.xforms.append(xform)
1230         if xform.search.opcode in conv_opcode_types:
1231            dst_type = conv_opcode_types[xform.search.opcode]
1232            for size in type_sizes(dst_type):
1233               sized_opcode = xform.search.opcode + str(size)
1234               self.opcode_xforms[sized_opcode].append(xform)
1235         else:
1236            self.opcode_xforms[xform.search.opcode].append(xform)
1237
1238         # Check to make sure the search pattern does not unexpectedly contain
1239         # more commutative expressions than match_expression (nir_search.c)
1240         # can handle.
1241         comm_exprs = xform.search.comm_exprs
1242
1243         if xform.search.many_commutative_expressions:
1244            if comm_exprs <= nir_search_max_comm_ops:
1245               print("Transform expected to have too many commutative " \
1246                     "expression but did not " \
1247                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1248                     file=sys.stderr)
1249               print("  " + str(xform), file=sys.stderr)
1250               traceback.print_exc(file=sys.stderr)
1251               print('', file=sys.stderr)
1252               error = True
1253         else:
1254            if comm_exprs > nir_search_max_comm_ops:
1255               print("Transformation with too many commutative expressions " \
1256                     "({} > {}).  Modify pattern or annotate with " \
1257                     "\"many-comm-expr\".".format(comm_exprs,
1258                                                  nir_search_max_comm_ops),
1259                     file=sys.stderr)
1260               print("  " + str(xform.search), file=sys.stderr)
1261               print("{}".format(xform.search.cond), file=sys.stderr)
1262               error = True
1263
1264      self.automaton = TreeAutomaton(self.xforms)
1265
1266      if error:
1267         sys.exit(1)
1268
1269
1270   def render(self):
1271      return _algebraic_pass_template.render(pass_name=self.pass_name,
1272                                             xforms=self.xforms,
1273                                             opcode_xforms=self.opcode_xforms,
1274                                             condition_list=condition_list,
1275                                             automaton=self.automaton,
1276                                             expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
1277                                             variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
1278                                             get_c_opcode=get_c_opcode,
1279                                             itertools=itertools,
1280                                             params=self.params)
1281
1282# The replacement expression isn't necessarily exact if the search expression is exact.
1283def ignore_exact(*expr):
1284   expr = SearchExpression.create(expr)
1285   expr.ignore_exact = True
1286   return expr
1287