xref: /aosp_15_r20/external/tensorflow/tensorflow/python/util/protobuf/compare.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utility functions for comparing proto2 messages in Python.
17
18ProtoEq() compares two proto2 messages for equality.
19
20ClearDefaultValuedFields() recursively clears the fields that are set to their
21default values. This is useful for comparing protocol buffers where the
22semantics of unset fields and default valued fields are the same.
23
24assertProtoEqual() is useful for unit tests.  It produces much more helpful
25output than assertEqual() for proto2 messages, e.g. this:
26
27  outer {
28    inner {
29-     strings: "x"
30?               ^
31+     strings: "y"
32?               ^
33    }
34  }
35
36...compared to the default output from assertEqual() that looks like this:
37
38AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc>
39
40Call it inside your unit test's googletest.TestCase subclasses like this:
41
42  from tensorflow.python.util.protobuf import compare
43
44  class MyTest(googletest.TestCase):
45    ...
46    def testXXX(self):
47      ...
48      compare.assertProtoEqual(self, a, b)
49
50Alternatively:
51
52  from tensorflow.python.util.protobuf import compare
53
54  class MyTest(compare.ProtoAssertions, googletest.TestCase):
55    ...
56    def testXXX(self):
57      ...
58      self.assertProtoEqual(a, b)
59"""
60
61import difflib
62
63import six
64
65from google.protobuf import descriptor
66from google.protobuf import descriptor_pool
67from google.protobuf import message
68from google.protobuf import text_format
69
70from ..compat import collections_abc
71
72
73def assertProtoEqual(self, a, b, check_initialized=True,  # pylint: disable=invalid-name
74                     normalize_numbers=False, msg=None):
75  """Fails with a useful error if a and b aren't equal.
76
77  Comparison of repeated fields matches the semantics of
78  unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter.
79
80  Args:
81    self: googletest.TestCase
82    a: proto2 PB instance, or text string representing one.
83    b: proto2 PB instance -- message.Message or subclass thereof.
84    check_initialized: boolean, whether to fail if either a or b isn't
85      initialized.
86    normalize_numbers: boolean, whether to normalize types and precision of
87      numbers before comparison.
88    msg: if specified, is used as the error message on failure.
89  """
90  pool = descriptor_pool.Default()
91  if isinstance(a, six.string_types):
92    a = text_format.Merge(a, b.__class__(), descriptor_pool=pool)
93
94  for pb in a, b:
95    if check_initialized:
96      errors = pb.FindInitializationErrors()
97      if errors:
98        self.fail('Initialization errors: %s\n%s' % (errors, pb))
99    if normalize_numbers:
100      NormalizeNumberFields(pb)
101
102  a_str = text_format.MessageToString(a, descriptor_pool=pool)
103  b_str = text_format.MessageToString(b, descriptor_pool=pool)
104
105  # Some Python versions would perform regular diff instead of multi-line
106  # diff if string is longer than 2**16. We substitute this behavior
107  # with a call to unified_diff instead to have easier-to-read diffs.
108  # For context, see: https://bugs.python.org/issue11763.
109  if len(a_str) < 2**16 and len(b_str) < 2**16:
110    self.assertMultiLineEqual(a_str, b_str, msg=msg)
111  else:
112    diff = ''.join(
113        difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True)))
114    if diff:
115      self.fail('%s :\n%s' % (msg, diff))
116
117
118def NormalizeNumberFields(pb):
119  """Normalizes types and precisions of number fields in a protocol buffer.
120
121  Due to subtleties in the python protocol buffer implementation, it is possible
122  for values to have different types and precision depending on whether they
123  were set and retrieved directly or deserialized from a protobuf. This function
124  normalizes integer values to ints and longs based on width, 32-bit floats to
125  five digits of precision to account for python always storing them as 64-bit,
126  and ensures doubles are floating point for when they're set to integers.
127
128  Modifies pb in place. Recurses into nested objects.
129
130  Args:
131    pb: proto2 message.
132
133  Returns:
134    the given pb, modified in place.
135  """
136  for desc, values in pb.ListFields():
137    is_repeated = True
138    if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED:
139      is_repeated = False
140      values = [values]
141
142    normalized_values = None
143
144    # We force 32-bit values to int and 64-bit values to long to make
145    # alternate implementations where the distinction is more significant
146    # (e.g. the C++ implementation) simpler.
147    if desc.type in (descriptor.FieldDescriptor.TYPE_INT64,
148                     descriptor.FieldDescriptor.TYPE_UINT64,
149                     descriptor.FieldDescriptor.TYPE_SINT64):
150      normalized_values = [int(x) for x in values]
151    elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32,
152                       descriptor.FieldDescriptor.TYPE_UINT32,
153                       descriptor.FieldDescriptor.TYPE_SINT32,
154                       descriptor.FieldDescriptor.TYPE_ENUM):
155      normalized_values = [int(x) for x in values]
156    elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT:
157      normalized_values = [round(x, 6) for x in values]
158    elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE:
159      normalized_values = [round(float(x), 7) for x in values]
160
161    if normalized_values is not None:
162      if is_repeated:
163        pb.ClearField(desc.name)
164        getattr(pb, desc.name).extend(normalized_values)
165      else:
166        setattr(pb, desc.name, normalized_values[0])
167
168    if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or
169        desc.type == descriptor.FieldDescriptor.TYPE_GROUP):
170      if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
171          desc.message_type.has_options and
172          desc.message_type.GetOptions().map_entry):
173        # This is a map, only recurse if the values have a message type.
174        if (desc.message_type.fields_by_number[2].type ==
175            descriptor.FieldDescriptor.TYPE_MESSAGE):
176          for v in six.itervalues(values):
177            NormalizeNumberFields(v)
178      else:
179        for v in values:
180          # recursive step
181          NormalizeNumberFields(v)
182
183  return pb
184
185
186def _IsMap(value):
187  return isinstance(value, collections_abc.Mapping)
188
189
190def _IsRepeatedContainer(value):
191  if isinstance(value, six.string_types):
192    return False
193  try:
194    iter(value)
195    return True
196  except TypeError:
197    return False
198
199
200def ProtoEq(a, b):
201  """Compares two proto2 objects for equality.
202
203  Recurses into nested messages. Uses list (not set) semantics for comparing
204  repeated fields, ie duplicates and order matter.
205
206  Args:
207    a: A proto2 message or a primitive.
208    b: A proto2 message or a primitive.
209
210  Returns:
211    `True` if the messages are equal.
212  """
213  def Format(pb):
214    """Returns a dictionary or unchanged pb bases on its type.
215
216    Specifically, this function returns a dictionary that maps tag
217    number (for messages) or element index (for repeated fields) to
218    value, or just pb unchanged if it's neither.
219
220    Args:
221      pb: A proto2 message or a primitive.
222    Returns:
223      A dict or unchanged pb.
224    """
225    if isinstance(pb, message.Message):
226      return dict((desc.number, value) for desc, value in pb.ListFields())
227    elif _IsMap(pb):
228      return dict(pb.items())
229    elif _IsRepeatedContainer(pb):
230      return dict(enumerate(list(pb)))
231    else:
232      return pb
233
234  a, b = Format(a), Format(b)
235
236  # Base case
237  if not isinstance(a, dict) or not isinstance(b, dict):
238    return a == b
239
240  # This list performs double duty: it compares two messages by tag value *or*
241  # two repeated fields by element, in order. the magic is in the format()
242  # function, which converts them both to the same easily comparable format.
243  for tag in sorted(set(a.keys()) | set(b.keys())):
244    if tag not in a or tag not in b:
245      return False
246    else:
247      # Recursive step
248      if not ProtoEq(a[tag], b[tag]):
249        return False
250
251  # Didn't find any values that differed, so they're equal!
252  return True
253
254
255class ProtoAssertions(object):
256  """Mix this into a googletest.TestCase class to get proto2 assertions.
257
258  Usage:
259
260  class SomeTestCase(compare.ProtoAssertions, googletest.TestCase):
261    ...
262    def testSomething(self):
263      ...
264      self.assertProtoEqual(a, b)
265
266  See module-level definitions for method documentation.
267  """
268
269  # pylint: disable=invalid-name
270  def assertProtoEqual(self, *args, **kwargs):
271    return assertProtoEqual(self, *args, **kwargs)
272