1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Functions for proving mathematical properties of expressions.""" 16 17import math 18import fractions 19import operator 20 21from compiler.util import ir_data 22from compiler.util import ir_data_utils 23from compiler.util import ir_util 24from compiler.util import traverse_ir 25 26 27# Create a local alias for math.gcd with a fallback to fractions.gcd if it is 28# not available. This can be dropped if pre-3.5 Python support is dropped. 29if hasattr(math, 'gcd'): 30 _math_gcd = math.gcd 31else: 32 _math_gcd = fractions.gcd 33 34 35def compute_constraints_of_expression(expression, ir): 36 """Adds appropriate bounding constraints to the given expression.""" 37 if ir_util.is_constant_type(expression.type): 38 return 39 expression_variety = expression.WhichOneof("expression") 40 if expression_variety == "constant": 41 _compute_constant_value_of_constant(expression) 42 elif expression_variety == "constant_reference": 43 _compute_constant_value_of_constant_reference(expression, ir) 44 elif expression_variety == "function": 45 _compute_constraints_of_function(expression, ir) 46 elif expression_variety == "field_reference": 47 _compute_constraints_of_field_reference(expression, ir) 48 elif expression_variety == "builtin_reference": 49 _compute_constraints_of_builtin_value(expression) 50 elif expression_variety == "boolean_constant": 51 _compute_constant_value_of_boolean_constant(expression) 52 else: 53 assert False, "Unknown expression variety {!r}".format(expression_variety) 54 if expression.type.WhichOneof("type") == "integer": 55 _assert_integer_constraints(expression) 56 57 58def _compute_constant_value_of_constant(expression): 59 value = expression.constant.value 60 expression.type.integer.modular_value = value 61 expression.type.integer.minimum_value = value 62 expression.type.integer.maximum_value = value 63 expression.type.integer.modulus = "infinity" 64 65 66def _compute_constant_value_of_constant_reference(expression, ir): 67 referred_object = ir_util.find_object( 68 expression.constant_reference.canonical_name, ir) 69 expression = ir_data_utils.builder(expression) 70 if isinstance(referred_object, ir_data.EnumValue): 71 compute_constraints_of_expression(referred_object.value, ir) 72 assert ir_util.is_constant(referred_object.value) 73 new_value = str(ir_util.constant_value(referred_object.value)) 74 expression.type.enumeration.value = new_value 75 elif isinstance(referred_object, ir_data.Field): 76 assert ir_util.field_is_virtual(referred_object), ( 77 "Non-virtual non-enum-value constant reference should have been caught " 78 "in type_check.py") 79 compute_constraints_of_expression(referred_object.read_transform, ir) 80 expression.type.CopyFrom(referred_object.read_transform.type) 81 else: 82 assert False, "Unexpected constant reference type." 83 84 85def _compute_constraints_of_function(expression, ir): 86 """Computes the known constraints of the result of a function.""" 87 for arg in expression.function.args: 88 compute_constraints_of_expression(arg, ir) 89 op = expression.function.function 90 if op in (ir_data.FunctionMapping.ADDITION, ir_data.FunctionMapping.SUBTRACTION): 91 _compute_constraints_of_additive_operator(expression) 92 elif op == ir_data.FunctionMapping.MULTIPLICATION: 93 _compute_constraints_of_multiplicative_operator(expression) 94 elif op in (ir_data.FunctionMapping.EQUALITY, ir_data.FunctionMapping.INEQUALITY, 95 ir_data.FunctionMapping.LESS, ir_data.FunctionMapping.LESS_OR_EQUAL, 96 ir_data.FunctionMapping.GREATER, ir_data.FunctionMapping.GREATER_OR_EQUAL, 97 ir_data.FunctionMapping.AND, ir_data.FunctionMapping.OR): 98 _compute_constant_value_of_comparison_operator(expression) 99 elif op == ir_data.FunctionMapping.CHOICE: 100 _compute_constraints_of_choice_operator(expression) 101 elif op == ir_data.FunctionMapping.MAXIMUM: 102 _compute_constraints_of_maximum_function(expression) 103 elif op == ir_data.FunctionMapping.PRESENCE: 104 _compute_constraints_of_existence_function(expression, ir) 105 elif op in (ir_data.FunctionMapping.UPPER_BOUND, ir_data.FunctionMapping.LOWER_BOUND): 106 _compute_constraints_of_bound_function(expression) 107 else: 108 assert False, "Unknown operator {!r}".format(op) 109 110 111def _compute_constraints_of_existence_function(expression, ir): 112 """Computes the constraints of a $has(field) expression.""" 113 field_path = expression.function.args[0].field_reference.path[-1] 114 field = ir_util.find_object(field_path, ir) 115 compute_constraints_of_expression(field.existence_condition, ir) 116 ir_data_utils.builder(expression).type.CopyFrom(field.existence_condition.type) 117 118 119def _compute_constraints_of_field_reference(expression, ir): 120 """Computes the constraints of a reference to a structure's field.""" 121 field_path = expression.field_reference.path[-1] 122 field = ir_util.find_object(field_path, ir) 123 if isinstance(field, ir_data.Field) and ir_util.field_is_virtual(field): 124 # References to virtual fields should have the virtual field's constraints 125 # copied over. 126 compute_constraints_of_expression(field.read_transform, ir) 127 ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type) 128 return 129 # Non-virtual non-integer fields do not (yet) have constraints. 130 if expression.type.WhichOneof("type") == "integer": 131 # TODO(bolms): These lines will need to change when support is added for 132 # fixed-point types. 133 expression.type.integer.modulus = "1" 134 expression.type.integer.modular_value = "0" 135 type_definition = ir_util.find_parent_object(field_path, ir) 136 if isinstance(field, ir_data.Field): 137 referrent_type = field.type 138 else: 139 referrent_type = field.physical_type_alias 140 if referrent_type.HasField("size_in_bits"): 141 type_size = ir_util.constant_value(referrent_type.size_in_bits) 142 else: 143 field_size = ir_util.constant_value(field.location.size) 144 if field_size is None: 145 type_size = None 146 else: 147 type_size = field_size * type_definition.addressable_unit 148 assert referrent_type.HasField("atomic_type"), field 149 assert not referrent_type.atomic_type.reference.canonical_name.module_file 150 _set_integer_constraints_from_physical_type( 151 expression, referrent_type, type_size) 152 153 154def _set_integer_constraints_from_physical_type( 155 expression, physical_type, type_size): 156 """Copies the integer constraints of an expression from a physical type.""" 157 # SCAFFOLDING HACK: In order to keep changelists manageable, this hardcodes 158 # the ranges for all of the Emboss Prelude integer types. This would break 159 # any user-defined `external` integer types, but that feature isn't fully 160 # implemented in the C++ backend, so it doesn't matter for now. 161 # 162 # Adding the attribute(s) for integer bounds will require new operators: 163 # integer/flooring division, remainder, and exponentiation (2**N, 10**N). 164 # 165 # (Technically, there are a few sets of operators that would work: for 166 # example, just the choice operator `?:` is sufficient, but very ugly. 167 # Bitwise AND, bitshift, and exponentiation would also work, but `10**($bits 168 # >> 2) * 2**($bits & 0b11) - 1` isn't quite as clear as `10**($bits // 4) * 169 # 2**($bits % 4) - 1`, in my (bolms@) opinion.) 170 # 171 # TODO(bolms): Add a scheme for defining integer bounds on user-defined 172 # external types. 173 if type_size is None: 174 # If the type_size is unknown, then we can't actually say anything about the 175 # minimum and maximum values of the type. For UInt, Int, and Bcd, an error 176 # will be thrown during the constraints check stage. 177 expression.type.integer.minimum_value = "-infinity" 178 expression.type.integer.maximum_value = "infinity" 179 return 180 name = tuple(physical_type.atomic_type.reference.canonical_name.object_path) 181 if name == ("UInt",): 182 expression.type.integer.minimum_value = "0" 183 expression.type.integer.maximum_value = str(2**type_size - 1) 184 elif name == ("Int",): 185 expression.type.integer.minimum_value = str(-(2**(type_size - 1))) 186 expression.type.integer.maximum_value = str(2**(type_size - 1) - 1) 187 elif name == ("Bcd",): 188 expression.type.integer.minimum_value = "0" 189 expression.type.integer.maximum_value = str( 190 10**(type_size // 4) * 2**(type_size % 4) - 1) 191 else: 192 assert False, "Unknown integral type " + ".".join(name) 193 194 195def _compute_constraints_of_parameter(parameter): 196 if parameter.type.WhichOneof("type") == "integer": 197 type_size = ir_util.constant_value( 198 parameter.physical_type_alias.size_in_bits) 199 _set_integer_constraints_from_physical_type( 200 parameter, parameter.physical_type_alias, type_size) 201 202 203def _compute_constraints_of_builtin_value(expression): 204 """Computes the constraints of a builtin (like $static_size_in_bits).""" 205 name = expression.builtin_reference.canonical_name.object_path[0] 206 if name == "$static_size_in_bits": 207 expression.type.integer.modulus = "1" 208 expression.type.integer.modular_value = "0" 209 expression.type.integer.minimum_value = "0" 210 # The maximum theoretically-supported size of something is 2**64 bytes, 211 # which is 2**64 * 8 bits. 212 # 213 # Really, $static_size_in_bits is only valid in expressions that have to be 214 # evaluated at compile time anyway, so it doesn't really matter if the 215 # bounds are excessive. 216 expression.type.integer.maximum_value = "infinity" 217 elif name == "$is_statically_sized": 218 # No bounds on a boolean variable. 219 pass 220 elif name == "$logical_value": 221 # $logical_value is the placeholder used in inferred write-through 222 # transformations. 223 # 224 # Only integers (currently) have "real" write-through transformations, but 225 # fields that would otherwise be straight aliases, but which have a 226 # [requires] attribute, are elevated to write-through fields, so that the 227 # [requires] clause can be checked in Write, CouldWriteValue, TryToWrite, 228 # Read, and Ok. 229 if expression.type.WhichOneof("type") == "integer": 230 assert expression.type.integer.modulus 231 assert expression.type.integer.modular_value 232 assert expression.type.integer.minimum_value 233 assert expression.type.integer.maximum_value 234 elif expression.type.WhichOneof("type") == "enumeration": 235 assert expression.type.enumeration.name 236 elif expression.type.WhichOneof("type") == "boolean": 237 pass 238 else: 239 assert False, "Unexpected type for $logical_value" 240 else: 241 assert False, "Unknown builtin " + name 242 243 244def _compute_constant_value_of_boolean_constant(expression): 245 expression.type.boolean.value = expression.boolean_constant.value 246 247 248def _add(a, b): 249 """Adds a and b, where a and b are ints, "infinity", or "-infinity".""" 250 if a in ("infinity", "-infinity"): 251 a, b = b, a 252 if b == "infinity": 253 assert a != "-infinity" 254 return "infinity" 255 if b == "-infinity": 256 assert a != "infinity" 257 return "-infinity" 258 return int(a) + int(b) 259 260 261def _sub(a, b): 262 """Subtracts b from a, where a and b are ints, "infinity", or "-infinity".""" 263 if b == "infinity": 264 return _add(a, "-infinity") 265 if b == "-infinity": 266 return _add(a, "infinity") 267 return _add(a, -int(b)) 268 269 270def _sign(a): 271 """Returns 1 if a > 0, 0 if a == 0, and -1 if a < 0.""" 272 if a == "infinity": 273 return 1 274 if a == "-infinity": 275 return -1 276 if int(a) > 0: 277 return 1 278 if int(a) < 0: 279 return -1 280 return 0 281 282 283def _mul(a, b): 284 """Multiplies a and b, where a and b are ints, "infinity", or "-infinity".""" 285 if _is_infinite(a): 286 a, b = b, a 287 if _is_infinite(b): 288 sign = _sign(a) * _sign(b) 289 if sign > 0: 290 return "infinity" 291 if sign < 0: 292 return "-infinity" 293 return 0 294 return int(a) * int(b) 295 296 297def _is_infinite(a): 298 return a in ("infinity", "-infinity") 299 300 301def _max(a): 302 """Returns max of a, where elements are ints, "infinity", or "-infinity".""" 303 if any(n == "infinity" for n in a): 304 return "infinity" 305 if all(n == "-infinity" for n in a): 306 return "-infinity" 307 return max(int(n) for n in a if not _is_infinite(n)) 308 309 310def _min(a): 311 """Returns min of a, where elements are ints, "infinity", or "-infinity".""" 312 if any(n == "-infinity" for n in a): 313 return "-infinity" 314 if all(n == "infinity" for n in a): 315 return "infinity" 316 return min(int(n) for n in a if not _is_infinite(n)) 317 318 319def _compute_constraints_of_additive_operator(expression): 320 """Computes the modular value of an additive expression.""" 321 funcs = { 322 ir_data.FunctionMapping.ADDITION: _add, 323 ir_data.FunctionMapping.SUBTRACTION: _sub, 324 } 325 func = funcs[expression.function.function] 326 args = expression.function.args 327 for arg in args: 328 assert arg.type.integer.modular_value, str(expression) 329 left, right = args 330 unadjusted_modular_value = func(left.type.integer.modular_value, 331 right.type.integer.modular_value) 332 new_modulus = _greatest_common_divisor(left.type.integer.modulus, 333 right.type.integer.modulus) 334 expression.type.integer.modulus = str(new_modulus) 335 if new_modulus == "infinity": 336 expression.type.integer.modular_value = str(unadjusted_modular_value) 337 else: 338 expression.type.integer.modular_value = str(unadjusted_modular_value % 339 new_modulus) 340 lmax = left.type.integer.maximum_value 341 lmin = left.type.integer.minimum_value 342 if expression.function.function == ir_data.FunctionMapping.SUBTRACTION: 343 rmax = right.type.integer.minimum_value 344 rmin = right.type.integer.maximum_value 345 else: 346 rmax = right.type.integer.maximum_value 347 rmin = right.type.integer.minimum_value 348 expression.type.integer.minimum_value = str(func(lmin, rmin)) 349 expression.type.integer.maximum_value = str(func(lmax, rmax)) 350 351 352def _compute_constraints_of_multiplicative_operator(expression): 353 """Computes the modular value of a multiplicative expression.""" 354 bounds = [arg.type.integer for arg in expression.function.args] 355 356 # The minimum and maximum values can come from any of the four pairings of 357 # (left min, left max) with (right min, right max), depending on the signs and 358 # magnitudes of the minima and maxima. E.g.: 359 # 360 # max = left max * right max: [ 2, 3] * [ 2, 3] 361 # max = left min * right min: [-3, -2] * [-3, -2] 362 # max = left max * right min: [-3, -2] * [ 2, 3] 363 # max = left min * right max: [ 2, 3] * [-3, -2] 364 # max = left max * right max: [-2, 3] * [-2, 3] 365 # max = left min * right min: [-3, 2] * [-3, 2] 366 # 367 # For uncorrelated multiplication, the minimum and maximum will always come 368 # from multiplying one extreme by another: if x is nonzero, then 369 # 370 # (y + e) * x > y * x || (y - e) * x > y * x 371 # 372 # for arbitrary nonzero e, so the extrema can only occur when we either cannot 373 # add or cannot subtract e. 374 # 375 # Correlated multiplication (e.g., `x * x`) can have tighter bounds, but 376 # Emboss is not currently trying to be that smart. 377 lmin, lmax = bounds[0].minimum_value, bounds[0].maximum_value 378 rmin, rmax = bounds[1].minimum_value, bounds[1].maximum_value 379 extrema = [_mul(lmax, rmax), _mul(lmin, rmax), # 380 _mul(lmax, rmin), _mul(lmin, rmin)] 381 expression.type.integer.minimum_value = str(_min(extrema)) 382 expression.type.integer.maximum_value = str(_max(extrema)) 383 384 if all(bound.modulus == "infinity" for bound in bounds): 385 # If both sides are constant, the result is constant. 386 expression.type.integer.modulus = "infinity" 387 expression.type.integer.modular_value = str(int(bounds[0].modular_value) * 388 int(bounds[1].modular_value)) 389 return 390 391 if any(bound.modulus == "infinity" for bound in bounds): 392 # If one side is constant and the other is not, then the non-constant 393 # modulus and modular_value can both be multiplied by the constant. E.g., 394 # if `a` is congruent to 3 mod 5, then `4 * a` will be congruent to 12 mod 395 # 20: 396 # 397 # a = ... | 4 * a = ... | 4 * a mod 20 = ... 398 # 3 | 12 | 12 399 # 8 | 32 | 12 400 # 13 | 52 | 12 401 # 18 | 72 | 12 402 # 23 | 92 | 12 403 # 28 | 112 | 12 404 # 33 | 132 | 12 405 # 406 # This is trivially shown by noting that the difference between consecutive 407 # possible values for `4 * a` always differ by 20. 408 if bounds[0].modulus == "infinity": 409 constant, variable = bounds 410 else: 411 variable, constant = bounds 412 if int(constant.modular_value) == 0: 413 # If the constant is 0, the result is 0, no matter what the variable side 414 # is. 415 expression.type.integer.modulus = "infinity" 416 expression.type.integer.modular_value = "0" 417 return 418 new_modulus = int(variable.modulus) * abs(int(constant.modular_value)) 419 expression.type.integer.modulus = str(new_modulus) 420 # The `% new_modulus` will force the `modular_value` to be positive, even 421 # when `constant.modular_value` is negative. 422 expression.type.integer.modular_value = str( 423 int(variable.modular_value) * int(constant.modular_value) % new_modulus) 424 return 425 426 # If neither side is constant, then the result is more complex. Full proof is 427 # available in g3doc/modular_congruence_multiplication_proof.md 428 # 429 # Essentially, if: 430 # 431 # l == _ * l_mod + l_mv 432 # r == _ * r_mod + r_mv 433 # 434 # Then we find l_mod0 and r_mod0 in: 435 # 436 # l == (_ * l_mod_nz + l_mv_nz) * l_mod0 437 # r == (_ * r_mod_nz + r_mv_nz) * r_mod0 438 # 439 # And finally conclude: 440 # 441 # l * r == _ * GCD(l_mod_nz, r_mod_nz) * l_mod0 * r_mod0 + l_mv * r_mv 442 product_of_zero_congruence_moduli = 1 443 product_of_modular_values = 1 444 nonzero_congruence_moduli = [] 445 for bound in bounds: 446 zero_congruence_modulus = _greatest_common_divisor(bound.modulus, 447 bound.modular_value) 448 assert int(bound.modulus) % zero_congruence_modulus == 0 449 product_of_zero_congruence_moduli *= zero_congruence_modulus 450 product_of_modular_values *= int(bound.modular_value) 451 nonzero_congruence_moduli.append(int(bound.modulus) // 452 zero_congruence_modulus) 453 shared_nonzero_congruence_modulus = _greatest_common_divisor( 454 nonzero_congruence_moduli[0], nonzero_congruence_moduli[1]) 455 final_modulus = (shared_nonzero_congruence_modulus * 456 product_of_zero_congruence_moduli) 457 expression.type.integer.modulus = str(final_modulus) 458 expression.type.integer.modular_value = str(product_of_modular_values % 459 final_modulus) 460 461 462def _assert_integer_constraints(expression): 463 """Asserts that the integer bounds of expression are self-consistent. 464 465 Asserts that `minimum_value` and `maximum_value` are congruent to 466 `modular_value` modulo `modulus`. 467 468 If `modulus` is "infinity", asserts that `minimum_value`, `maximum_value`, and 469 `modular_value` are all equal. 470 471 If `minimum_value` is equal to `maximum_value`, asserts that `modular_value` 472 is equal to both, and that `modulus` is "infinity". 473 474 Arguments: 475 expression: an expression with type.integer 476 477 Returns: 478 None 479 """ 480 bounds = expression.type.integer 481 if bounds.modulus == "infinity": 482 assert bounds.minimum_value == bounds.modular_value 483 assert bounds.maximum_value == bounds.modular_value 484 return 485 modulus = int(bounds.modulus) 486 assert modulus > 0 487 if bounds.minimum_value != "-infinity": 488 assert int(bounds.minimum_value) % modulus == int(bounds.modular_value) 489 if bounds.maximum_value != "infinity": 490 assert int(bounds.maximum_value) % modulus == int(bounds.modular_value) 491 if bounds.minimum_value == bounds.maximum_value: 492 # TODO(bolms): I believe there are situations using the not-yet-implemented 493 # integer division operator that would trigger these asserts, so they should 494 # be turned into assignments (with corresponding tests) when implementing 495 # division. 496 assert bounds.modular_value == bounds.minimum_value 497 assert bounds.modulus == "infinity" 498 if bounds.minimum_value != "-infinity" and bounds.maximum_value != "infinity": 499 assert int(bounds.minimum_value) <= int(bounds.maximum_value) 500 501 502def _compute_constant_value_of_comparison_operator(expression): 503 """Computes the constant value, if any, of a comparison operator.""" 504 args = expression.function.args 505 if all(ir_util.is_constant(arg) for arg in args): 506 functions = { 507 ir_data.FunctionMapping.EQUALITY: operator.eq, 508 ir_data.FunctionMapping.INEQUALITY: operator.ne, 509 ir_data.FunctionMapping.LESS: operator.lt, 510 ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le, 511 ir_data.FunctionMapping.GREATER: operator.gt, 512 ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge, 513 ir_data.FunctionMapping.AND: operator.and_, 514 ir_data.FunctionMapping.OR: operator.or_, 515 } 516 func = functions[expression.function.function] 517 expression.type.boolean.value = func( 518 *[ir_util.constant_value(arg) for arg in args]) 519 520 521def _compute_constraints_of_bound_function(expression): 522 """Computes the constraints of $upper_bound or $lower_bound.""" 523 if expression.function.function == ir_data.FunctionMapping.UPPER_BOUND: 524 value = expression.function.args[0].type.integer.maximum_value 525 elif expression.function.function == ir_data.FunctionMapping.LOWER_BOUND: 526 value = expression.function.args[0].type.integer.minimum_value 527 else: 528 assert False, "Non-bound function" 529 expression.type.integer.minimum_value = value 530 expression.type.integer.maximum_value = value 531 expression.type.integer.modular_value = value 532 expression.type.integer.modulus = "infinity" 533 534 535def _compute_constraints_of_maximum_function(expression): 536 """Computes the constraints of the $max function.""" 537 assert expression.type.WhichOneof("type") == "integer" 538 args = expression.function.args 539 assert args[0].type.WhichOneof("type") == "integer" 540 # The minimum value of the result occurs when every argument takes its minimum 541 # value, which means that the minimum result is the maximum-of-minimums. 542 expression.type.integer.minimum_value = str(_max( 543 [arg.type.integer.minimum_value for arg in args])) 544 # The maximum result is the maximum-of-maximums. 545 expression.type.integer.maximum_value = str(_max( 546 [arg.type.integer.maximum_value for arg in args])) 547 # If the expression is dominated by a constant factor, then the result is 548 # constant. I (bolms@) believe this is the only case where 549 # _compute_constraints_of_maximum_function might violate the assertions in 550 # _assert_integer_constraints. 551 if (expression.type.integer.minimum_value == 552 expression.type.integer.maximum_value): 553 expression.type.integer.modular_value = ( 554 expression.type.integer.minimum_value) 555 expression.type.integer.modulus = "infinity" 556 return 557 result_modulus = args[0].type.integer.modulus 558 result_modular_value = args[0].type.integer.modular_value 559 # The result of $max(a, b) could be either a or b, which means that the result 560 # of $max(a, b) uses the _shared_modular_value() of a and b, just like the 561 # choice operator '?:'. 562 # 563 # This also takes advantage of the fact that $max(a, b, c, d, ...) is 564 # equivalent to $max(a, $max(b, $max(c, $max(d, ...)))), so it is valid to 565 # call _shared_modular_value() in a loop. 566 for arg in args[1:]: 567 # TODO(bolms): I think the bounds could be tigher in some cases where 568 # arg.maximum_value is less than the new expression.minimum_value, and 569 # in some very specific cases where arg.maximum_value is greater than the 570 # new expression.minimum_value, but arg.maximum_value - arg.modulus is less 571 # than expression.minimum_value. 572 result_modulus, result_modular_value = _shared_modular_value( 573 (result_modulus, result_modular_value), 574 (arg.type.integer.modulus, arg.type.integer.modular_value)) 575 expression.type.integer.modulus = str(result_modulus) 576 expression.type.integer.modular_value = str(result_modular_value) 577 578 579def _shared_modular_value(left, right): 580 """Returns the shared modulus and modular value of left and right. 581 582 Arguments: 583 left: A tuple of (modulus, modular value) 584 right: A tuple of (modulus, modular value) 585 586 Returns: 587 A tuple of (modulus, modular_value) such that: 588 589 left.modulus % result.modulus == 0 590 right.modulus % result.modulus == 0 591 left.modular_value % result.modulus = result.modular_value 592 right.modular_value % result.modulus = result.modular_value 593 594 That is, the result.modulus and result.modular_value will be compatible 595 with, but (possibly) less restrictive than both left.(modulus, 596 modular_value) and right.(modulus, modular_value). 597 """ 598 left_modulus, left_modular_value = left 599 right_modulus, right_modular_value = right 600 # The combined modulus is gcd(gcd(left_modulus, right_modulus), 601 # left_modular_value - right_modular_value). 602 # 603 # The inner gcd normalizes the left_modulus and right_modulus, but can leave 604 # incompatible modular_values. The outer gcd finds a modulus to which both 605 # modular_values are congruent. Some examples: 606 # 607 # left | right | res 608 # --------------+----------------+-------------------- 609 # l % 12 == 7 | r % 20 == 15 | res % 4 == 3 610 # l == 35 | r % 20 == 15 | res % 20 == 15 611 # l % 24 == 15 | r % 12 == 7 | res % 4 == 3 612 # l % 20 == 15 | r % 20 == 10 | res % 5 == 0 613 # l % 20 == 16 | r % 20 == 11 | res % 5 == 1 614 # l == 10 | r == 7 | res % 3 == 1 615 # l == 4 | r == 4 | res == 4 616 # 617 # The cases where one side or the other are constant are handled 618 # automatically by the fact that _greatest_common_divisor("infinity", x) 619 # is x. 620 common_modulus = _greatest_common_divisor(left_modulus, right_modulus) 621 new_modulus = _greatest_common_divisor( 622 common_modulus, abs(int(left_modular_value) - int(right_modular_value))) 623 if new_modulus == "infinity": 624 # The only way for the new_modulus to come out as "infinity" *should* be 625 # if both if_true and if_false have the same constant value. 626 assert left_modular_value == right_modular_value 627 assert left_modulus == right_modulus == "infinity" 628 return new_modulus, left_modular_value 629 else: 630 assert (int(left_modular_value) % new_modulus == 631 int(right_modular_value) % new_modulus) 632 return new_modulus, int(left_modular_value) % new_modulus 633 634 635def _compute_constraints_of_choice_operator(expression): 636 """Computes the constraints of a choice operation '?:'.""" 637 condition, if_true, if_false = ir_data_utils.reader(expression).function.args 638 expression = ir_data_utils.builder(expression) 639 if condition.type.boolean.HasField("value"): 640 # The generated expressions for $size_in_bits and $size_in_bytes look like 641 # 642 # $max((field1_existence_condition ? field1_start + field1_size : 0), 643 # (field2_existence_condition ? field2_start + field2_size : 0), 644 # (field3_existence_condition ? field3_start + field3_size : 0), 645 # ...) 646 # 647 # Since most existence_conditions are just "true", it is important to select 648 # the tighter bounds in those cases -- otherwise, only zero-length 649 # structures could have a constant $size_in_bits or $size_in_bytes. 650 side = if_true if condition.type.boolean.value else if_false 651 expression.type.CopyFrom(side.type) 652 return 653 # The type.integer minimum_value/maximum_value bounding code is needed since 654 # constraints.check_constraints() will complain if minimum and maximum are not 655 # set correctly. I'm (bolms@) not sure if the modulus/modular_value pulls its 656 # weight, but for completeness I've left it in. 657 if if_true.type.WhichOneof("type") == "integer": 658 # The minimum value of the choice is the minimum value of either side, and 659 # the maximum is the maximum value of either side. 660 expression.type.integer.minimum_value = str(_min([ 661 if_true.type.integer.minimum_value, 662 if_false.type.integer.minimum_value])) 663 expression.type.integer.maximum_value = str(_max([ 664 if_true.type.integer.maximum_value, 665 if_false.type.integer.maximum_value])) 666 new_modulus, new_modular_value = _shared_modular_value( 667 (if_true.type.integer.modulus, if_true.type.integer.modular_value), 668 (if_false.type.integer.modulus, if_false.type.integer.modular_value)) 669 expression.type.integer.modulus = str(new_modulus) 670 expression.type.integer.modular_value = str(new_modular_value) 671 else: 672 assert if_true.type.WhichOneof("type") in ("boolean", "enumeration"), ( 673 "Unknown type {} for expression".format( 674 if_true.type.WhichOneof("type"))) 675 676 677def _greatest_common_divisor(a, b): 678 """Returns the greatest common divisor of a and b. 679 680 Arguments: 681 a: an integer, a stringified integer, or the string "infinity" 682 b: an integer, a stringified integer, or the string "infinity" 683 684 Returns: 685 Conceptually, "infinity" is treated as the product of all integers. 686 687 If both a and b are 0, returns "infinity". 688 689 Otherwise, if either a or b are "infinity", and the other is 0, returns 690 "infinity". 691 692 Otherwise, if either a or b are "infinity", returns the other. 693 694 Otherwise, returns the greatest common divisor of a and b. 695 """ 696 if a != "infinity": a = int(a) 697 if b != "infinity": b = int(b) 698 assert a == "infinity" or a >= 0 699 assert b == "infinity" or b >= 0 700 if a == b == 0: return "infinity" 701 # GCD(0, x) is always x, so it's safe to shortcut when a == 0 or b == 0. 702 if a == 0: return b 703 if b == 0: return a 704 if a == "infinity": return b 705 if b == "infinity": return a 706 return _math_gcd(a, b) 707 708 709def compute_constants(ir): 710 """Computes constant values for all expressions in ir. 711 712 compute_constants calculates all constant values and adds them to the type 713 information for each expression and subexpression. 714 715 Arguments: 716 ir: an IR on which to compute constants 717 718 Returns: 719 A (possibly empty) list of errors. 720 """ 721 traverse_ir.fast_traverse_ir_top_down( 722 ir, [ir_data.Expression], compute_constraints_of_expression, 723 skip_descendants_of={ir_data.Expression}) 724 traverse_ir.fast_traverse_ir_top_down( 725 ir, [ir_data.RuntimeParameter], _compute_constraints_of_parameter, 726 skip_descendants_of={ir_data.Expression}) 727 return [] 728