1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utility functions for comparing proto2 messages in Python. 17 18ProtoEq() compares two proto2 messages for equality. 19 20ClearDefaultValuedFields() recursively clears the fields that are set to their 21default values. This is useful for comparing protocol buffers where the 22semantics of unset fields and default valued fields are the same. 23 24assertProtoEqual() is useful for unit tests. It produces much more helpful 25output than assertEqual() for proto2 messages, e.g. this: 26 27 outer { 28 inner { 29- strings: "x" 30? ^ 31+ strings: "y" 32? ^ 33 } 34 } 35 36...compared to the default output from assertEqual() that looks like this: 37 38AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc> 39 40Call it inside your unit test's googletest.TestCase subclasses like this: 41 42 from tensorflow.python.util.protobuf import compare 43 44 class MyTest(googletest.TestCase): 45 ... 46 def testXXX(self): 47 ... 48 compare.assertProtoEqual(self, a, b) 49 50Alternatively: 51 52 from tensorflow.python.util.protobuf import compare 53 54 class MyTest(compare.ProtoAssertions, googletest.TestCase): 55 ... 56 def testXXX(self): 57 ... 58 self.assertProtoEqual(a, b) 59""" 60 61import difflib 62 63import six 64 65from google.protobuf import descriptor 66from google.protobuf import descriptor_pool 67from google.protobuf import message 68from google.protobuf import text_format 69 70from ..compat import collections_abc 71 72 73def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=invalid-name 74 normalize_numbers=False, msg=None): 75 """Fails with a useful error if a and b aren't equal. 76 77 Comparison of repeated fields matches the semantics of 78 unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter. 79 80 Args: 81 self: googletest.TestCase 82 a: proto2 PB instance, or text string representing one. 83 b: proto2 PB instance -- message.Message or subclass thereof. 84 check_initialized: boolean, whether to fail if either a or b isn't 85 initialized. 86 normalize_numbers: boolean, whether to normalize types and precision of 87 numbers before comparison. 88 msg: if specified, is used as the error message on failure. 89 """ 90 pool = descriptor_pool.Default() 91 if isinstance(a, six.string_types): 92 a = text_format.Merge(a, b.__class__(), descriptor_pool=pool) 93 94 for pb in a, b: 95 if check_initialized: 96 errors = pb.FindInitializationErrors() 97 if errors: 98 self.fail('Initialization errors: %s\n%s' % (errors, pb)) 99 if normalize_numbers: 100 NormalizeNumberFields(pb) 101 102 a_str = text_format.MessageToString(a, descriptor_pool=pool) 103 b_str = text_format.MessageToString(b, descriptor_pool=pool) 104 105 # Some Python versions would perform regular diff instead of multi-line 106 # diff if string is longer than 2**16. We substitute this behavior 107 # with a call to unified_diff instead to have easier-to-read diffs. 108 # For context, see: https://bugs.python.org/issue11763. 109 if len(a_str) < 2**16 and len(b_str) < 2**16: 110 self.assertMultiLineEqual(a_str, b_str, msg=msg) 111 else: 112 diff = ''.join( 113 difflib.unified_diff(a_str.splitlines(True), b_str.splitlines(True))) 114 if diff: 115 self.fail('%s :\n%s' % (msg, diff)) 116 117 118def NormalizeNumberFields(pb): 119 """Normalizes types and precisions of number fields in a protocol buffer. 120 121 Due to subtleties in the python protocol buffer implementation, it is possible 122 for values to have different types and precision depending on whether they 123 were set and retrieved directly or deserialized from a protobuf. This function 124 normalizes integer values to ints and longs based on width, 32-bit floats to 125 five digits of precision to account for python always storing them as 64-bit, 126 and ensures doubles are floating point for when they're set to integers. 127 128 Modifies pb in place. Recurses into nested objects. 129 130 Args: 131 pb: proto2 message. 132 133 Returns: 134 the given pb, modified in place. 135 """ 136 for desc, values in pb.ListFields(): 137 is_repeated = True 138 if desc.label != descriptor.FieldDescriptor.LABEL_REPEATED: 139 is_repeated = False 140 values = [values] 141 142 normalized_values = None 143 144 # We force 32-bit values to int and 64-bit values to long to make 145 # alternate implementations where the distinction is more significant 146 # (e.g. the C++ implementation) simpler. 147 if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, 148 descriptor.FieldDescriptor.TYPE_UINT64, 149 descriptor.FieldDescriptor.TYPE_SINT64): 150 normalized_values = [int(x) for x in values] 151 elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, 152 descriptor.FieldDescriptor.TYPE_UINT32, 153 descriptor.FieldDescriptor.TYPE_SINT32, 154 descriptor.FieldDescriptor.TYPE_ENUM): 155 normalized_values = [int(x) for x in values] 156 elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: 157 normalized_values = [round(x, 6) for x in values] 158 elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE: 159 normalized_values = [round(float(x), 7) for x in values] 160 161 if normalized_values is not None: 162 if is_repeated: 163 pb.ClearField(desc.name) 164 getattr(pb, desc.name).extend(normalized_values) 165 else: 166 setattr(pb, desc.name, normalized_values[0]) 167 168 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or 169 desc.type == descriptor.FieldDescriptor.TYPE_GROUP): 170 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 171 desc.message_type.has_options and 172 desc.message_type.GetOptions().map_entry): 173 # This is a map, only recurse if the values have a message type. 174 if (desc.message_type.fields_by_number[2].type == 175 descriptor.FieldDescriptor.TYPE_MESSAGE): 176 for v in six.itervalues(values): 177 NormalizeNumberFields(v) 178 else: 179 for v in values: 180 # recursive step 181 NormalizeNumberFields(v) 182 183 return pb 184 185 186def _IsMap(value): 187 return isinstance(value, collections_abc.Mapping) 188 189 190def _IsRepeatedContainer(value): 191 if isinstance(value, six.string_types): 192 return False 193 try: 194 iter(value) 195 return True 196 except TypeError: 197 return False 198 199 200def ProtoEq(a, b): 201 """Compares two proto2 objects for equality. 202 203 Recurses into nested messages. Uses list (not set) semantics for comparing 204 repeated fields, ie duplicates and order matter. 205 206 Args: 207 a: A proto2 message or a primitive. 208 b: A proto2 message or a primitive. 209 210 Returns: 211 `True` if the messages are equal. 212 """ 213 def Format(pb): 214 """Returns a dictionary or unchanged pb bases on its type. 215 216 Specifically, this function returns a dictionary that maps tag 217 number (for messages) or element index (for repeated fields) to 218 value, or just pb unchanged if it's neither. 219 220 Args: 221 pb: A proto2 message or a primitive. 222 Returns: 223 A dict or unchanged pb. 224 """ 225 if isinstance(pb, message.Message): 226 return dict((desc.number, value) for desc, value in pb.ListFields()) 227 elif _IsMap(pb): 228 return dict(pb.items()) 229 elif _IsRepeatedContainer(pb): 230 return dict(enumerate(list(pb))) 231 else: 232 return pb 233 234 a, b = Format(a), Format(b) 235 236 # Base case 237 if not isinstance(a, dict) or not isinstance(b, dict): 238 return a == b 239 240 # This list performs double duty: it compares two messages by tag value *or* 241 # two repeated fields by element, in order. the magic is in the format() 242 # function, which converts them both to the same easily comparable format. 243 for tag in sorted(set(a.keys()) | set(b.keys())): 244 if tag not in a or tag not in b: 245 return False 246 else: 247 # Recursive step 248 if not ProtoEq(a[tag], b[tag]): 249 return False 250 251 # Didn't find any values that differed, so they're equal! 252 return True 253 254 255class ProtoAssertions(object): 256 """Mix this into a googletest.TestCase class to get proto2 assertions. 257 258 Usage: 259 260 class SomeTestCase(compare.ProtoAssertions, googletest.TestCase): 261 ... 262 def testSomething(self): 263 ... 264 self.assertProtoEqual(a, b) 265 266 See module-level definitions for method documentation. 267 """ 268 269 # pylint: disable=invalid-name 270 def assertProtoEqual(self, *args, **kwargs): 271 return assertProtoEqual(self, *args, **kwargs) 272