1# 2# Copyright (C) 2014 Intel Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22 23import ast 24from collections import defaultdict 25import itertools 26import struct 27import sys 28import mako.template 29import re 30import traceback 31 32from nir_opcodes import opcodes, type_sizes 33 34# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 35nir_search_max_comm_ops = 8 36 37# These opcodes are only employed by nir_search. This provides a mapping from 38# opcode to destination type. 39conv_opcode_types = { 40 'i2f' : 'float', 41 'u2f' : 'float', 42 'f2f' : 'float', 43 'f2u' : 'uint', 44 'f2i' : 'int', 45 'u2u' : 'uint', 46 'i2i' : 'int', 47 'b2f' : 'float', 48 'b2i' : 'int', 49 'i2b' : 'bool', 50 'f2b' : 'bool', 51} 52 53def get_cond_index(conds, cond): 54 if cond: 55 if cond in conds: 56 return conds[cond] 57 else: 58 cond_index = len(conds) 59 conds[cond] = cond_index 60 return cond_index 61 else: 62 return -1 63 64def get_c_opcode(op): 65 if op in conv_opcode_types: 66 return 'nir_search_op_' + op 67 else: 68 return 'nir_op_' + op 69 70_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 71 72def type_bits(type_str): 73 m = _type_re.match(type_str) 74 assert m.group('type') 75 76 if m.group('bits') is None: 77 return 0 78 else: 79 return int(m.group('bits')) 80 81# Represents a set of variables, each with a unique id 82class VarSet(object): 83 def __init__(self): 84 self.names = {} 85 self.ids = itertools.count() 86 self.immutable = False; 87 88 def __getitem__(self, name): 89 if name not in self.names: 90 assert not self.immutable, "Unknown replacement variable: " + name 91 self.names[name] = next(self.ids) 92 93 return self.names[name] 94 95 def lock(self): 96 self.immutable = True 97 98class SearchExpression(object): 99 def __init__(self, expr): 100 self.opcode = expr[0] 101 self.sources = expr[1:] 102 self.ignore_exact = False 103 104 @staticmethod 105 def create(val): 106 if isinstance(val, tuple): 107 return SearchExpression(val) 108 else: 109 assert(isinstance(val, SearchExpression)) 110 return val 111 112 def __repr__(self): 113 l = [self.opcode, *self.sources] 114 if self.ignore_exact: 115 l.append('ignore_exact') 116 return repr((*l,)) 117 118class Value(object): 119 @staticmethod 120 def create(val, name_base, varset, algebraic_pass): 121 if isinstance(val, bytes): 122 val = val.decode('utf-8') 123 124 if isinstance(val, tuple) or isinstance(val, SearchExpression): 125 return Expression(val, name_base, varset, algebraic_pass) 126 elif isinstance(val, Expression): 127 return val 128 elif isinstance(val, str): 129 return Variable(val, name_base, varset, algebraic_pass) 130 elif isinstance(val, (bool, float, int)): 131 return Constant(val, name_base) 132 133 def __init__(self, val, name, type_str): 134 self.in_val = str(val) 135 self.name = name 136 self.type_str = type_str 137 138 def __str__(self): 139 return self.in_val 140 141 def get_bit_size(self): 142 """Get the physical bit-size that has been chosen for this value, or if 143 there is none, the canonical value which currently represents this 144 bit-size class. Variables will be preferred, i.e. if there are any 145 variables in the equivalence class, the canonical value will be a 146 variable. We do this since we'll need to know which variable each value 147 is equivalent to when constructing the replacement expression. This is 148 the "find" part of the union-find algorithm. 149 """ 150 bit_size = self 151 152 while isinstance(bit_size, Value): 153 if bit_size._bit_size is None: 154 break 155 bit_size = bit_size._bit_size 156 157 if bit_size is not self: 158 self._bit_size = bit_size 159 return bit_size 160 161 def set_bit_size(self, other): 162 """Make self.get_bit_size() return what other.get_bit_size() return 163 before calling this, or just "other" if it's a concrete bit-size. This is 164 the "union" part of the union-find algorithm. 165 """ 166 167 self_bit_size = self.get_bit_size() 168 other_bit_size = other if isinstance(other, int) else other.get_bit_size() 169 170 if self_bit_size == other_bit_size: 171 return 172 173 self_bit_size._bit_size = other_bit_size 174 175 @property 176 def type_enum(self): 177 return "nir_search_value_" + self.type_str 178 179 @property 180 def c_bit_size(self): 181 bit_size = self.get_bit_size() 182 if isinstance(bit_size, int): 183 return bit_size 184 elif isinstance(bit_size, Variable): 185 return -bit_size.index - 1 186 else: 187 # If the bit-size class is neither a variable, nor an actual bit-size, then 188 # - If it's in the search expression, we don't need to check anything 189 # - If it's in the replace expression, either it's ambiguous (in which 190 # case we'd reject it), or it equals the bit-size of the search value 191 # We represent these cases with a 0 bit-size. 192 return 0 193 194 __template = mako.template.Template(""" { .${val.type_str} = { 195 { ${val.type_enum}, ${val.c_bit_size} }, 196% if isinstance(val, Constant): 197 ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 198% elif isinstance(val, Variable): 199 ${val.index}, /* ${val.var_name} */ 200 ${'true' if val.is_constant else 'false'}, 201 ${val.type() or 'nir_type_invalid' }, 202 ${val.cond_index}, 203 ${val.swizzle()}, 204% elif isinstance(val, Expression): 205 ${'true' if val.inexact else 'false'}, 206 ${'true' if val.exact else 'false'}, 207 ${'true' if val.ignore_exact else 'false'}, 208 ${'true' if val.nsz else 'false'}, 209 ${'true' if val.nnan else 'false'}, 210 ${'true' if val.ninf else 'false'}, 211 ${val.c_opcode()}, 212 ${val.comm_expr_idx}, ${val.comm_exprs}, 213 { ${', '.join(src.array_index for src in val.sources)} }, 214 ${val.cond_index}, 215% endif 216 } }, 217""") 218 219 def render(self, cache): 220 struct_init = self.__template.render(val=self, 221 Constant=Constant, 222 Variable=Variable, 223 Expression=Expression) 224 if struct_init in cache: 225 # If it's in the cache, register a name remap in the cache and render 226 # only a comment saying it's been remapped 227 self.array_index = cache[struct_init] 228 return " /* {} -> {} in the cache */\n".format(self.name, 229 cache[struct_init]) 230 else: 231 self.array_index = str(cache["next_index"]) 232 cache[struct_init] = self.array_index 233 cache["next_index"] += 1 234 return struct_init 235 236_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 237 238class Constant(Value): 239 def __init__(self, val, name): 240 Value.__init__(self, val, name, "constant") 241 242 if isinstance(val, (str)): 243 m = _constant_re.match(val) 244 self.value = ast.literal_eval(m.group('value')) 245 self._bit_size = int(m.group('bits')) if m.group('bits') else None 246 else: 247 self.value = val 248 self._bit_size = None 249 250 if isinstance(self.value, bool): 251 assert self._bit_size is None or self._bit_size == 1 252 self._bit_size = 1 253 254 def hex(self): 255 if isinstance(self.value, (bool)): 256 return 'NIR_TRUE' if self.value else 'NIR_FALSE' 257 if isinstance(self.value, int): 258 # Explicitly sign-extend negative integers to 64-bit, ensuring correct 259 # handling of -INT32_MIN which is not representable in 32-bit. 260 if self.value < 0: 261 return hex(struct.unpack('Q', struct.pack('q', self.value))[0]) + 'ull' 262 else: 263 return hex(self.value) + 'ull' 264 elif isinstance(self.value, float): 265 return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) + 'ull' 266 else: 267 assert False 268 269 def type(self): 270 if isinstance(self.value, (bool)): 271 return "nir_type_bool" 272 elif isinstance(self.value, int): 273 return "nir_type_int" 274 elif isinstance(self.value, float): 275 return "nir_type_float" 276 277 def equivalent(self, other): 278 """Check that two constants are equivalent. 279 280 This is check is much weaker than equality. One generally cannot be 281 used in place of the other. Using this implementation for the __eq__ 282 will break BitSizeValidator. 283 284 """ 285 if not isinstance(other, type(self)): 286 return False 287 288 return self.value == other.value 289 290# The $ at the end forces there to be an error if any part of the string 291# doesn't match one of the field patterns. 292_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 293 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 294 r"(?P<cond>\([^\)]+\))?" 295 r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?" 296 r"$") 297 298class Variable(Value): 299 def __init__(self, val, name, varset, algebraic_pass): 300 Value.__init__(self, val, name, "variable") 301 302 m = _var_name_re.match(val) 303 assert m and m.group('name') is not None, \ 304 "Malformed variable name \"{}\".".format(val) 305 306 self.var_name = m.group('name') 307 308 # Prevent common cases where someone puts quotes around a literal 309 # constant. If we want to support names that have numeric or 310 # punctuation characters, we can me the first assertion more flexible. 311 assert self.var_name.isalpha() 312 assert self.var_name != 'True' 313 assert self.var_name != 'False' 314 315 self.is_constant = m.group('const') is not None 316 self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond')) 317 self.required_type = m.group('type') 318 self._bit_size = int(m.group('bits')) if m.group('bits') else None 319 self.swiz = m.group('swiz') 320 321 if self.required_type == 'bool': 322 if self._bit_size is not None: 323 assert self._bit_size in type_sizes(self.required_type) 324 else: 325 self._bit_size = 1 326 327 if self.required_type is not None: 328 assert self.required_type in ('float', 'bool', 'int', 'uint') 329 330 self.index = varset[self.var_name] 331 332 def type(self): 333 if self.required_type == 'bool': 334 return "nir_type_bool" 335 elif self.required_type in ('int', 'uint'): 336 return "nir_type_int" 337 elif self.required_type == 'float': 338 return "nir_type_float" 339 340 def equivalent(self, other): 341 """Check that two variables are equivalent. 342 343 This is check is much weaker than equality. One generally cannot be 344 used in place of the other. Using this implementation for the __eq__ 345 will break BitSizeValidator. 346 347 """ 348 if not isinstance(other, type(self)): 349 return False 350 351 return self.index == other.index 352 353 def swizzle(self): 354 if self.swiz is not None: 355 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 356 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 357 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 358 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 359 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 360 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 361 return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 362 363_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 364 r"(?P<cond>\([^\)]+\))?") 365 366class Expression(Value): 367 def __init__(self, expr, name_base, varset, algebraic_pass): 368 Value.__init__(self, expr, name_base, "expression") 369 370 expr = SearchExpression.create(expr) 371 372 m = _opcode_re.match(expr.opcode) 373 assert m and m.group('opcode') is not None 374 375 self.opcode = m.group('opcode') 376 self._bit_size = int(m.group('bits')) if m.group('bits') else None 377 self.inexact = m.group('inexact') is not None 378 self.exact = m.group('exact') is not None 379 self.ignore_exact = expr.ignore_exact 380 self.cond = m.group('cond') 381 382 assert not self.inexact or not self.exact, \ 383 'Expression cannot be both exact and inexact.' 384 385 # "many-comm-expr" isn't really a condition. It's notification to the 386 # generator that this pattern is known to have too many commutative 387 # expressions, and an error should not be generated for this case. 388 # nsz, nnan and ninf are special conditions, so we treat them specially too. 389 cond = {k: True for k in self.cond[1:-1].split(",")} if self.cond else {} 390 self.many_commutative_expressions = cond.pop('many-comm-expr', False) 391 self.nsz = cond.pop('nsz', False) 392 self.nnan = cond.pop('nnan', False) 393 self.ninf = cond.pop('ninf', False) 394 395 assert len(cond) <= 1 396 self.cond = cond.popitem()[0] if cond else None 397 398 # Deduplicate references to the condition functions for the expressions 399 # and save the index for the order they were added. 400 self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond) 401 402 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass) 403 for (i, src) in enumerate(expr.sources) ] 404 405 # nir_search_expression::srcs is hard-coded to 4 406 assert len(self.sources) <= 4 407 408 if self.opcode in conv_opcode_types: 409 assert self._bit_size is None, \ 410 'Expression cannot use an unsized conversion opcode with ' \ 411 'an explicit size; that\'s silly.' 412 413 self.__index_comm_exprs(0) 414 415 def equivalent(self, other): 416 """Check that two variables are equivalent. 417 418 This is check is much weaker than equality. One generally cannot be 419 used in place of the other. Using this implementation for the __eq__ 420 will break BitSizeValidator. 421 422 This implementation does not check for equivalence due to commutativity, 423 but it could. 424 425 """ 426 if not isinstance(other, type(self)): 427 return False 428 429 if len(self.sources) != len(other.sources): 430 return False 431 432 if self.opcode != other.opcode: 433 return False 434 435 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 436 437 def __index_comm_exprs(self, base_idx): 438 """Recursively count and index commutative expressions 439 """ 440 self.comm_exprs = 0 441 442 # A note about the explicit "len(self.sources)" check. The list of 443 # sources comes from user input, and that input might be bad. Check 444 # that the expected second source exists before accessing it. Without 445 # this check, a unit test that does "('iadd', 'a')" will crash. 446 if self.opcode not in conv_opcode_types and \ 447 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 448 len(self.sources) >= 2 and \ 449 not self.sources[0].equivalent(self.sources[1]): 450 self.comm_expr_idx = base_idx 451 self.comm_exprs += 1 452 else: 453 self.comm_expr_idx = -1 454 455 for s in self.sources: 456 if isinstance(s, Expression): 457 s.__index_comm_exprs(base_idx + self.comm_exprs) 458 self.comm_exprs += s.comm_exprs 459 460 return self.comm_exprs 461 462 def c_opcode(self): 463 return get_c_opcode(self.opcode) 464 465 def render(self, cache): 466 srcs = "".join(src.render(cache) for src in self.sources) 467 return srcs + super(Expression, self).render(cache) 468 469class BitSizeValidator(object): 470 """A class for validating bit sizes of expressions. 471 472 NIR supports multiple bit-sizes on expressions in order to handle things 473 such as fp64. The source and destination of every ALU operation is 474 assigned a type and that type may or may not specify a bit size. Sources 475 and destinations whose type does not specify a bit size are considered 476 "unsized" and automatically take on the bit size of the corresponding 477 register or SSA value. NIR has two simple rules for bit sizes that are 478 validated by nir_validator: 479 480 1) A given SSA def or register has a single bit size that is respected by 481 everything that reads from it or writes to it. 482 483 2) The bit sizes of all unsized inputs/outputs on any given ALU 484 instruction must match. They need not match the sized inputs or 485 outputs but they must match each other. 486 487 In order to keep nir_algebraic relatively simple and easy-to-use, 488 nir_search supports a type of bit-size inference based on the two rules 489 above. This is similar to type inference in many common programming 490 languages. If, for instance, you are constructing an add operation and you 491 know the second source is 16-bit, then you know that the other source and 492 the destination must also be 16-bit. There are, however, cases where this 493 inference can be ambiguous or contradictory. Consider, for instance, the 494 following transformation: 495 496 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 497 498 This transformation can potentially cause a problem because usub_borrow is 499 well-defined for any bit-size of integer. However, b2i always generates a 500 32-bit result so it could end up replacing a 64-bit expression with one 501 that takes two 64-bit values and produces a 32-bit value. As another 502 example, consider this expression: 503 504 (('bcsel', a, b, 0), ('iand', a, b)) 505 506 In this case, in the search expression a must be 32-bit but b can 507 potentially have any bit size. If we had a 64-bit b value, we would end up 508 trying to and a 32-bit value with a 64-bit value which would be invalid 509 510 This class solves that problem by providing a validation layer that proves 511 that a given search-and-replace operation is 100% well-defined before we 512 generate any code. This ensures that bugs are caught at compile time 513 rather than at run time. 514 515 Each value maintains a "bit-size class", which is either an actual bit size 516 or an equivalence class with other values that must have the same bit size. 517 The validator works by combining bit-size classes with each other according 518 to the NIR rules outlined above, checking that there are no inconsistencies. 519 When doing this for the replacement expression, we make sure to never change 520 the equivalence class of any of the search values. We could make the example 521 transforms above work by doing some extra run-time checking of the search 522 expression, but we make the user specify those constraints themselves, to 523 avoid any surprises. Since the replacement bitsizes can only be connected to 524 the source bitsize via variables (variables must have the same bitsize in 525 the source and replacment expressions) or the roots of the expression (the 526 replacement expression must produce the same bit size as the search 527 expression), we prevent merging a variable with anything when processing the 528 replacement expression, or specializing the search bitsize 529 with anything. The former prevents 530 531 (('bcsel', a, b, 0), ('iand', a, b)) 532 533 from being allowed, since we'd have to merge the bitsizes for a and b due to 534 the 'iand', while the latter prevents 535 536 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 537 538 from being allowed, since the search expression has the bit size of a and b, 539 which can't be specialized to 32 which is the bitsize of the replace 540 expression. It also prevents something like: 541 542 (('b2i', ('i2b', a)), ('ineq', a, 0)) 543 544 since the bitsize of 'b2i', which can be anything, can't be specialized to 545 the bitsize of a. 546 547 After doing all this, we check that every subexpression of the replacement 548 was assigned a constant bitsize, the bitsize of a variable, or the bitsize 549 of the search expresssion, since those are the things that are known when 550 constructing the replacement expresssion. Finally, we record the bitsize 551 needed in nir_search_value so that we know what to do when building the 552 replacement expression. 553 """ 554 555 def __init__(self, varset): 556 self._var_classes = [None] * len(varset.names) 557 558 def compare_bitsizes(self, a, b): 559 """Determines which bitsize class is a specialization of the other, or 560 whether neither is. When we merge two different bitsizes, the 561 less-specialized bitsize always points to the more-specialized one, so 562 that calling get_bit_size() always gets you the most specialized bitsize. 563 The specialization partial order is given by: 564 - Physical bitsizes are always the most specialized, and a different 565 bitsize can never specialize another. 566 - In the search expression, variables can always be specialized to each 567 other and to physical bitsizes. In the replace expression, we disallow 568 this to avoid adding extra constraints to the search expression that 569 the user didn't specify. 570 - Expressions and constants without a bitsize can always be specialized to 571 each other and variables, but not the other way around. 572 573 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 574 and None if they are not comparable (neither a <= b nor b <= a). 575 """ 576 if isinstance(a, int): 577 if isinstance(b, int): 578 return 0 if a == b else None 579 elif isinstance(b, Variable): 580 return -1 if self.is_search else None 581 else: 582 return -1 583 elif isinstance(a, Variable): 584 if isinstance(b, int): 585 return 1 if self.is_search else None 586 elif isinstance(b, Variable): 587 return 0 if self.is_search or a.index == b.index else None 588 else: 589 return -1 590 else: 591 if isinstance(b, int): 592 return 1 593 elif isinstance(b, Variable): 594 return 1 595 else: 596 return 0 597 598 def unify_bit_size(self, a, b, error_msg): 599 """Record that a must have the same bit-size as b. If both 600 have been assigned conflicting physical bit-sizes, call "error_msg" with 601 the bit-sizes of self and other to get a message and raise an error. 602 In the replace expression, disallow merging variables with other 603 variables and physical bit-sizes as well. 604 """ 605 a_bit_size = a.get_bit_size() 606 b_bit_size = b if isinstance(b, int) else b.get_bit_size() 607 608 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 609 610 assert cmp_result is not None, \ 611 error_msg(a_bit_size, b_bit_size) 612 613 if cmp_result < 0: 614 b_bit_size.set_bit_size(a) 615 elif not isinstance(a_bit_size, int): 616 a_bit_size.set_bit_size(b) 617 618 def merge_variables(self, val): 619 """Perform the first part of type inference by merging all the different 620 uses of the same variable. We always do this as if we're in the search 621 expression, even if we're actually not, since otherwise we'd get errors 622 if the search expression specified some constraint but the replace 623 expression didn't, because we'd be merging a variable and a constant. 624 """ 625 if isinstance(val, Variable): 626 if self._var_classes[val.index] is None: 627 self._var_classes[val.index] = val 628 else: 629 other = self._var_classes[val.index] 630 self.unify_bit_size(other, val, 631 lambda other_bit_size, bit_size: 632 'Variable {} has conflicting bit size requirements: ' \ 633 'it must have bit size {} and {}'.format( 634 val.var_name, other_bit_size, bit_size)) 635 elif isinstance(val, Expression): 636 for src in val.sources: 637 self.merge_variables(src) 638 639 def validate_value(self, val): 640 """Validate the an expression by performing classic Hindley-Milner 641 type inference on bitsizes. This will detect if there are any conflicting 642 requirements, and unify variables so that we know which variables must 643 have the same bitsize. If we're operating on the replace expression, we 644 will refuse to merge different variables together or merge a variable 645 with a constant, in order to prevent surprises due to rules unexpectedly 646 not matching at runtime. 647 """ 648 if not isinstance(val, Expression): 649 return 650 651 # Generic conversion ops are special in that they have a single unsized 652 # source and an unsized destination and the two don't have to match. 653 # This means there's no validation or unioning to do here besides the 654 # len(val.sources) check. 655 if val.opcode in conv_opcode_types: 656 assert len(val.sources) == 1, \ 657 "Expression {} has {} sources, expected 1".format( 658 val, len(val.sources)) 659 self.validate_value(val.sources[0]) 660 return 661 662 nir_op = opcodes[val.opcode] 663 assert len(val.sources) == nir_op.num_inputs, \ 664 "Expression {} has {} sources, expected {}".format( 665 val, len(val.sources), nir_op.num_inputs) 666 667 for src in val.sources: 668 self.validate_value(src) 669 670 dst_type_bits = type_bits(nir_op.output_type) 671 672 # First, unify all the sources. That way, an error coming up because two 673 # sources have an incompatible bit-size won't produce an error message 674 # involving the destination. 675 first_unsized_src = None 676 for src_type, src in zip(nir_op.input_types, val.sources): 677 src_type_bits = type_bits(src_type) 678 if src_type_bits == 0: 679 if first_unsized_src is None: 680 first_unsized_src = src 681 continue 682 683 if self.is_search: 684 self.unify_bit_size(first_unsized_src, src, 685 lambda first_unsized_src_bit_size, src_bit_size: 686 'Source {} of {} must have bit size {}, while source {} ' \ 687 'must have incompatible bit size {}'.format( 688 first_unsized_src, val, first_unsized_src_bit_size, 689 src, src_bit_size)) 690 else: 691 self.unify_bit_size(first_unsized_src, src, 692 lambda first_unsized_src_bit_size, src_bit_size: 693 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 694 'of {} may not have the same bit size when building the ' \ 695 'replacement expression.'.format( 696 first_unsized_src, first_unsized_src_bit_size, src, 697 src_bit_size, val)) 698 else: 699 if self.is_search: 700 self.unify_bit_size(src, src_type_bits, 701 lambda src_bit_size, unused: 702 '{} must have {} bits, but as a source of nir_op_{} '\ 703 'it must have {} bits'.format( 704 src, src_bit_size, nir_op.name, src_type_bits)) 705 else: 706 self.unify_bit_size(src, src_type_bits, 707 lambda src_bit_size, unused: 708 '{} has the bit size of {}, but as a source of ' \ 709 'nir_op_{} it must have {} bits, which may not be the ' \ 710 'same'.format( 711 src, src_bit_size, nir_op.name, src_type_bits)) 712 713 if dst_type_bits == 0: 714 if first_unsized_src is not None: 715 if self.is_search: 716 self.unify_bit_size(val, first_unsized_src, 717 lambda val_bit_size, src_bit_size: 718 '{} must have the bit size of {}, while its source {} ' \ 719 'must have incompatible bit size {}'.format( 720 val, val_bit_size, first_unsized_src, src_bit_size)) 721 else: 722 self.unify_bit_size(val, first_unsized_src, 723 lambda val_bit_size, src_bit_size: 724 '{} must have {} bits, but its source {} ' \ 725 '(bit size of {}) may not have that bit size ' \ 726 'when building the replacement.'.format( 727 val, val_bit_size, first_unsized_src, src_bit_size)) 728 else: 729 self.unify_bit_size(val, dst_type_bits, 730 lambda dst_bit_size, unused: 731 '{} must have {} bits, but as a destination of nir_op_{} ' \ 732 'it must have {} bits'.format( 733 val, dst_bit_size, nir_op.name, dst_type_bits)) 734 735 def validate_replace(self, val, search): 736 bit_size = val.get_bit_size() 737 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 738 bit_size == search.get_bit_size(), \ 739 'Ambiguous bit size for replacement value {}: ' \ 740 'it cannot be deduced from a variable, a fixed bit size ' \ 741 'somewhere, or the search expression.'.format(val) 742 743 if isinstance(val, Expression): 744 for src in val.sources: 745 self.validate_replace(src, search) 746 elif isinstance(val, Variable): 747 # These catch problems when someone copies and pastes the search 748 # into the replacement. 749 assert not val.is_constant, \ 750 'Replacement variables must not be marked constant.' 751 752 assert val.cond_index == -1, \ 753 'Replacement variables must not have a condition.' 754 755 assert not val.required_type, \ 756 'Replacement variables must not have a required type.' 757 758 def validate(self, search, replace): 759 self.is_search = True 760 self.merge_variables(search) 761 self.merge_variables(replace) 762 self.validate_value(search) 763 764 self.is_search = False 765 self.validate_value(replace) 766 767 # Check that search is always more specialized than replace. Note that 768 # we're doing this in replace mode, disallowing merging variables. 769 search_bit_size = search.get_bit_size() 770 replace_bit_size = replace.get_bit_size() 771 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 772 773 assert cmp_result is not None and cmp_result <= 0, \ 774 'The search expression bit size {} and replace expression ' \ 775 'bit size {} may not be the same'.format( 776 search_bit_size, replace_bit_size) 777 778 replace.set_bit_size(search) 779 780 self.validate_replace(replace, search) 781 782_optimization_ids = itertools.count() 783 784condition_list = ['true'] 785 786class SearchAndReplace(object): 787 def __init__(self, transform, algebraic_pass): 788 self.id = next(_optimization_ids) 789 790 search = transform[0] 791 replace = transform[1] 792 if len(transform) > 2: 793 self.condition = transform[2] 794 else: 795 self.condition = 'true' 796 797 if self.condition not in condition_list: 798 condition_list.append(self.condition) 799 self.condition_index = condition_list.index(self.condition) 800 801 varset = VarSet() 802 if isinstance(search, Expression): 803 self.search = search 804 else: 805 self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass) 806 807 varset.lock() 808 809 if isinstance(replace, Value): 810 self.replace = replace 811 else: 812 self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass) 813 814 BitSizeValidator(varset).validate(self.search, self.replace) 815 816class TreeAutomaton(object): 817 """This class calculates a bottom-up tree automaton to quickly search for 818 the left-hand sides of tranforms. Tree automatons are a generalization of 819 classical NFA's and DFA's, where the transition function determines the 820 state of the parent node based on the state of its children. We construct a 821 deterministic automaton to match patterns, using a similar algorithm to the 822 classical NFA to DFA construction. At the moment, it only matches opcodes 823 and constants (without checking the actual value), leaving more detailed 824 checking to the search function which actually checks the leaves. The 825 automaton acts as a quick filter for the search function, requiring only n 826 + 1 table lookups for each n-source operation. The implementation is based 827 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 828 In the language of that reference, this is a frontier-to-root deterministic 829 automaton using only symbol filtering. The filtering is crucial to reduce 830 both the time taken to generate the tables and the size of the tables. 831 """ 832 def __init__(self, transforms): 833 self.patterns = [t.search for t in transforms] 834 self._compute_items() 835 self._build_table() 836 #print('num items: {}'.format(len(set(self.items.values())))) 837 #print('num states: {}'.format(len(self.states))) 838 #for state, patterns in zip(self.states, self.patterns): 839 # print('{}: num patterns: {}'.format(state, len(patterns))) 840 841 class IndexMap(object): 842 """An indexed list of objects, where one can either lookup an object by 843 index or find the index associated to an object quickly using a hash 844 table. Compared to a list, it has a constant time index(). Compared to a 845 set, it provides a stable iteration order. 846 """ 847 def __init__(self, iterable=()): 848 self.objects = [] 849 self.map = {} 850 for obj in iterable: 851 self.add(obj) 852 853 def __getitem__(self, i): 854 return self.objects[i] 855 856 def __contains__(self, obj): 857 return obj in self.map 858 859 def __len__(self): 860 return len(self.objects) 861 862 def __iter__(self): 863 return iter(self.objects) 864 865 def clear(self): 866 self.objects = [] 867 self.map.clear() 868 869 def index(self, obj): 870 return self.map[obj] 871 872 def add(self, obj): 873 if obj in self.map: 874 return self.map[obj] 875 else: 876 index = len(self.objects) 877 self.objects.append(obj) 878 self.map[obj] = index 879 return index 880 881 def __repr__(self): 882 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 883 884 class Item(object): 885 """This represents an "item" in the language of "Tree Automatons." This 886 is just a subtree of some pattern, which represents a potential partial 887 match at runtime. We deduplicate them, so that identical subtrees of 888 different patterns share the same object, and store some extra 889 information needed for the main algorithm as well. 890 """ 891 def __init__(self, opcode, children): 892 self.opcode = opcode 893 self.children = children 894 # These are the indices of patterns for which this item is the root node. 895 self.patterns = [] 896 # This the set of opcodes for parents of this item. Used to speed up 897 # filtering. 898 self.parent_ops = set() 899 900 def __str__(self): 901 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 902 903 def __repr__(self): 904 return str(self) 905 906 def _compute_items(self): 907 """Build a set of all possible items, deduplicating them.""" 908 # This is a map from (opcode, sources) to item. 909 self.items = {} 910 911 # The set of all opcodes used by the patterns. Used later to avoid 912 # building and emitting all the tables for opcodes that aren't used. 913 self.opcodes = self.IndexMap() 914 915 def get_item(opcode, children, pattern=None): 916 commutative = len(children) >= 2 \ 917 and "2src_commutative" in opcodes[opcode].algebraic_properties 918 item = self.items.setdefault((opcode, children), 919 self.Item(opcode, children)) 920 if commutative: 921 self.items[opcode, (children[1], children[0]) + children[2:]] = item 922 if pattern is not None: 923 item.patterns.append(pattern) 924 return item 925 926 self.wildcard = get_item("__wildcard", ()) 927 self.const = get_item("__const", ()) 928 929 def process_subpattern(src, pattern=None): 930 if isinstance(src, Constant): 931 # Note: we throw away the actual constant value! 932 return self.const 933 elif isinstance(src, Variable): 934 if src.is_constant: 935 return self.const 936 else: 937 # Note: we throw away which variable it is here! This special 938 # item is equivalent to nu in "Tree Automatons." 939 return self.wildcard 940 else: 941 assert isinstance(src, Expression) 942 opcode = src.opcode 943 stripped = opcode.rstrip('0123456789') 944 if stripped in conv_opcode_types: 945 # Matches that use conversion opcodes with a specific type, 946 # like f2i1, are tricky. Either we construct the automaton to 947 # match specific NIR opcodes like nir_op_f2i1, in which case we 948 # need to create separate items for each possible NIR opcode 949 # for patterns that have a generic opcode like f2i, or we 950 # construct it to match the search opcode, in which case we 951 # need to map f2i1 to f2i when constructing the automaton. Here 952 # we do the latter. 953 opcode = stripped 954 self.opcodes.add(opcode) 955 children = tuple(process_subpattern(c) for c in src.sources) 956 item = get_item(opcode, children, pattern) 957 for i, child in enumerate(children): 958 child.parent_ops.add(opcode) 959 return item 960 961 for i, pattern in enumerate(self.patterns): 962 process_subpattern(pattern, i) 963 964 def _build_table(self): 965 """This is the core algorithm which builds up the transition table. It 966 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 967 Comp_a and Filt_{a,i} using integers to identify match sets." It 968 simultaneously builds up a list of all possible "match sets" or 969 "states", where each match set represents the set of Item's that match a 970 given instruction, and builds up the transition table between states. 971 """ 972 # Map from opcode + filtered state indices to transitioned state. 973 self.table = defaultdict(dict) 974 # Bijection from state to index. q in the original algorithm is 975 # len(self.states) 976 self.states = self.IndexMap() 977 # Lists of pattern matches separated by None 978 self.state_patterns = [None] 979 # Offset in the ->transforms table for each state index 980 self.state_pattern_offsets = [] 981 # Map from state index to filtered state index for each opcode. 982 self.filter = defaultdict(list) 983 # Bijections from filtered state to filtered state index for each 984 # opcode, called the "representor sets" in the original algorithm. 985 # q_{a,j} in the original algorithm is len(self.rep[op]). 986 self.rep = defaultdict(self.IndexMap) 987 988 # Everything in self.states with a index at least worklist_index is part 989 # of the worklist of newly created states. There is also a worklist of 990 # newly fitered states for each opcode, for which worklist_indices 991 # serves a similar purpose. worklist_index corresponds to p in the 992 # original algorithm, while worklist_indices is p_{a,j} (although since 993 # we only filter by opcode/symbol, it's really just p_a). 994 self.worklist_index = 0 995 worklist_indices = defaultdict(lambda: 0) 996 997 # This is the set of opcodes for which the filtered worklist is non-empty. 998 # It's used to avoid scanning opcodes for which there is nothing to 999 # process when building the transition table. It corresponds to new_a in 1000 # the original algorithm. 1001 new_opcodes = self.IndexMap() 1002 1003 # Process states on the global worklist, filtering them for each opcode, 1004 # updating the filter tables, and updating the filtered worklists if any 1005 # new filtered states are found. Similar to ComputeRepresenterSets() in 1006 # the original algorithm, although that only processes a single state. 1007 def process_new_states(): 1008 while self.worklist_index < len(self.states): 1009 state = self.states[self.worklist_index] 1010 # Calculate pattern matches for this state. Each pattern is 1011 # assigned to a unique item, so we don't have to worry about 1012 # deduplicating them here. However, we do have to sort them so 1013 # that they're visited at runtime in the order they're specified 1014 # in the source. 1015 patterns = list(sorted(p for item in state for p in item.patterns)) 1016 1017 if patterns: 1018 # Add our patterns to the global table. 1019 self.state_pattern_offsets.append(len(self.state_patterns)) 1020 self.state_patterns.extend(patterns) 1021 self.state_patterns.append(None) 1022 else: 1023 # Point to the initial sentinel in the global table. 1024 self.state_pattern_offsets.append(0) 1025 1026 # calculate filter table for this state, and update filtered 1027 # worklists. 1028 for op in self.opcodes: 1029 filt = self.filter[op] 1030 rep = self.rep[op] 1031 filtered = frozenset(item for item in state if \ 1032 op in item.parent_ops) 1033 if filtered in rep: 1034 rep_index = rep.index(filtered) 1035 else: 1036 rep_index = rep.add(filtered) 1037 new_opcodes.add(op) 1038 assert len(filt) == self.worklist_index 1039 filt.append(rep_index) 1040 self.worklist_index += 1 1041 1042 # There are two start states: one which can only match as a wildcard, 1043 # and one which can match as a wildcard or constant. These will be the 1044 # states of intrinsics/other instructions and load_const instructions, 1045 # respectively. The indices of these must match the definitions of 1046 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 1047 # initialize things correctly. 1048 self.states.add(frozenset((self.wildcard,))) 1049 self.states.add(frozenset((self.const,self.wildcard))) 1050 process_new_states() 1051 1052 while len(new_opcodes) > 0: 1053 for op in new_opcodes: 1054 rep = self.rep[op] 1055 table = self.table[op] 1056 op_worklist_index = worklist_indices[op] 1057 if op in conv_opcode_types: 1058 num_srcs = 1 1059 else: 1060 num_srcs = opcodes[op].num_inputs 1061 1062 # Iterate over all possible source combinations where at least one 1063 # is on the worklist. 1064 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 1065 if all(src_idx < op_worklist_index for src_idx in src_indices): 1066 continue 1067 1068 srcs = tuple(rep[src_idx] for src_idx in src_indices) 1069 1070 # Try all possible pairings of source items and add the 1071 # corresponding parent items. This is Comp_a from the paper. 1072 parent = set(self.items[op, item_srcs] for item_srcs in 1073 itertools.product(*srcs) if (op, item_srcs) in self.items) 1074 1075 # We could always start matching something else with a 1076 # wildcard. This is Cl from the paper. 1077 parent.add(self.wildcard) 1078 1079 table[src_indices] = self.states.add(frozenset(parent)) 1080 worklist_indices[op] = len(rep) 1081 new_opcodes.clear() 1082 process_new_states() 1083 1084_algebraic_pass_template = mako.template.Template(""" 1085#include "nir.h" 1086#include "nir_builder.h" 1087#include "nir_search.h" 1088#include "nir_search_helpers.h" 1089 1090/* What follows is NIR algebraic transform code for the following ${len(xforms)} 1091 * transforms: 1092% for xform in xforms: 1093 * ${xform.search} => ${xform.replace} 1094% endfor 1095 */ 1096 1097<% cache = {"next_index": 0} %> 1098static const nir_search_value_union ${pass_name}_values[] = { 1099% for xform in xforms: 1100 /* ${xform.search} => ${xform.replace} */ 1101${xform.search.render(cache)} 1102${xform.replace.render(cache)} 1103% endfor 1104}; 1105 1106% if expression_cond: 1107static const nir_search_expression_cond ${pass_name}_expression_cond[] = { 1108% for cond in expression_cond: 1109 ${cond[0]}, 1110% endfor 1111}; 1112% endif 1113 1114% if variable_cond: 1115static const nir_search_variable_cond ${pass_name}_variable_cond[] = { 1116% for cond in variable_cond: 1117 ${cond[0]}, 1118% endfor 1119}; 1120% endif 1121 1122static const struct transform ${pass_name}_transforms[] = { 1123% for i in automaton.state_patterns: 1124% if i is not None: 1125 { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} }, 1126% else: 1127 { ~0, ~0, ~0 }, /* Sentinel */ 1128 1129% endif 1130% endfor 1131}; 1132 1133static const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = { 1134% for op in automaton.opcodes: 1135 [${get_c_opcode(op)}] = { 1136% if all(e == 0 for e in automaton.filter[op]): 1137 .filter = NULL, 1138% else: 1139 .filter = (const uint16_t []) { 1140 % for e in automaton.filter[op]: 1141 ${e}, 1142 % endfor 1143 }, 1144% endif 1145 <% 1146 num_filtered = len(automaton.rep[op]) 1147 %> 1148 .num_filtered_states = ${num_filtered}, 1149 .table = (const uint16_t []) { 1150 <% 1151 num_srcs = len(next(iter(automaton.table[op]))) 1152 %> 1153 % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1154 ${automaton.table[op][indices]}, 1155 % endfor 1156 }, 1157 }, 1158% endfor 1159}; 1160 1161/* Mapping from state index to offset in transforms (0 being no transforms) */ 1162static const uint16_t ${pass_name}_transform_offsets[] = { 1163% for offset in automaton.state_pattern_offsets: 1164 ${offset}, 1165% endfor 1166}; 1167 1168static const nir_algebraic_table ${pass_name}_table = { 1169 .transforms = ${pass_name}_transforms, 1170 .transform_offsets = ${pass_name}_transform_offsets, 1171 .pass_op_table = ${pass_name}_pass_op_table, 1172 .values = ${pass_name}_values, 1173 .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" }, 1174 .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" }, 1175}; 1176 1177bool 1178${pass_name}( 1179 nir_shader *shader 1180% for type, name in params: 1181 , ${type} ${name} 1182% endfor 1183) { 1184 bool progress = false; 1185 bool condition_flags[${len(condition_list)}]; 1186 const nir_shader_compiler_options *options = shader->options; 1187 const shader_info *info = &shader->info; 1188 (void) options; 1189 (void) info; 1190 1191 STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values)); 1192 % for index, condition in enumerate(condition_list): 1193 condition_flags[${index}] = ${condition}; 1194 % endfor 1195 1196 nir_foreach_function_impl(impl, shader) { 1197 progress |= nir_algebraic_impl(impl, condition_flags, &${pass_name}_table); 1198 } 1199 1200 return progress; 1201} 1202""") 1203 1204 1205class AlgebraicPass(object): 1206 # params is a list of `("type", "name")` tuples 1207 def __init__(self, pass_name, transforms, params=[]): 1208 self.xforms = [] 1209 self.opcode_xforms = defaultdict(lambda : []) 1210 self.pass_name = pass_name 1211 self.expression_cond = {} 1212 self.variable_cond = {} 1213 self.params = params 1214 1215 error = False 1216 1217 for xform in transforms: 1218 if not isinstance(xform, SearchAndReplace): 1219 try: 1220 xform = SearchAndReplace(xform, self) 1221 except: 1222 print("Failed to parse transformation:", file=sys.stderr) 1223 print(" " + str(xform), file=sys.stderr) 1224 traceback.print_exc(file=sys.stderr) 1225 print('', file=sys.stderr) 1226 error = True 1227 continue 1228 1229 self.xforms.append(xform) 1230 if xform.search.opcode in conv_opcode_types: 1231 dst_type = conv_opcode_types[xform.search.opcode] 1232 for size in type_sizes(dst_type): 1233 sized_opcode = xform.search.opcode + str(size) 1234 self.opcode_xforms[sized_opcode].append(xform) 1235 else: 1236 self.opcode_xforms[xform.search.opcode].append(xform) 1237 1238 # Check to make sure the search pattern does not unexpectedly contain 1239 # more commutative expressions than match_expression (nir_search.c) 1240 # can handle. 1241 comm_exprs = xform.search.comm_exprs 1242 1243 if xform.search.many_commutative_expressions: 1244 if comm_exprs <= nir_search_max_comm_ops: 1245 print("Transform expected to have too many commutative " \ 1246 "expression but did not " \ 1247 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 1248 file=sys.stderr) 1249 print(" " + str(xform), file=sys.stderr) 1250 traceback.print_exc(file=sys.stderr) 1251 print('', file=sys.stderr) 1252 error = True 1253 else: 1254 if comm_exprs > nir_search_max_comm_ops: 1255 print("Transformation with too many commutative expressions " \ 1256 "({} > {}). Modify pattern or annotate with " \ 1257 "\"many-comm-expr\".".format(comm_exprs, 1258 nir_search_max_comm_ops), 1259 file=sys.stderr) 1260 print(" " + str(xform.search), file=sys.stderr) 1261 print("{}".format(xform.search.cond), file=sys.stderr) 1262 error = True 1263 1264 self.automaton = TreeAutomaton(self.xforms) 1265 1266 if error: 1267 sys.exit(1) 1268 1269 1270 def render(self): 1271 return _algebraic_pass_template.render(pass_name=self.pass_name, 1272 xforms=self.xforms, 1273 opcode_xforms=self.opcode_xforms, 1274 condition_list=condition_list, 1275 automaton=self.automaton, 1276 expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]), 1277 variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]), 1278 get_c_opcode=get_c_opcode, 1279 itertools=itertools, 1280 params=self.params) 1281 1282# The replacement expression isn't necessarily exact if the search expression is exact. 1283def ignore_exact(*expr): 1284 expr = SearchExpression.create(expr) 1285 expr.ignore_exact = True 1286 return expr 1287