1# -*- coding: utf-8 -*- 2# Protocol Buffers - Google's data interchange format 3# Copyright 2008 Google Inc. All rights reserved. 4# https://developers.google.com/protocol-buffers/ 5# 6# Redistribution and use in source and binary forms, with or without 7# modification, are permitted provided that the following conditions are 8# met: 9# 10# * Redistributions of source code must retain the above copyright 11# notice, this list of conditions and the following disclaimer. 12# * Redistributions in binary form must reproduce the above 13# copyright notice, this list of conditions and the following disclaimer 14# in the documentation and/or other materials provided with the 15# distribution. 16# * Neither the name of Google Inc. nor the names of its 17# contributors may be used to endorse or promote products derived from 18# this software without specific prior written permission. 19# 20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32"""Test for preservation of unknown fields in the pure Python implementation.""" 33 34__author__ = '[email protected] (Bohdan Koval)' 35 36import sys 37import unittest 38 39from google.protobuf import map_unittest_pb2 40from google.protobuf import unittest_mset_pb2 41from google.protobuf import unittest_pb2 42from google.protobuf import unittest_proto3_arena_pb2 43from google.protobuf.internal import api_implementation 44from google.protobuf.internal import encoder 45from google.protobuf.internal import message_set_extensions_pb2 46from google.protobuf.internal import missing_enum_values_pb2 47from google.protobuf.internal import test_util 48from google.protobuf.internal import testing_refleaks 49from google.protobuf.internal import type_checkers 50from google.protobuf.internal import wire_format 51from google.protobuf import descriptor 52from google.protobuf import unknown_fields 53try: 54 import tracemalloc # pylint: disable=g-import-not-at-top 55except ImportError: 56 # Requires python 3.4+ 57 pass 58 59 60@testing_refleaks.TestCase 61class UnknownFieldsTest(unittest.TestCase): 62 63 def setUp(self): 64 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 65 self.all_fields = unittest_pb2.TestAllTypes() 66 test_util.SetAllFields(self.all_fields) 67 self.all_fields_data = self.all_fields.SerializeToString() 68 self.empty_message = unittest_pb2.TestEmptyMessage() 69 self.empty_message.ParseFromString(self.all_fields_data) 70 71 def testSerialize(self): 72 data = self.empty_message.SerializeToString() 73 74 # Don't use assertEqual because we don't want to dump raw binary data to 75 # stdout. 76 self.assertTrue(data == self.all_fields_data) 77 78 def testSerializeProto3(self): 79 # Verify proto3 unknown fields behavior. 80 message = unittest_proto3_arena_pb2.TestEmptyMessage() 81 message.ParseFromString(self.all_fields_data) 82 self.assertEqual(self.all_fields_data, message.SerializeToString()) 83 84 def testByteSize(self): 85 self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) 86 87 def testListFields(self): 88 # Make sure ListFields doesn't return unknown fields. 89 self.assertEqual(0, len(self.empty_message.ListFields())) 90 91 def testSerializeMessageSetWireFormatUnknownExtension(self): 92 # Create a message using the message set wire format with an unknown 93 # message. 94 raw = unittest_mset_pb2.RawMessageSet() 95 96 # Add an unknown extension. 97 item = raw.item.add() 98 item.type_id = 98218603 99 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 100 message1.i = 12345 101 item.message = message1.SerializeToString() 102 103 serialized = raw.SerializeToString() 104 105 # Parse message using the message set wire format. 106 proto = message_set_extensions_pb2.TestMessageSet() 107 proto.MergeFromString(serialized) 108 109 unknown_field_set = unknown_fields.UnknownFieldSet(proto) 110 self.assertEqual(len(unknown_field_set), 1) 111 # Unknown field should have wire format data which can be parsed back to 112 # original message. 113 self.assertEqual(unknown_field_set[0].field_number, item.type_id) 114 self.assertEqual(unknown_field_set[0].wire_type, 115 wire_format.WIRETYPE_LENGTH_DELIMITED) 116 d = unknown_field_set[0].data 117 message_new = message_set_extensions_pb2.TestMessageSetExtension1() 118 message_new.ParseFromString(d) 119 self.assertEqual(message1, message_new) 120 121 # Verify that the unknown extension is serialized unchanged 122 reserialized = proto.SerializeToString() 123 new_raw = unittest_mset_pb2.RawMessageSet() 124 new_raw.MergeFromString(reserialized) 125 self.assertEqual(raw, new_raw) 126 127 def testEquals(self): 128 message = unittest_pb2.TestEmptyMessage() 129 message.ParseFromString(self.all_fields_data) 130 self.assertEqual(self.empty_message, message) 131 132 self.all_fields.ClearField('optional_string') 133 message.ParseFromString(self.all_fields.SerializeToString()) 134 self.assertNotEqual(self.empty_message, message) 135 136 def testDiscardUnknownFields(self): 137 self.empty_message.DiscardUnknownFields() 138 self.assertEqual(b'', self.empty_message.SerializeToString()) 139 # Test message field and repeated message field. 140 message = unittest_pb2.TestAllTypes() 141 other_message = unittest_pb2.TestAllTypes() 142 other_message.optional_string = 'discard' 143 message.optional_nested_message.ParseFromString( 144 other_message.SerializeToString()) 145 message.repeated_nested_message.add().ParseFromString( 146 other_message.SerializeToString()) 147 self.assertNotEqual( 148 b'', message.optional_nested_message.SerializeToString()) 149 self.assertNotEqual( 150 b'', message.repeated_nested_message[0].SerializeToString()) 151 message.DiscardUnknownFields() 152 self.assertEqual(b'', message.optional_nested_message.SerializeToString()) 153 self.assertEqual( 154 b'', message.repeated_nested_message[0].SerializeToString()) 155 156 msg = map_unittest_pb2.TestMap() 157 msg.map_int32_all_types[1].optional_nested_message.ParseFromString( 158 other_message.SerializeToString()) 159 msg.map_string_string['1'] = 'test' 160 self.assertNotEqual( 161 b'', 162 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 163 msg.DiscardUnknownFields() 164 self.assertEqual( 165 b'', 166 msg.map_int32_all_types[1].optional_nested_message.SerializeToString()) 167 168 169@testing_refleaks.TestCase 170class UnknownFieldsAccessorsTest(unittest.TestCase): 171 172 def setUp(self): 173 self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR 174 self.all_fields = unittest_pb2.TestAllTypes() 175 test_util.SetAllFields(self.all_fields) 176 self.all_fields_data = self.all_fields.SerializeToString() 177 self.empty_message = unittest_pb2.TestEmptyMessage() 178 self.empty_message.ParseFromString(self.all_fields_data) 179 180 # InternalCheckUnknownField() is an additional Pure Python check which checks 181 # a detail of unknown fields. It cannot be used by the C++ 182 # implementation because some protect members are called. 183 # The test is added for historical reasons. It is not necessary as 184 # serialized string is checked. 185 # TODO(jieluo): Remove message._unknown_fields. 186 def InternalCheckUnknownField(self, name, expected_value): 187 if api_implementation.Type() != 'python': 188 return 189 field_descriptor = self.descriptor.fields_by_name[name] 190 wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] 191 field_tag = encoder.TagBytes(field_descriptor.number, wire_type) 192 result_dict = {} 193 for tag_bytes, value in self.empty_message._unknown_fields: 194 if tag_bytes == field_tag: 195 decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] 196 decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) 197 self.assertEqual(expected_value, result_dict[field_descriptor]) 198 199 def CheckUnknownField(self, name, unknown_field_set, expected_value): 200 field_descriptor = self.descriptor.fields_by_name[name] 201 expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ 202 field_descriptor.type] 203 for unknown_field in unknown_field_set: 204 if unknown_field.field_number == field_descriptor.number: 205 self.assertEqual(expected_type, unknown_field.wire_type) 206 if expected_type == 3: 207 # Check group 208 self.assertEqual(expected_value[0], 209 unknown_field.data[0].field_number) 210 self.assertEqual(expected_value[1], unknown_field.data[0].wire_type) 211 self.assertEqual(expected_value[2], unknown_field.data[0].data) 212 continue 213 if expected_type == wire_format.WIRETYPE_LENGTH_DELIMITED: 214 self.assertIn(type(unknown_field.data), (str, bytes)) 215 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 216 self.assertIn(unknown_field.data, expected_value) 217 else: 218 self.assertEqual(expected_value, unknown_field.data) 219 220 def testCheckUnknownFieldValue(self): 221 unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message) 222 # Test enum. 223 self.CheckUnknownField('optional_nested_enum', 224 unknown_field_set, 225 self.all_fields.optional_nested_enum) 226 self.InternalCheckUnknownField('optional_nested_enum', 227 self.all_fields.optional_nested_enum) 228 229 # Test repeated enum. 230 self.CheckUnknownField('repeated_nested_enum', 231 unknown_field_set, 232 self.all_fields.repeated_nested_enum) 233 self.InternalCheckUnknownField('repeated_nested_enum', 234 self.all_fields.repeated_nested_enum) 235 236 # Test varint. 237 self.CheckUnknownField('optional_int32', 238 unknown_field_set, 239 self.all_fields.optional_int32) 240 self.InternalCheckUnknownField('optional_int32', 241 self.all_fields.optional_int32) 242 243 # Test fixed32. 244 self.CheckUnknownField('optional_fixed32', 245 unknown_field_set, 246 self.all_fields.optional_fixed32) 247 self.InternalCheckUnknownField('optional_fixed32', 248 self.all_fields.optional_fixed32) 249 250 # Test fixed64. 251 self.CheckUnknownField('optional_fixed64', 252 unknown_field_set, 253 self.all_fields.optional_fixed64) 254 self.InternalCheckUnknownField('optional_fixed64', 255 self.all_fields.optional_fixed64) 256 257 # Test length delimited. 258 self.CheckUnknownField('optional_string', 259 unknown_field_set, 260 self.all_fields.optional_string.encode('utf-8')) 261 self.InternalCheckUnknownField('optional_string', 262 self.all_fields.optional_string) 263 264 # Test group. 265 self.CheckUnknownField('optionalgroup', 266 unknown_field_set, 267 (17, 0, 117)) 268 self.InternalCheckUnknownField('optionalgroup', 269 self.all_fields.optionalgroup) 270 271 self.assertEqual(98, len(unknown_field_set)) 272 273 def testCopyFrom(self): 274 message = unittest_pb2.TestEmptyMessage() 275 message.CopyFrom(self.empty_message) 276 self.assertEqual(message.SerializeToString(), self.all_fields_data) 277 278 def testMergeFrom(self): 279 message = unittest_pb2.TestAllTypes() 280 message.optional_int32 = 1 281 message.optional_uint32 = 2 282 source = unittest_pb2.TestEmptyMessage() 283 source.ParseFromString(message.SerializeToString()) 284 285 message.ClearField('optional_int32') 286 message.optional_int64 = 3 287 message.optional_uint32 = 4 288 destination = unittest_pb2.TestEmptyMessage() 289 unknown_field_set = unknown_fields.UnknownFieldSet(destination) 290 self.assertEqual(0, len(unknown_field_set)) 291 destination.ParseFromString(message.SerializeToString()) 292 self.assertEqual(0, len(unknown_field_set)) 293 unknown_field_set = unknown_fields.UnknownFieldSet(destination) 294 self.assertEqual(2, len(unknown_field_set)) 295 destination.MergeFrom(source) 296 self.assertEqual(2, len(unknown_field_set)) 297 # Check that the fields where correctly merged, even stored in the unknown 298 # fields set. 299 message.ParseFromString(destination.SerializeToString()) 300 self.assertEqual(message.optional_int32, 1) 301 self.assertEqual(message.optional_uint32, 2) 302 self.assertEqual(message.optional_int64, 3) 303 304 def testClear(self): 305 unknown_field_set = unknown_fields.UnknownFieldSet(self.empty_message) 306 self.empty_message.Clear() 307 # All cleared, even unknown fields. 308 self.assertEqual(self.empty_message.SerializeToString(), b'') 309 self.assertEqual(len(unknown_field_set), 98) 310 311 @unittest.skipIf((sys.version_info.major, sys.version_info.minor) < (3, 4), 312 'tracemalloc requires python 3.4+') 313 def testUnknownFieldsNoMemoryLeak(self): 314 # Call to UnknownFields must not leak memory 315 nb_leaks = 1234 316 317 def leaking_function(): 318 for _ in range(nb_leaks): 319 unknown_fields.UnknownFieldSet(self.empty_message) 320 321 tracemalloc.start() 322 snapshot1 = tracemalloc.take_snapshot() 323 leaking_function() 324 snapshot2 = tracemalloc.take_snapshot() 325 top_stats = snapshot2.compare_to(snapshot1, 'lineno') 326 tracemalloc.stop() 327 # There's no easy way to look for a precise leak source. 328 # Rely on a "marker" count value while checking allocated memory. 329 self.assertEqual([], [x for x in top_stats if x.count_diff == nb_leaks]) 330 331 def testSubUnknownFields(self): 332 message = unittest_pb2.TestAllTypes() 333 message.optionalgroup.a = 123 334 destination = unittest_pb2.TestEmptyMessage() 335 destination.ParseFromString(message.SerializeToString()) 336 sub_unknown_fields = unknown_fields.UnknownFieldSet(destination)[0].data 337 self.assertEqual(1, len(sub_unknown_fields)) 338 self.assertEqual(sub_unknown_fields[0].data, 123) 339 destination.Clear() 340 self.assertEqual(1, len(sub_unknown_fields)) 341 self.assertEqual(sub_unknown_fields[0].data, 123) 342 message.Clear() 343 message.optional_uint32 = 456 344 nested_message = unittest_pb2.NestedTestAllTypes() 345 nested_message.payload.optional_nested_message.ParseFromString( 346 message.SerializeToString()) 347 unknown_field_set = unknown_fields.UnknownFieldSet( 348 nested_message.payload.optional_nested_message) 349 self.assertEqual(unknown_field_set[0].data, 456) 350 nested_message.ClearField('payload') 351 self.assertEqual(unknown_field_set[0].data, 456) 352 unknown_field_set = unknown_fields.UnknownFieldSet( 353 nested_message.payload.optional_nested_message) 354 self.assertEqual(0, len(unknown_field_set)) 355 356 def testUnknownField(self): 357 message = unittest_pb2.TestAllTypes() 358 message.optional_int32 = 123 359 destination = unittest_pb2.TestEmptyMessage() 360 destination.ParseFromString(message.SerializeToString()) 361 unknown_field = unknown_fields.UnknownFieldSet(destination)[0] 362 destination.Clear() 363 self.assertEqual(unknown_field.data, 123) 364 365 def testUnknownExtensions(self): 366 message = unittest_pb2.TestEmptyMessageWithExtensions() 367 message.ParseFromString(self.all_fields_data) 368 self.assertEqual(len(unknown_fields.UnknownFieldSet(message)), 98) 369 self.assertEqual(message.SerializeToString(), self.all_fields_data) 370 371 372@testing_refleaks.TestCase 373class UnknownEnumValuesTest(unittest.TestCase): 374 375 def setUp(self): 376 self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR 377 378 self.message = missing_enum_values_pb2.TestEnumValues() 379 # TestEnumValues.ZERO = 0, but does not exist in the other NestedEnum. 380 self.message.optional_nested_enum = ( 381 missing_enum_values_pb2.TestEnumValues.ZERO) 382 self.message.repeated_nested_enum.extend([ 383 missing_enum_values_pb2.TestEnumValues.ZERO, 384 missing_enum_values_pb2.TestEnumValues.ONE, 385 ]) 386 self.message.packed_nested_enum.extend([ 387 missing_enum_values_pb2.TestEnumValues.ZERO, 388 missing_enum_values_pb2.TestEnumValues.ONE, 389 ]) 390 self.message_data = self.message.SerializeToString() 391 self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() 392 self.missing_message.ParseFromString(self.message_data) 393 394 # CheckUnknownField() is an additional Pure Python check which checks 395 # a detail of unknown fields. It cannot be used by the C++ 396 # implementation because some protect members are called. 397 # The test is added for historical reasons. It is not necessary as 398 # serialized string is checked. 399 400 def CheckUnknownField(self, name, expected_value): 401 field_descriptor = self.descriptor.fields_by_name[name] 402 unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message) 403 self.assertIsInstance(unknown_field_set, unknown_fields.UnknownFieldSet) 404 count = 0 405 for field in unknown_field_set: 406 if field.field_number == field_descriptor.number: 407 count += 1 408 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 409 self.assertIn(field.data, expected_value) 410 else: 411 self.assertEqual(expected_value, field.data) 412 if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: 413 self.assertEqual(count, len(expected_value)) 414 else: 415 self.assertEqual(count, 1) 416 417 def testUnknownParseMismatchEnumValue(self): 418 just_string = missing_enum_values_pb2.JustString() 419 just_string.dummy = 'blah' 420 421 missing = missing_enum_values_pb2.TestEnumValues() 422 # The parse is invalid, storing the string proto into the set of 423 # unknown fields. 424 missing.ParseFromString(just_string.SerializeToString()) 425 426 # Fetching the enum field shouldn't crash, instead returning the 427 # default value. 428 self.assertEqual(missing.optional_nested_enum, 0) 429 430 def testUnknownEnumValue(self): 431 self.assertFalse(self.missing_message.HasField('optional_nested_enum')) 432 self.assertEqual(self.missing_message.optional_nested_enum, 2) 433 # Clear does not do anything. 434 serialized = self.missing_message.SerializeToString() 435 self.missing_message.ClearField('optional_nested_enum') 436 self.assertEqual(self.missing_message.SerializeToString(), serialized) 437 438 def testUnknownRepeatedEnumValue(self): 439 self.assertEqual([], self.missing_message.repeated_nested_enum) 440 441 def testUnknownPackedEnumValue(self): 442 self.assertEqual([], self.missing_message.packed_nested_enum) 443 444 def testCheckUnknownFieldValueForEnum(self): 445 unknown_field_set = unknown_fields.UnknownFieldSet(self.missing_message) 446 self.assertEqual(len(unknown_field_set), 5) 447 self.CheckUnknownField('optional_nested_enum', 448 self.message.optional_nested_enum) 449 self.CheckUnknownField('repeated_nested_enum', 450 self.message.repeated_nested_enum) 451 self.CheckUnknownField('packed_nested_enum', 452 self.message.packed_nested_enum) 453 454 def testRoundTrip(self): 455 new_message = missing_enum_values_pb2.TestEnumValues() 456 new_message.ParseFromString(self.missing_message.SerializeToString()) 457 self.assertEqual(self.message, new_message) 458 459 460if __name__ == '__main__': 461 unittest.main() 462