xref: /aosp_15_r20/external/emboss/compiler/front_end/expression_bounds.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Functions for proving mathematical properties of expressions."""
16
17import math
18import fractions
19import operator
20
21from compiler.util import ir_data
22from compiler.util import ir_data_utils
23from compiler.util import ir_util
24from compiler.util import traverse_ir
25
26
27# Create a local alias for math.gcd with a fallback to fractions.gcd if it is
28# not available. This can be dropped if pre-3.5 Python support is dropped.
29if hasattr(math, 'gcd'):
30  _math_gcd = math.gcd
31else:
32  _math_gcd = fractions.gcd
33
34
35def compute_constraints_of_expression(expression, ir):
36  """Adds appropriate bounding constraints to the given expression."""
37  if ir_util.is_constant_type(expression.type):
38    return
39  expression_variety = expression.WhichOneof("expression")
40  if expression_variety == "constant":
41    _compute_constant_value_of_constant(expression)
42  elif expression_variety == "constant_reference":
43    _compute_constant_value_of_constant_reference(expression, ir)
44  elif expression_variety == "function":
45    _compute_constraints_of_function(expression, ir)
46  elif expression_variety == "field_reference":
47    _compute_constraints_of_field_reference(expression, ir)
48  elif expression_variety == "builtin_reference":
49    _compute_constraints_of_builtin_value(expression)
50  elif expression_variety == "boolean_constant":
51    _compute_constant_value_of_boolean_constant(expression)
52  else:
53    assert False, "Unknown expression variety {!r}".format(expression_variety)
54  if expression.type.WhichOneof("type") == "integer":
55    _assert_integer_constraints(expression)
56
57
58def _compute_constant_value_of_constant(expression):
59  value = expression.constant.value
60  expression.type.integer.modular_value = value
61  expression.type.integer.minimum_value = value
62  expression.type.integer.maximum_value = value
63  expression.type.integer.modulus = "infinity"
64
65
66def _compute_constant_value_of_constant_reference(expression, ir):
67  referred_object = ir_util.find_object(
68      expression.constant_reference.canonical_name, ir)
69  expression = ir_data_utils.builder(expression)
70  if isinstance(referred_object, ir_data.EnumValue):
71    compute_constraints_of_expression(referred_object.value, ir)
72    assert ir_util.is_constant(referred_object.value)
73    new_value = str(ir_util.constant_value(referred_object.value))
74    expression.type.enumeration.value = new_value
75  elif isinstance(referred_object, ir_data.Field):
76    assert ir_util.field_is_virtual(referred_object), (
77        "Non-virtual non-enum-value constant reference should have been caught "
78        "in type_check.py")
79    compute_constraints_of_expression(referred_object.read_transform, ir)
80    expression.type.CopyFrom(referred_object.read_transform.type)
81  else:
82    assert False, "Unexpected constant reference type."
83
84
85def _compute_constraints_of_function(expression, ir):
86  """Computes the known constraints of the result of a function."""
87  for arg in expression.function.args:
88    compute_constraints_of_expression(arg, ir)
89  op = expression.function.function
90  if op in (ir_data.FunctionMapping.ADDITION, ir_data.FunctionMapping.SUBTRACTION):
91    _compute_constraints_of_additive_operator(expression)
92  elif op == ir_data.FunctionMapping.MULTIPLICATION:
93    _compute_constraints_of_multiplicative_operator(expression)
94  elif op in (ir_data.FunctionMapping.EQUALITY, ir_data.FunctionMapping.INEQUALITY,
95              ir_data.FunctionMapping.LESS, ir_data.FunctionMapping.LESS_OR_EQUAL,
96              ir_data.FunctionMapping.GREATER, ir_data.FunctionMapping.GREATER_OR_EQUAL,
97              ir_data.FunctionMapping.AND, ir_data.FunctionMapping.OR):
98    _compute_constant_value_of_comparison_operator(expression)
99  elif op == ir_data.FunctionMapping.CHOICE:
100    _compute_constraints_of_choice_operator(expression)
101  elif op == ir_data.FunctionMapping.MAXIMUM:
102    _compute_constraints_of_maximum_function(expression)
103  elif op == ir_data.FunctionMapping.PRESENCE:
104    _compute_constraints_of_existence_function(expression, ir)
105  elif op in (ir_data.FunctionMapping.UPPER_BOUND, ir_data.FunctionMapping.LOWER_BOUND):
106    _compute_constraints_of_bound_function(expression)
107  else:
108    assert False, "Unknown operator {!r}".format(op)
109
110
111def _compute_constraints_of_existence_function(expression, ir):
112  """Computes the constraints of a $has(field) expression."""
113  field_path = expression.function.args[0].field_reference.path[-1]
114  field = ir_util.find_object(field_path, ir)
115  compute_constraints_of_expression(field.existence_condition, ir)
116  ir_data_utils.builder(expression).type.CopyFrom(field.existence_condition.type)
117
118
119def _compute_constraints_of_field_reference(expression, ir):
120  """Computes the constraints of a reference to a structure's field."""
121  field_path = expression.field_reference.path[-1]
122  field = ir_util.find_object(field_path, ir)
123  if isinstance(field, ir_data.Field) and ir_util.field_is_virtual(field):
124    # References to virtual fields should have the virtual field's constraints
125    # copied over.
126    compute_constraints_of_expression(field.read_transform, ir)
127    ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
128    return
129  # Non-virtual non-integer fields do not (yet) have constraints.
130  if expression.type.WhichOneof("type") == "integer":
131    # TODO(bolms): These lines will need to change when support is added for
132    # fixed-point types.
133    expression.type.integer.modulus = "1"
134    expression.type.integer.modular_value = "0"
135    type_definition = ir_util.find_parent_object(field_path, ir)
136    if isinstance(field, ir_data.Field):
137      referrent_type = field.type
138    else:
139      referrent_type = field.physical_type_alias
140    if referrent_type.HasField("size_in_bits"):
141      type_size = ir_util.constant_value(referrent_type.size_in_bits)
142    else:
143      field_size = ir_util.constant_value(field.location.size)
144      if field_size is None:
145        type_size = None
146      else:
147        type_size = field_size * type_definition.addressable_unit
148    assert referrent_type.HasField("atomic_type"), field
149    assert not referrent_type.atomic_type.reference.canonical_name.module_file
150    _set_integer_constraints_from_physical_type(
151        expression, referrent_type, type_size)
152
153
154def _set_integer_constraints_from_physical_type(
155    expression, physical_type, type_size):
156  """Copies the integer constraints of an expression from a physical type."""
157  # SCAFFOLDING HACK: In order to keep changelists manageable, this hardcodes
158  # the ranges for all of the Emboss Prelude integer types.   This would break
159  # any user-defined `external` integer types, but that feature isn't fully
160  # implemented in the C++ backend, so it doesn't matter for now.
161  #
162  # Adding the attribute(s) for integer bounds will require new operators:
163  # integer/flooring division, remainder, and exponentiation (2**N, 10**N).
164  #
165  # (Technically, there are a few sets of operators that would work: for
166  # example, just the choice operator `?:` is sufficient, but very ugly.
167  # Bitwise AND, bitshift, and exponentiation would also work, but `10**($bits
168  # >> 2) * 2**($bits & 0b11) - 1` isn't quite as clear as `10**($bits // 4) *
169  # 2**($bits % 4) - 1`, in my (bolms@) opinion.)
170  #
171  # TODO(bolms): Add a scheme for defining integer bounds on user-defined
172  # external types.
173  if type_size is None:
174    # If the type_size is unknown, then we can't actually say anything about the
175    # minimum and maximum values of the type.  For UInt, Int, and Bcd, an error
176    # will be thrown during the constraints check stage.
177    expression.type.integer.minimum_value = "-infinity"
178    expression.type.integer.maximum_value = "infinity"
179    return
180  name = tuple(physical_type.atomic_type.reference.canonical_name.object_path)
181  if name == ("UInt",):
182    expression.type.integer.minimum_value = "0"
183    expression.type.integer.maximum_value = str(2**type_size - 1)
184  elif name == ("Int",):
185    expression.type.integer.minimum_value = str(-(2**(type_size - 1)))
186    expression.type.integer.maximum_value = str(2**(type_size - 1) - 1)
187  elif name == ("Bcd",):
188    expression.type.integer.minimum_value = "0"
189    expression.type.integer.maximum_value = str(
190        10**(type_size // 4) * 2**(type_size % 4) - 1)
191  else:
192    assert False, "Unknown integral type " + ".".join(name)
193
194
195def _compute_constraints_of_parameter(parameter):
196  if parameter.type.WhichOneof("type") == "integer":
197    type_size = ir_util.constant_value(
198        parameter.physical_type_alias.size_in_bits)
199    _set_integer_constraints_from_physical_type(
200        parameter, parameter.physical_type_alias, type_size)
201
202
203def _compute_constraints_of_builtin_value(expression):
204  """Computes the constraints of a builtin (like $static_size_in_bits)."""
205  name = expression.builtin_reference.canonical_name.object_path[0]
206  if name == "$static_size_in_bits":
207    expression.type.integer.modulus = "1"
208    expression.type.integer.modular_value = "0"
209    expression.type.integer.minimum_value = "0"
210    # The maximum theoretically-supported size of something is 2**64 bytes,
211    # which is 2**64 * 8 bits.
212    #
213    # Really, $static_size_in_bits is only valid in expressions that have to be
214    # evaluated at compile time anyway, so it doesn't really matter if the
215    # bounds are excessive.
216    expression.type.integer.maximum_value = "infinity"
217  elif name == "$is_statically_sized":
218    # No bounds on a boolean variable.
219    pass
220  elif name == "$logical_value":
221    # $logical_value is the placeholder used in inferred write-through
222    # transformations.
223    #
224    # Only integers (currently) have "real" write-through transformations, but
225    # fields that would otherwise be straight aliases, but which have a
226    # [requires] attribute, are elevated to write-through fields, so that the
227    # [requires] clause can be checked in Write, CouldWriteValue, TryToWrite,
228    # Read, and Ok.
229    if expression.type.WhichOneof("type") == "integer":
230      assert expression.type.integer.modulus
231      assert expression.type.integer.modular_value
232      assert expression.type.integer.minimum_value
233      assert expression.type.integer.maximum_value
234    elif expression.type.WhichOneof("type") == "enumeration":
235      assert expression.type.enumeration.name
236    elif expression.type.WhichOneof("type") == "boolean":
237      pass
238    else:
239      assert False, "Unexpected type for $logical_value"
240  else:
241    assert False, "Unknown builtin " + name
242
243
244def _compute_constant_value_of_boolean_constant(expression):
245  expression.type.boolean.value = expression.boolean_constant.value
246
247
248def _add(a, b):
249  """Adds a and b, where a and b are ints, "infinity", or "-infinity"."""
250  if a in ("infinity", "-infinity"):
251    a, b = b, a
252  if b == "infinity":
253    assert a != "-infinity"
254    return "infinity"
255  if b == "-infinity":
256    assert a != "infinity"
257    return "-infinity"
258  return int(a) + int(b)
259
260
261def _sub(a, b):
262  """Subtracts b from a, where a and b are ints, "infinity", or "-infinity"."""
263  if b == "infinity":
264    return _add(a, "-infinity")
265  if b == "-infinity":
266    return _add(a, "infinity")
267  return _add(a, -int(b))
268
269
270def _sign(a):
271  """Returns 1 if a > 0, 0 if a == 0, and -1 if a < 0."""
272  if a == "infinity":
273    return 1
274  if a == "-infinity":
275    return -1
276  if int(a) > 0:
277    return 1
278  if int(a) < 0:
279    return -1
280  return 0
281
282
283def _mul(a, b):
284  """Multiplies a and b, where a and b are ints, "infinity", or "-infinity"."""
285  if _is_infinite(a):
286    a, b = b, a
287  if _is_infinite(b):
288    sign = _sign(a) * _sign(b)
289    if sign > 0:
290      return "infinity"
291    if sign < 0:
292      return "-infinity"
293    return 0
294  return int(a) * int(b)
295
296
297def _is_infinite(a):
298  return a in ("infinity", "-infinity")
299
300
301def _max(a):
302  """Returns max of a, where elements are ints, "infinity", or "-infinity"."""
303  if any(n == "infinity" for n in a):
304    return "infinity"
305  if all(n == "-infinity" for n in a):
306    return "-infinity"
307  return max(int(n) for n in a if not _is_infinite(n))
308
309
310def _min(a):
311  """Returns min of a, where elements are ints, "infinity", or "-infinity"."""
312  if any(n == "-infinity" for n in a):
313    return "-infinity"
314  if all(n == "infinity" for n in a):
315    return "infinity"
316  return min(int(n) for n in a if not _is_infinite(n))
317
318
319def _compute_constraints_of_additive_operator(expression):
320  """Computes the modular value of an additive expression."""
321  funcs = {
322      ir_data.FunctionMapping.ADDITION: _add,
323      ir_data.FunctionMapping.SUBTRACTION: _sub,
324  }
325  func = funcs[expression.function.function]
326  args = expression.function.args
327  for arg in args:
328    assert arg.type.integer.modular_value, str(expression)
329  left, right = args
330  unadjusted_modular_value = func(left.type.integer.modular_value,
331                                  right.type.integer.modular_value)
332  new_modulus = _greatest_common_divisor(left.type.integer.modulus,
333                                         right.type.integer.modulus)
334  expression.type.integer.modulus = str(new_modulus)
335  if new_modulus == "infinity":
336    expression.type.integer.modular_value = str(unadjusted_modular_value)
337  else:
338    expression.type.integer.modular_value = str(unadjusted_modular_value %
339                                                new_modulus)
340  lmax = left.type.integer.maximum_value
341  lmin = left.type.integer.minimum_value
342  if expression.function.function == ir_data.FunctionMapping.SUBTRACTION:
343    rmax = right.type.integer.minimum_value
344    rmin = right.type.integer.maximum_value
345  else:
346    rmax = right.type.integer.maximum_value
347    rmin = right.type.integer.minimum_value
348  expression.type.integer.minimum_value = str(func(lmin, rmin))
349  expression.type.integer.maximum_value = str(func(lmax, rmax))
350
351
352def _compute_constraints_of_multiplicative_operator(expression):
353  """Computes the modular value of a multiplicative expression."""
354  bounds = [arg.type.integer for arg in expression.function.args]
355
356  # The minimum and maximum values can come from any of the four pairings of
357  # (left min, left max) with (right min, right max), depending on the signs and
358  # magnitudes of the minima and maxima.  E.g.:
359  #
360  # max = left max * right max: [ 2,  3] * [ 2,  3]
361  # max = left min * right min: [-3, -2] * [-3, -2]
362  # max = left max * right min: [-3, -2] * [ 2,  3]
363  # max = left min * right max: [ 2,  3] * [-3, -2]
364  # max = left max * right max: [-2,  3] * [-2,  3]
365  # max = left min * right min: [-3,  2] * [-3,  2]
366  #
367  # For uncorrelated multiplication, the minimum and maximum will always come
368  # from multiplying one extreme by another: if x is nonzero, then
369  #
370  #     (y + e) * x > y * x  ||  (y - e) * x > y * x
371  #
372  # for arbitrary nonzero e, so the extrema can only occur when we either cannot
373  # add or cannot subtract e.
374  #
375  # Correlated multiplication (e.g., `x * x`) can have tighter bounds, but
376  # Emboss is not currently trying to be that smart.
377  lmin, lmax = bounds[0].minimum_value, bounds[0].maximum_value
378  rmin, rmax = bounds[1].minimum_value, bounds[1].maximum_value
379  extrema = [_mul(lmax, rmax), _mul(lmin, rmax),  #
380             _mul(lmax, rmin), _mul(lmin, rmin)]
381  expression.type.integer.minimum_value = str(_min(extrema))
382  expression.type.integer.maximum_value = str(_max(extrema))
383
384  if all(bound.modulus == "infinity" for bound in bounds):
385    # If both sides are constant, the result is constant.
386    expression.type.integer.modulus = "infinity"
387    expression.type.integer.modular_value = str(int(bounds[0].modular_value) *
388                                                int(bounds[1].modular_value))
389    return
390
391  if any(bound.modulus == "infinity" for bound in bounds):
392    # If one side is constant and the other is not, then the non-constant
393    # modulus and modular_value can both be multiplied by the constant.  E.g.,
394    # if `a` is congruent to 3 mod 5, then `4 * a` will be congruent to 12 mod
395    # 20:
396    #
397    #   a = ...   |  4 * a = ...  |  4 * a mod 20 = ...
398    #   3         |  12           |  12
399    #   8         |  32           |  12
400    #   13        |  52           |  12
401    #   18        |  72           |  12
402    #   23        |  92           |  12
403    #   28        |  112          |  12
404    #   33        |  132          |  12
405    #
406    # This is trivially shown by noting that the difference between consecutive
407    # possible values for `4 * a` always differ by 20.
408    if bounds[0].modulus == "infinity":
409      constant, variable = bounds
410    else:
411      variable, constant = bounds
412    if int(constant.modular_value) == 0:
413      # If the constant is 0, the result is 0, no matter what the variable side
414      # is.
415      expression.type.integer.modulus = "infinity"
416      expression.type.integer.modular_value = "0"
417      return
418    new_modulus = int(variable.modulus) * abs(int(constant.modular_value))
419    expression.type.integer.modulus = str(new_modulus)
420    # The `% new_modulus` will force the `modular_value` to be positive, even
421    # when `constant.modular_value` is negative.
422    expression.type.integer.modular_value = str(
423        int(variable.modular_value) * int(constant.modular_value) % new_modulus)
424    return
425
426  # If neither side is constant, then the result is more complex.  Full proof is
427  # available in g3doc/modular_congruence_multiplication_proof.md
428  #
429  # Essentially, if:
430  #
431  # l == _ * l_mod + l_mv
432  # r == _ * r_mod + r_mv
433  #
434  # Then we find l_mod0 and r_mod0 in:
435  #
436  # l == (_ * l_mod_nz + l_mv_nz) * l_mod0
437  # r == (_ * r_mod_nz + r_mv_nz) * r_mod0
438  #
439  # And finally conclude:
440  #
441  # l * r == _ * GCD(l_mod_nz, r_mod_nz) * l_mod0 * r_mod0 + l_mv * r_mv
442  product_of_zero_congruence_moduli = 1
443  product_of_modular_values = 1
444  nonzero_congruence_moduli = []
445  for bound in bounds:
446    zero_congruence_modulus = _greatest_common_divisor(bound.modulus,
447                                                       bound.modular_value)
448    assert int(bound.modulus) % zero_congruence_modulus == 0
449    product_of_zero_congruence_moduli *= zero_congruence_modulus
450    product_of_modular_values *= int(bound.modular_value)
451    nonzero_congruence_moduli.append(int(bound.modulus) //
452                                     zero_congruence_modulus)
453  shared_nonzero_congruence_modulus = _greatest_common_divisor(
454      nonzero_congruence_moduli[0], nonzero_congruence_moduli[1])
455  final_modulus = (shared_nonzero_congruence_modulus *
456                   product_of_zero_congruence_moduli)
457  expression.type.integer.modulus = str(final_modulus)
458  expression.type.integer.modular_value = str(product_of_modular_values %
459                                              final_modulus)
460
461
462def _assert_integer_constraints(expression):
463  """Asserts that the integer bounds of expression are self-consistent.
464
465  Asserts that `minimum_value` and `maximum_value` are congruent to
466  `modular_value` modulo `modulus`.
467
468  If `modulus` is "infinity", asserts that `minimum_value`, `maximum_value`, and
469  `modular_value` are all equal.
470
471  If `minimum_value` is equal to `maximum_value`, asserts that `modular_value`
472  is equal to both, and that `modulus` is "infinity".
473
474  Arguments:
475      expression: an expression with type.integer
476
477  Returns:
478      None
479  """
480  bounds = expression.type.integer
481  if bounds.modulus == "infinity":
482    assert bounds.minimum_value == bounds.modular_value
483    assert bounds.maximum_value == bounds.modular_value
484    return
485  modulus = int(bounds.modulus)
486  assert modulus > 0
487  if bounds.minimum_value != "-infinity":
488    assert int(bounds.minimum_value) % modulus == int(bounds.modular_value)
489  if bounds.maximum_value != "infinity":
490    assert int(bounds.maximum_value) % modulus == int(bounds.modular_value)
491  if bounds.minimum_value == bounds.maximum_value:
492    # TODO(bolms): I believe there are situations using the not-yet-implemented
493    # integer division operator that would trigger these asserts, so they should
494    # be turned into assignments (with corresponding tests) when implementing
495    # division.
496    assert bounds.modular_value == bounds.minimum_value
497    assert bounds.modulus == "infinity"
498  if bounds.minimum_value != "-infinity" and bounds.maximum_value != "infinity":
499    assert int(bounds.minimum_value) <= int(bounds.maximum_value)
500
501
502def _compute_constant_value_of_comparison_operator(expression):
503  """Computes the constant value, if any, of a comparison operator."""
504  args = expression.function.args
505  if all(ir_util.is_constant(arg) for arg in args):
506    functions = {
507        ir_data.FunctionMapping.EQUALITY: operator.eq,
508        ir_data.FunctionMapping.INEQUALITY: operator.ne,
509        ir_data.FunctionMapping.LESS: operator.lt,
510        ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le,
511        ir_data.FunctionMapping.GREATER: operator.gt,
512        ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge,
513        ir_data.FunctionMapping.AND: operator.and_,
514        ir_data.FunctionMapping.OR: operator.or_,
515    }
516    func = functions[expression.function.function]
517    expression.type.boolean.value = func(
518        *[ir_util.constant_value(arg) for arg in args])
519
520
521def _compute_constraints_of_bound_function(expression):
522  """Computes the constraints of $upper_bound or $lower_bound."""
523  if expression.function.function == ir_data.FunctionMapping.UPPER_BOUND:
524    value = expression.function.args[0].type.integer.maximum_value
525  elif expression.function.function == ir_data.FunctionMapping.LOWER_BOUND:
526    value = expression.function.args[0].type.integer.minimum_value
527  else:
528    assert False, "Non-bound function"
529  expression.type.integer.minimum_value = value
530  expression.type.integer.maximum_value = value
531  expression.type.integer.modular_value = value
532  expression.type.integer.modulus = "infinity"
533
534
535def _compute_constraints_of_maximum_function(expression):
536  """Computes the constraints of the $max function."""
537  assert expression.type.WhichOneof("type") == "integer"
538  args = expression.function.args
539  assert args[0].type.WhichOneof("type") == "integer"
540  # The minimum value of the result occurs when every argument takes its minimum
541  # value, which means that the minimum result is the maximum-of-minimums.
542  expression.type.integer.minimum_value = str(_max(
543      [arg.type.integer.minimum_value for arg in args]))
544  # The maximum result is the maximum-of-maximums.
545  expression.type.integer.maximum_value = str(_max(
546      [arg.type.integer.maximum_value for arg in args]))
547  # If the expression is dominated by a constant factor, then the result is
548  # constant.  I (bolms@) believe this is the only case where
549  # _compute_constraints_of_maximum_function might violate the assertions in
550  # _assert_integer_constraints.
551  if (expression.type.integer.minimum_value ==
552      expression.type.integer.maximum_value):
553    expression.type.integer.modular_value = (
554        expression.type.integer.minimum_value)
555    expression.type.integer.modulus = "infinity"
556    return
557  result_modulus = args[0].type.integer.modulus
558  result_modular_value = args[0].type.integer.modular_value
559  # The result of $max(a, b) could be either a or b, which means that the result
560  # of $max(a, b) uses the _shared_modular_value() of a and b, just like the
561  # choice operator '?:'.
562  #
563  # This also takes advantage of the fact that $max(a, b, c, d, ...) is
564  # equivalent to $max(a, $max(b, $max(c, $max(d, ...)))), so it is valid to
565  # call _shared_modular_value() in a loop.
566  for arg in args[1:]:
567    # TODO(bolms): I think the bounds could be tigher in some cases where
568    # arg.maximum_value is less than the new expression.minimum_value, and
569    # in some very specific cases where arg.maximum_value is greater than the
570    # new expression.minimum_value, but arg.maximum_value - arg.modulus is less
571    # than expression.minimum_value.
572    result_modulus, result_modular_value = _shared_modular_value(
573        (result_modulus, result_modular_value),
574        (arg.type.integer.modulus, arg.type.integer.modular_value))
575  expression.type.integer.modulus = str(result_modulus)
576  expression.type.integer.modular_value = str(result_modular_value)
577
578
579def _shared_modular_value(left, right):
580  """Returns the shared modulus and modular value of left and right.
581
582  Arguments:
583    left: A tuple of (modulus, modular value)
584    right: A tuple of (modulus, modular value)
585
586  Returns:
587    A tuple of (modulus, modular_value) such that:
588
589    left.modulus % result.modulus == 0
590    right.modulus % result.modulus == 0
591    left.modular_value % result.modulus = result.modular_value
592    right.modular_value % result.modulus = result.modular_value
593
594    That is, the result.modulus and result.modular_value will be compatible
595    with, but (possibly) less restrictive than both left.(modulus,
596    modular_value) and right.(modulus, modular_value).
597  """
598  left_modulus, left_modular_value = left
599  right_modulus, right_modular_value = right
600  # The combined modulus is gcd(gcd(left_modulus, right_modulus),
601  # left_modular_value - right_modular_value).
602  #
603  # The inner gcd normalizes the left_modulus and right_modulus, but can leave
604  # incompatible modular_values.  The outer gcd finds a modulus to which both
605  # modular_values are congruent.  Some examples:
606  #
607  #     left          |  right         |  res
608  #     --------------+----------------+--------------------
609  #     l % 12 == 7   |  r % 20 == 15  |  res % 4 == 3
610  #     l == 35       |  r % 20 == 15  |  res % 20 == 15
611  #     l % 24 == 15  |  r % 12 == 7   |  res % 4 == 3
612  #     l % 20 == 15  |  r % 20 == 10  |  res % 5 == 0
613  #     l % 20 == 16  |  r % 20 == 11  |  res % 5 == 1
614  #     l == 10       |  r == 7        |  res % 3 == 1
615  #     l == 4        |  r == 4        |  res == 4
616  #
617  # The cases where one side or the other are constant are handled
618  # automatically by the fact that _greatest_common_divisor("infinity", x)
619  # is x.
620  common_modulus = _greatest_common_divisor(left_modulus, right_modulus)
621  new_modulus = _greatest_common_divisor(
622      common_modulus, abs(int(left_modular_value) - int(right_modular_value)))
623  if new_modulus == "infinity":
624    # The only way for the new_modulus to come out as "infinity" *should* be
625    # if both if_true and if_false have the same constant value.
626    assert left_modular_value == right_modular_value
627    assert left_modulus == right_modulus == "infinity"
628    return new_modulus, left_modular_value
629  else:
630    assert (int(left_modular_value) % new_modulus ==
631            int(right_modular_value) % new_modulus)
632    return new_modulus, int(left_modular_value) % new_modulus
633
634
635def _compute_constraints_of_choice_operator(expression):
636  """Computes the constraints of a choice operation '?:'."""
637  condition, if_true, if_false = ir_data_utils.reader(expression).function.args
638  expression = ir_data_utils.builder(expression)
639  if condition.type.boolean.HasField("value"):
640    # The generated expressions for $size_in_bits and $size_in_bytes look like
641    #
642    #     $max((field1_existence_condition ? field1_start + field1_size : 0),
643    #          (field2_existence_condition ? field2_start + field2_size : 0),
644    #          (field3_existence_condition ? field3_start + field3_size : 0),
645    #          ...)
646    #
647    # Since most existence_conditions are just "true", it is important to select
648    # the tighter bounds in those cases -- otherwise, only zero-length
649    # structures could have a constant $size_in_bits or $size_in_bytes.
650    side = if_true if condition.type.boolean.value else if_false
651    expression.type.CopyFrom(side.type)
652    return
653  # The type.integer minimum_value/maximum_value bounding code is needed since
654  # constraints.check_constraints() will complain if minimum and maximum are not
655  # set correctly.  I'm (bolms@) not sure if the modulus/modular_value pulls its
656  # weight, but for completeness I've left it in.
657  if if_true.type.WhichOneof("type") == "integer":
658    # The minimum value of the choice is the minimum value of either side, and
659    # the maximum is the maximum value of either side.
660    expression.type.integer.minimum_value = str(_min([
661        if_true.type.integer.minimum_value,
662        if_false.type.integer.minimum_value]))
663    expression.type.integer.maximum_value = str(_max([
664        if_true.type.integer.maximum_value,
665        if_false.type.integer.maximum_value]))
666    new_modulus, new_modular_value = _shared_modular_value(
667        (if_true.type.integer.modulus, if_true.type.integer.modular_value),
668        (if_false.type.integer.modulus, if_false.type.integer.modular_value))
669    expression.type.integer.modulus = str(new_modulus)
670    expression.type.integer.modular_value = str(new_modular_value)
671  else:
672    assert if_true.type.WhichOneof("type") in ("boolean", "enumeration"), (
673        "Unknown type {} for expression".format(
674            if_true.type.WhichOneof("type")))
675
676
677def _greatest_common_divisor(a, b):
678  """Returns the greatest common divisor of a and b.
679
680  Arguments:
681    a: an integer, a stringified integer, or the string "infinity"
682    b: an integer, a stringified integer, or the string "infinity"
683
684  Returns:
685    Conceptually, "infinity" is treated as the product of all integers.
686
687    If both a and b are 0, returns "infinity".
688
689    Otherwise, if either a or b are "infinity", and the other is 0, returns
690    "infinity".
691
692    Otherwise, if either a or b are "infinity", returns the other.
693
694    Otherwise, returns the greatest common divisor of a and b.
695  """
696  if a != "infinity": a = int(a)
697  if b != "infinity": b = int(b)
698  assert a == "infinity" or a >= 0
699  assert b == "infinity" or b >= 0
700  if a == b == 0: return "infinity"
701  # GCD(0, x) is always x, so it's safe to shortcut when a == 0 or b == 0.
702  if a == 0: return b
703  if b == 0: return a
704  if a == "infinity": return b
705  if b == "infinity": return a
706  return _math_gcd(a, b)
707
708
709def compute_constants(ir):
710  """Computes constant values for all expressions in ir.
711
712  compute_constants calculates all constant values and adds them to the type
713  information for each expression and subexpression.
714
715  Arguments:
716      ir: an IR on which to compute constants
717
718  Returns:
719      A (possibly empty) list of errors.
720  """
721  traverse_ir.fast_traverse_ir_top_down(
722      ir, [ir_data.Expression], compute_constraints_of_expression,
723      skip_descendants_of={ir_data.Expression})
724  traverse_ir.fast_traverse_ir_top_down(
725      ir, [ir_data.RuntimeParameter], _compute_constraints_of_parameter,
726      skip_descendants_of={ir_data.Expression})
727  return []
728