1# -*- coding: utf-8 -*- 2# Protocol Buffers - Google's data interchange format 3# Copyright 2008 Google Inc. All rights reserved. 4# https://developers.google.com/protocol-buffers/ 5# 6# Redistribution and use in source and binary forms, with or without 7# modification, are permitted provided that the following conditions are 8# met: 9# 10# * Redistributions of source code must retain the above copyright 11# notice, this list of conditions and the following disclaimer. 12# * Redistributions in binary form must reproduce the above 13# copyright notice, this list of conditions and the following disclaimer 14# in the documentation and/or other materials provided with the 15# distribution. 16# * Neither the name of Google Inc. nor the names of its 17# contributors may be used to endorse or promote products derived from 18# this software without specific prior written permission. 19# 20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32"""Unittest for reflection.py, which also indirectly tests the output of the 33pure-Python protocol compiler. 34""" 35 36import copy 37import gc 38import operator 39import struct 40import sys 41import warnings 42import unittest 43 44from google.protobuf import unittest_import_pb2 45from google.protobuf import unittest_mset_pb2 46from google.protobuf import unittest_pb2 47from google.protobuf import unittest_proto3_arena_pb2 48from google.protobuf import descriptor_pb2 49from google.protobuf import descriptor 50from google.protobuf import message 51from google.protobuf import reflection 52from google.protobuf import text_format 53from google.protobuf.internal import api_implementation 54from google.protobuf.internal import more_extensions_pb2 55from google.protobuf.internal import more_messages_pb2 56from google.protobuf.internal import message_set_extensions_pb2 57from google.protobuf.internal import wire_format 58from google.protobuf.internal import test_util 59from google.protobuf.internal import testing_refleaks 60from google.protobuf.internal import decoder 61from google.protobuf.internal import _parameterized 62 63 64warnings.simplefilter('error', DeprecationWarning) 65 66 67class _MiniDecoder(object): 68 """Decodes a stream of values from a string. 69 70 Once upon a time we actually had a class called decoder.Decoder. Then we 71 got rid of it during a redesign that made decoding much, much faster overall. 72 But a couple tests in this file used it to check that the serialized form of 73 a message was correct. So, this class implements just the methods that were 74 used by said tests, so that we don't have to rewrite the tests. 75 """ 76 77 def __init__(self, bytes): 78 self._bytes = bytes 79 self._pos = 0 80 81 def ReadVarint(self): 82 result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) 83 return result 84 85 ReadInt32 = ReadVarint 86 ReadInt64 = ReadVarint 87 ReadUInt32 = ReadVarint 88 ReadUInt64 = ReadVarint 89 90 def ReadSInt64(self): 91 return wire_format.ZigZagDecode(self.ReadVarint()) 92 93 ReadSInt32 = ReadSInt64 94 95 def ReadFieldNumberAndWireType(self): 96 return wire_format.UnpackTag(self.ReadVarint()) 97 98 def ReadFloat(self): 99 result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0] 100 self._pos += 4 101 return result 102 103 def ReadDouble(self): 104 result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0] 105 self._pos += 8 106 return result 107 108 def EndOfStream(self): 109 return self._pos == len(self._bytes) 110 111 112@_parameterized.named_parameters( 113 ('_proto2', unittest_pb2), 114 ('_proto3', unittest_proto3_arena_pb2)) 115@testing_refleaks.TestCase 116class ReflectionTest(unittest.TestCase): 117 118 def assertListsEqual(self, values, others): 119 self.assertEqual(len(values), len(others)) 120 for i in range(len(values)): 121 self.assertEqual(values[i], others[i]) 122 123 def testScalarConstructor(self, message_module): 124 # Constructor with only scalar types should succeed. 125 proto = message_module.TestAllTypes( 126 optional_int32=24, 127 optional_double=54.321, 128 optional_string='optional_string', 129 optional_float=None) 130 131 self.assertEqual(24, proto.optional_int32) 132 self.assertEqual(54.321, proto.optional_double) 133 self.assertEqual('optional_string', proto.optional_string) 134 if message_module is unittest_pb2: 135 self.assertFalse(proto.HasField("optional_float")) 136 137 def testRepeatedScalarConstructor(self, message_module): 138 # Constructor with only repeated scalar types should succeed. 139 proto = message_module.TestAllTypes( 140 repeated_int32=[1, 2, 3, 4], 141 repeated_double=[1.23, 54.321], 142 repeated_bool=[True, False, False], 143 repeated_string=["optional_string"], 144 repeated_float=None) 145 146 self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) 147 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 148 self.assertEqual([True, False, False], list(proto.repeated_bool)) 149 self.assertEqual(["optional_string"], list(proto.repeated_string)) 150 self.assertEqual([], list(proto.repeated_float)) 151 152 def testMixedConstructor(self, message_module): 153 # Constructor with only mixed types should succeed. 154 proto = message_module.TestAllTypes( 155 optional_int32=24, 156 optional_string='optional_string', 157 repeated_double=[1.23, 54.321], 158 repeated_bool=[True, False, False], 159 repeated_nested_message=[ 160 message_module.TestAllTypes.NestedMessage( 161 bb=message_module.TestAllTypes.FOO), 162 message_module.TestAllTypes.NestedMessage( 163 bb=message_module.TestAllTypes.BAR)], 164 repeated_foreign_message=[ 165 message_module.ForeignMessage(c=-43), 166 message_module.ForeignMessage(c=45324), 167 message_module.ForeignMessage(c=12)], 168 optional_nested_message=None) 169 170 self.assertEqual(24, proto.optional_int32) 171 self.assertEqual('optional_string', proto.optional_string) 172 self.assertEqual([1.23, 54.321], list(proto.repeated_double)) 173 self.assertEqual([True, False, False], list(proto.repeated_bool)) 174 self.assertEqual( 175 [message_module.TestAllTypes.NestedMessage( 176 bb=message_module.TestAllTypes.FOO), 177 message_module.TestAllTypes.NestedMessage( 178 bb=message_module.TestAllTypes.BAR)], 179 list(proto.repeated_nested_message)) 180 self.assertEqual( 181 [message_module.ForeignMessage(c=-43), 182 message_module.ForeignMessage(c=45324), 183 message_module.ForeignMessage(c=12)], 184 list(proto.repeated_foreign_message)) 185 self.assertFalse(proto.HasField("optional_nested_message")) 186 187 def testConstructorTypeError(self, message_module): 188 self.assertRaises( 189 TypeError, message_module.TestAllTypes, optional_int32='foo') 190 self.assertRaises( 191 TypeError, message_module.TestAllTypes, optional_string=1234) 192 self.assertRaises( 193 TypeError, message_module.TestAllTypes, optional_nested_message=1234) 194 self.assertRaises( 195 TypeError, message_module.TestAllTypes, repeated_int32=1234) 196 self.assertRaises( 197 TypeError, message_module.TestAllTypes, repeated_int32=['foo']) 198 self.assertRaises( 199 TypeError, message_module.TestAllTypes, repeated_string=1234) 200 self.assertRaises( 201 TypeError, message_module.TestAllTypes, repeated_string=[1234]) 202 self.assertRaises( 203 TypeError, message_module.TestAllTypes, repeated_nested_message=1234) 204 self.assertRaises( 205 TypeError, message_module.TestAllTypes, repeated_nested_message=[1234]) 206 207 def testConstructorInvalidatesCachedByteSize(self, message_module): 208 message = message_module.TestAllTypes(optional_int32=12) 209 self.assertEqual(2, message.ByteSize()) 210 211 message = message_module.TestAllTypes( 212 optional_nested_message=message_module.TestAllTypes.NestedMessage()) 213 self.assertEqual(3, message.ByteSize()) 214 215 message = message_module.TestAllTypes(repeated_int32=[12]) 216 # TODO(jieluo): Add this test back for proto3 217 if message_module is unittest_pb2: 218 self.assertEqual(3, message.ByteSize()) 219 220 message = message_module.TestAllTypes( 221 repeated_nested_message=[message_module.TestAllTypes.NestedMessage()]) 222 self.assertEqual(3, message.ByteSize()) 223 224 def testReferencesToNestedMessage(self, message_module): 225 proto = message_module.TestAllTypes() 226 nested = proto.optional_nested_message 227 del proto 228 # A previous version had a bug where this would raise an exception when 229 # hitting a now-dead weak reference. 230 nested.bb = 23 231 232 def testOneOf(self, message_module): 233 proto = message_module.TestAllTypes() 234 proto.oneof_uint32 = 10 235 proto.oneof_nested_message.bb = 11 236 self.assertEqual(11, proto.oneof_nested_message.bb) 237 self.assertFalse(proto.HasField('oneof_uint32')) 238 nested = proto.oneof_nested_message 239 proto.oneof_string = 'abc' 240 self.assertEqual('abc', proto.oneof_string) 241 self.assertEqual(11, nested.bb) 242 self.assertFalse(proto.HasField('oneof_nested_message')) 243 244 def testGetDefaultMessageAfterDisconnectingDefaultMessage( 245 self, message_module): 246 proto = message_module.TestAllTypes() 247 nested = proto.optional_nested_message 248 proto.ClearField('optional_nested_message') 249 del proto 250 del nested 251 # Force a garbage collect so that the underlying CMessages are freed along 252 # with the Messages they point to. This is to make sure we're not deleting 253 # default message instances. 254 gc.collect() 255 proto = message_module.TestAllTypes() 256 nested = proto.optional_nested_message 257 258 def testDisconnectingNestedMessageAfterSettingField(self, message_module): 259 proto = message_module.TestAllTypes() 260 nested = proto.optional_nested_message 261 nested.bb = 5 262 self.assertTrue(proto.HasField('optional_nested_message')) 263 proto.ClearField('optional_nested_message') # Should disconnect from parent 264 self.assertEqual(5, nested.bb) 265 self.assertEqual(0, proto.optional_nested_message.bb) 266 self.assertIsNot(nested, proto.optional_nested_message) 267 nested.bb = 23 268 self.assertFalse(proto.HasField('optional_nested_message')) 269 self.assertEqual(0, proto.optional_nested_message.bb) 270 271 def testDisconnectingNestedMessageBeforeGettingField(self, message_module): 272 proto = message_module.TestAllTypes() 273 self.assertFalse(proto.HasField('optional_nested_message')) 274 proto.ClearField('optional_nested_message') 275 self.assertFalse(proto.HasField('optional_nested_message')) 276 277 def testDisconnectingNestedMessageAfterMerge(self, message_module): 278 # This test exercises the code path that does not use ReleaseMessage(). 279 # The underlying fear is that if we use ReleaseMessage() incorrectly, 280 # we will have memory leaks. It's hard to check that that doesn't happen, 281 # but at least we can exercise that code path to make sure it works. 282 proto1 = message_module.TestAllTypes() 283 proto2 = message_module.TestAllTypes() 284 proto2.optional_nested_message.bb = 5 285 proto1.MergeFrom(proto2) 286 self.assertTrue(proto1.HasField('optional_nested_message')) 287 proto1.ClearField('optional_nested_message') 288 self.assertFalse(proto1.HasField('optional_nested_message')) 289 290 def testDisconnectingLazyNestedMessage(self, message_module): 291 # This test exercises releasing a nested message that is lazy. This test 292 # only exercises real code in the C++ implementation as Python does not 293 # support lazy parsing, but the current C++ implementation results in 294 # memory corruption and a crash. 295 if api_implementation.Type() != 'python': 296 return 297 proto = message_module.TestAllTypes() 298 proto.optional_lazy_message.bb = 5 299 proto.ClearField('optional_lazy_message') 300 del proto 301 gc.collect() 302 303 def testSingularListFields(self, message_module): 304 proto = message_module.TestAllTypes() 305 proto.optional_fixed32 = 1 306 proto.optional_int32 = 5 307 proto.optional_string = 'foo' 308 # Access sub-message but don't set it yet. 309 nested_message = proto.optional_nested_message 310 self.assertEqual( 311 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 312 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 313 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], 314 proto.ListFields()) 315 316 proto.optional_nested_message.bb = 123 317 self.assertEqual( 318 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), 319 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), 320 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), 321 (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], 322 nested_message) ], 323 proto.ListFields()) 324 325 def testRepeatedListFields(self, message_module): 326 proto = message_module.TestAllTypes() 327 proto.repeated_fixed32.append(1) 328 proto.repeated_int32.append(5) 329 proto.repeated_int32.append(11) 330 proto.repeated_string.extend(['foo', 'bar']) 331 proto.repeated_string.extend([]) 332 proto.repeated_string.append('baz') 333 proto.repeated_string.extend(str(x) for x in range(2)) 334 proto.optional_int32 = 21 335 proto.repeated_bool # Access but don't set anything; should not be listed. 336 self.assertEqual( 337 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), 338 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), 339 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), 340 (proto.DESCRIPTOR.fields_by_name['repeated_string' ], 341 ['foo', 'bar', 'baz', '0', '1']) ], 342 proto.ListFields()) 343 344 def testClearFieldWithUnknownFieldName(self, message_module): 345 proto = message_module.TestAllTypes() 346 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') 347 self.assertRaises(ValueError, proto.ClearField, b'nonexistent_field') 348 349 def testDisallowedAssignments(self, message_module): 350 # It's illegal to assign values directly to repeated fields 351 # or to nonrepeated composite fields. Ensure that this fails. 352 proto = message_module.TestAllTypes() 353 # Repeated fields. 354 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) 355 # Lists shouldn't work, either. 356 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) 357 # Composite fields. 358 self.assertRaises(AttributeError, setattr, proto, 359 'optional_nested_message', 23) 360 # Assignment to a repeated nested message field without specifying 361 # the index in the array of nested messages. 362 self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, 363 'bb', 34) 364 # Assignment to an attribute of a repeated field. 365 self.assertRaises(AttributeError, setattr, proto.repeated_float, 366 'some_attribute', 34) 367 # proto.nonexistent_field = 23 should fail as well. 368 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) 369 370 def testSingleScalarTypeSafety(self, message_module): 371 proto = message_module.TestAllTypes() 372 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) 373 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') 374 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) 375 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) 376 self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo') 377 self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo') 378 self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo') 379 # TODO(jieluo): Fix type checking difference for python and c extension 380 if (api_implementation.Type() == 'python' or 381 (sys.version_info.major, sys.version_info.minor) >= (3, 10)): 382 self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1) 383 else: 384 proto.optional_bool = 1.1 385 386 def assertIntegerTypes(self, integer_fn, message_module): 387 """Verifies setting of scalar integers. 388 389 Args: 390 integer_fn: A function to wrap the integers that will be assigned. 391 message_module: unittest_pb2 or unittest_proto3_arena_pb2 392 """ 393 def TestGetAndDeserialize(field_name, value, expected_type): 394 proto = message_module.TestAllTypes() 395 value = integer_fn(value) 396 setattr(proto, field_name, value) 397 self.assertIsInstance(getattr(proto, field_name), expected_type) 398 proto2 = message_module.TestAllTypes() 399 proto2.ParseFromString(proto.SerializeToString()) 400 self.assertIsInstance(getattr(proto2, field_name), expected_type) 401 402 TestGetAndDeserialize('optional_int32', 1, int) 403 TestGetAndDeserialize('optional_int32', 1 << 30, int) 404 TestGetAndDeserialize('optional_uint32', 1 << 30, int) 405 integer_64 = int 406 if struct.calcsize('L') == 4: 407 # Python only has signed ints, so 32-bit python can't fit an uint32 408 # in an int. 409 TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64) 410 else: 411 # 64-bit python can fit uint32 inside an int 412 TestGetAndDeserialize('optional_uint32', 1 << 31, int) 413 TestGetAndDeserialize('optional_int64', 1 << 30, integer_64) 414 TestGetAndDeserialize('optional_int64', 1 << 60, integer_64) 415 TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64) 416 TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64) 417 418 def testIntegerTypes(self, message_module): 419 self.assertIntegerTypes(lambda x: x, message_module) 420 421 def testNonStandardIntegerTypes(self, message_module): 422 self.assertIntegerTypes(test_util.NonStandardInteger, message_module) 423 424 def testIllegalValuesForIntegers(self, message_module): 425 pb = message_module.TestAllTypes() 426 427 # Strings are illegal, even when the represent an integer. 428 with self.assertRaises(TypeError): 429 pb.optional_uint64 = '2' 430 431 # The exact error should propagate with a poorly written custom integer. 432 with self.assertRaisesRegex(RuntimeError, 'my_error'): 433 pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error') 434 435 def assetIntegerBoundsChecking(self, integer_fn, message_module): 436 """Verifies bounds checking for scalar integer fields. 437 438 Args: 439 integer_fn: A function to wrap the integers that will be assigned. 440 message_module: unittest_pb2 or unittest_proto3_arena_pb2 441 """ 442 def TestMinAndMaxIntegers(field_name, expected_min, expected_max): 443 pb = message_module.TestAllTypes() 444 expected_min = integer_fn(expected_min) 445 expected_max = integer_fn(expected_max) 446 setattr(pb, field_name, expected_min) 447 self.assertEqual(expected_min, getattr(pb, field_name)) 448 setattr(pb, field_name, expected_max) 449 self.assertEqual(expected_max, getattr(pb, field_name)) 450 self.assertRaises((ValueError, TypeError), setattr, pb, field_name, 451 expected_min - 1) 452 self.assertRaises((ValueError, TypeError), setattr, pb, field_name, 453 expected_max + 1) 454 455 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) 456 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) 457 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) 458 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) 459 # A bit of white-box testing since -1 is an int and not a long in C++ and 460 # so goes down a different path. 461 pb = message_module.TestAllTypes() 462 with self.assertRaises((ValueError, TypeError)): 463 pb.optional_uint64 = integer_fn(-(1 << 63)) 464 465 pb = message_module.TestAllTypes() 466 pb.optional_nested_enum = integer_fn(1) 467 self.assertEqual(1, pb.optional_nested_enum) 468 469 def testSingleScalarBoundsChecking(self, message_module): 470 self.assetIntegerBoundsChecking(lambda x: x, message_module) 471 472 def testNonStandardSingleScalarBoundsChecking(self, message_module): 473 self.assetIntegerBoundsChecking( 474 test_util.NonStandardInteger, message_module) 475 476 def testRepeatedScalarTypeSafety(self, message_module): 477 proto = message_module.TestAllTypes() 478 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) 479 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') 480 self.assertRaises(TypeError, proto.repeated_string, 10) 481 self.assertRaises(TypeError, proto.repeated_bytes, 10) 482 483 proto.repeated_int32.append(10) 484 proto.repeated_int32[0] = 23 485 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) 486 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') 487 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, []) 488 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 489 'index', 23) 490 491 proto.repeated_string.append('2') 492 self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10) 493 494 # Repeated enums tests. 495 #proto.repeated_nested_enum.append(0) 496 497 def testSingleScalarGettersAndSetters(self, message_module): 498 proto = message_module.TestAllTypes() 499 self.assertEqual(0, proto.optional_int32) 500 proto.optional_int32 = 1 501 self.assertEqual(1, proto.optional_int32) 502 503 proto.optional_uint64 = 0xffffffffffff 504 self.assertEqual(0xffffffffffff, proto.optional_uint64) 505 proto.optional_uint64 = 0xffffffffffffffff 506 self.assertEqual(0xffffffffffffffff, proto.optional_uint64) 507 # TODO(robinson): Test all other scalar field types. 508 509 def testEnums(self, message_module): 510 proto = message_module.TestAllTypes() 511 self.assertEqual(1, proto.FOO) 512 self.assertEqual(1, message_module.TestAllTypes.FOO) 513 self.assertEqual(2, proto.BAR) 514 self.assertEqual(2, message_module.TestAllTypes.BAR) 515 self.assertEqual(3, proto.BAZ) 516 self.assertEqual(3, message_module.TestAllTypes.BAZ) 517 518 def testEnum_Name(self, message_module): 519 self.assertEqual( 520 'FOREIGN_FOO', 521 message_module.ForeignEnum.Name(message_module.FOREIGN_FOO)) 522 self.assertEqual( 523 'FOREIGN_BAR', 524 message_module.ForeignEnum.Name(message_module.FOREIGN_BAR)) 525 self.assertEqual( 526 'FOREIGN_BAZ', 527 message_module.ForeignEnum.Name(message_module.FOREIGN_BAZ)) 528 self.assertRaises(ValueError, 529 message_module.ForeignEnum.Name, 11312) 530 531 proto = message_module.TestAllTypes() 532 self.assertEqual('FOO', 533 proto.NestedEnum.Name(proto.FOO)) 534 self.assertEqual('FOO', 535 message_module.TestAllTypes.NestedEnum.Name(proto.FOO)) 536 self.assertEqual('BAR', 537 proto.NestedEnum.Name(proto.BAR)) 538 self.assertEqual('BAR', 539 message_module.TestAllTypes.NestedEnum.Name(proto.BAR)) 540 self.assertEqual('BAZ', 541 proto.NestedEnum.Name(proto.BAZ)) 542 self.assertEqual('BAZ', 543 message_module.TestAllTypes.NestedEnum.Name(proto.BAZ)) 544 self.assertRaises(ValueError, 545 proto.NestedEnum.Name, 11312) 546 self.assertRaises(ValueError, 547 message_module.TestAllTypes.NestedEnum.Name, 11312) 548 549 # Check some coercion cases. 550 self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name, 551 11312.0) 552 self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name, 553 None) 554 self.assertEqual('FOO', message_module.TestAllTypes.NestedEnum.Name(True)) 555 556 def testEnum_Value(self, message_module): 557 self.assertEqual(message_module.FOREIGN_FOO, 558 message_module.ForeignEnum.Value('FOREIGN_FOO')) 559 self.assertEqual(message_module.FOREIGN_FOO, 560 message_module.ForeignEnum.FOREIGN_FOO) 561 562 self.assertEqual(message_module.FOREIGN_BAR, 563 message_module.ForeignEnum.Value('FOREIGN_BAR')) 564 self.assertEqual(message_module.FOREIGN_BAR, 565 message_module.ForeignEnum.FOREIGN_BAR) 566 567 self.assertEqual(message_module.FOREIGN_BAZ, 568 message_module.ForeignEnum.Value('FOREIGN_BAZ')) 569 self.assertEqual(message_module.FOREIGN_BAZ, 570 message_module.ForeignEnum.FOREIGN_BAZ) 571 572 self.assertRaises(ValueError, 573 message_module.ForeignEnum.Value, 'FO') 574 with self.assertRaises(AttributeError): 575 message_module.ForeignEnum.FO 576 577 proto = message_module.TestAllTypes() 578 self.assertEqual(proto.FOO, 579 proto.NestedEnum.Value('FOO')) 580 self.assertEqual(proto.FOO, 581 proto.NestedEnum.FOO) 582 583 self.assertEqual(proto.FOO, 584 message_module.TestAllTypes.NestedEnum.Value('FOO')) 585 self.assertEqual(proto.FOO, 586 message_module.TestAllTypes.NestedEnum.FOO) 587 588 self.assertEqual(proto.BAR, 589 proto.NestedEnum.Value('BAR')) 590 self.assertEqual(proto.BAR, 591 proto.NestedEnum.BAR) 592 593 self.assertEqual(proto.BAR, 594 message_module.TestAllTypes.NestedEnum.Value('BAR')) 595 self.assertEqual(proto.BAR, 596 message_module.TestAllTypes.NestedEnum.BAR) 597 598 self.assertEqual(proto.BAZ, 599 proto.NestedEnum.Value('BAZ')) 600 self.assertEqual(proto.BAZ, 601 proto.NestedEnum.BAZ) 602 603 self.assertEqual(proto.BAZ, 604 message_module.TestAllTypes.NestedEnum.Value('BAZ')) 605 self.assertEqual(proto.BAZ, 606 message_module.TestAllTypes.NestedEnum.BAZ) 607 608 self.assertRaises(ValueError, 609 proto.NestedEnum.Value, 'Foo') 610 with self.assertRaises(AttributeError): 611 proto.NestedEnum.Value.Foo 612 613 self.assertRaises(ValueError, 614 message_module.TestAllTypes.NestedEnum.Value, 'Foo') 615 with self.assertRaises(AttributeError): 616 message_module.TestAllTypes.NestedEnum.Value.Foo 617 618 def testEnum_KeysAndValues(self, message_module): 619 if message_module == unittest_pb2: 620 keys = ['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'] 621 values = [4, 5, 6] 622 items = [('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)] 623 else: 624 keys = ['FOREIGN_ZERO', 'FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'] 625 values = [0, 4, 5, 6] 626 items = [('FOREIGN_ZERO', 0), ('FOREIGN_FOO', 4), 627 ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)] 628 self.assertEqual(keys, 629 list(message_module.ForeignEnum.keys())) 630 self.assertEqual(values, 631 list(message_module.ForeignEnum.values())) 632 self.assertEqual(items, 633 list(message_module.ForeignEnum.items())) 634 635 proto = message_module.TestAllTypes() 636 if message_module == unittest_pb2: 637 keys = ['FOO', 'BAR', 'BAZ', 'NEG'] 638 values = [1, 2, 3, -1] 639 items = [('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)] 640 else: 641 keys = ['ZERO', 'FOO', 'BAR', 'BAZ', 'NEG'] 642 values = [0, 1, 2, 3, -1] 643 items = [('ZERO', 0), ('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)] 644 self.assertEqual(keys, list(proto.NestedEnum.keys())) 645 self.assertEqual(values, list(proto.NestedEnum.values())) 646 self.assertEqual(items, 647 list(proto.NestedEnum.items())) 648 649 def testStaticParseFrom(self, message_module): 650 proto1 = message_module.TestAllTypes() 651 test_util.SetAllFields(proto1) 652 653 string1 = proto1.SerializeToString() 654 proto2 = message_module.TestAllTypes.FromString(string1) 655 656 # Messages should be equal. 657 self.assertEqual(proto2, proto1) 658 659 def testMergeFromSingularField(self, message_module): 660 # Test merge with just a singular field. 661 proto1 = message_module.TestAllTypes() 662 proto1.optional_int32 = 1 663 664 proto2 = message_module.TestAllTypes() 665 # This shouldn't get overwritten. 666 proto2.optional_string = 'value' 667 668 proto2.MergeFrom(proto1) 669 self.assertEqual(1, proto2.optional_int32) 670 self.assertEqual('value', proto2.optional_string) 671 672 def testMergeFromRepeatedField(self, message_module): 673 # Test merge with just a repeated field. 674 proto1 = message_module.TestAllTypes() 675 proto1.repeated_int32.append(1) 676 proto1.repeated_int32.append(2) 677 678 proto2 = message_module.TestAllTypes() 679 proto2.repeated_int32.append(0) 680 proto2.MergeFrom(proto1) 681 682 self.assertEqual(0, proto2.repeated_int32[0]) 683 self.assertEqual(1, proto2.repeated_int32[1]) 684 self.assertEqual(2, proto2.repeated_int32[2]) 685 686 def testMergeFromRepeatedNestedMessage(self, message_module): 687 # Test merge with a repeated nested message. 688 proto1 = message_module.TestAllTypes() 689 m = proto1.repeated_nested_message.add() 690 m.bb = 123 691 m = proto1.repeated_nested_message.add() 692 m.bb = 321 693 694 proto2 = message_module.TestAllTypes() 695 m = proto2.repeated_nested_message.add() 696 m.bb = 999 697 proto2.MergeFrom(proto1) 698 self.assertEqual(999, proto2.repeated_nested_message[0].bb) 699 self.assertEqual(123, proto2.repeated_nested_message[1].bb) 700 self.assertEqual(321, proto2.repeated_nested_message[2].bb) 701 702 proto3 = message_module.TestAllTypes() 703 proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) 704 self.assertEqual(999, proto3.repeated_nested_message[0].bb) 705 self.assertEqual(123, proto3.repeated_nested_message[1].bb) 706 self.assertEqual(321, proto3.repeated_nested_message[2].bb) 707 708 def testMergeFromAllFields(self, message_module): 709 # With all fields set. 710 proto1 = message_module.TestAllTypes() 711 test_util.SetAllFields(proto1) 712 proto2 = message_module.TestAllTypes() 713 proto2.MergeFrom(proto1) 714 715 # Messages should be equal. 716 self.assertEqual(proto2, proto1) 717 718 # Serialized string should be equal too. 719 string1 = proto1.SerializeToString() 720 string2 = proto2.SerializeToString() 721 self.assertEqual(string1, string2) 722 723 def testMergeFromBug(self, message_module): 724 message1 = message_module.TestAllTypes() 725 message2 = message_module.TestAllTypes() 726 727 # Cause optional_nested_message to be instantiated within message1, even 728 # though it is not considered to be "present". 729 message1.optional_nested_message 730 self.assertFalse(message1.HasField('optional_nested_message')) 731 732 # Merge into message2. This should not instantiate the field is message2. 733 message2.MergeFrom(message1) 734 self.assertFalse(message2.HasField('optional_nested_message')) 735 736 def testCopyFromSingularField(self, message_module): 737 # Test copy with just a singular field. 738 proto1 = message_module.TestAllTypes() 739 proto1.optional_int32 = 1 740 proto1.optional_string = 'important-text' 741 742 proto2 = message_module.TestAllTypes() 743 proto2.optional_string = 'value' 744 745 proto2.CopyFrom(proto1) 746 self.assertEqual(1, proto2.optional_int32) 747 self.assertEqual('important-text', proto2.optional_string) 748 749 def testCopyFromRepeatedField(self, message_module): 750 # Test copy with a repeated field. 751 proto1 = message_module.TestAllTypes() 752 proto1.repeated_int32.append(1) 753 proto1.repeated_int32.append(2) 754 755 proto2 = message_module.TestAllTypes() 756 proto2.repeated_int32.append(0) 757 proto2.CopyFrom(proto1) 758 759 self.assertEqual(1, proto2.repeated_int32[0]) 760 self.assertEqual(2, proto2.repeated_int32[1]) 761 762 def testCopyFromAllFields(self, message_module): 763 # With all fields set. 764 proto1 = message_module.TestAllTypes() 765 test_util.SetAllFields(proto1) 766 proto2 = message_module.TestAllTypes() 767 proto2.CopyFrom(proto1) 768 769 # Messages should be equal. 770 self.assertEqual(proto2, proto1) 771 772 # Serialized string should be equal too. 773 string1 = proto1.SerializeToString() 774 string2 = proto2.SerializeToString() 775 self.assertEqual(string1, string2) 776 777 def testCopyFromSelf(self, message_module): 778 proto1 = message_module.TestAllTypes() 779 proto1.repeated_int32.append(1) 780 proto1.optional_int32 = 2 781 proto1.optional_string = 'important-text' 782 783 proto1.CopyFrom(proto1) 784 self.assertEqual(1, proto1.repeated_int32[0]) 785 self.assertEqual(2, proto1.optional_int32) 786 self.assertEqual('important-text', proto1.optional_string) 787 788 def testDeepCopy(self, message_module): 789 proto1 = message_module.TestAllTypes() 790 proto1.optional_int32 = 1 791 proto2 = copy.deepcopy(proto1) 792 self.assertEqual(1, proto2.optional_int32) 793 794 proto1.repeated_int32.append(2) 795 proto1.repeated_int32.append(3) 796 container = copy.deepcopy(proto1.repeated_int32) 797 self.assertEqual([2, 3], container) 798 container.remove(container[0]) 799 self.assertEqual([3], container) 800 801 message1 = proto1.repeated_nested_message.add() 802 message1.bb = 1 803 messages = copy.deepcopy(proto1.repeated_nested_message) 804 self.assertEqual(proto1.repeated_nested_message, messages) 805 message1.bb = 2 806 self.assertNotEqual(proto1.repeated_nested_message, messages) 807 messages.remove(messages[0]) 808 self.assertEqual(len(messages), 0) 809 810 # TODO(anuraag): Implement deepcopy for extension dict 811 812 def testDisconnectingBeforeClear(self, message_module): 813 proto = message_module.TestAllTypes() 814 nested = proto.optional_nested_message 815 proto.Clear() 816 self.assertIsNot(nested, proto.optional_nested_message) 817 nested.bb = 23 818 self.assertFalse(proto.HasField('optional_nested_message')) 819 self.assertEqual(0, proto.optional_nested_message.bb) 820 821 proto = message_module.TestAllTypes() 822 nested = proto.optional_nested_message 823 nested.bb = 5 824 foreign = proto.optional_foreign_message 825 foreign.c = 6 826 proto.Clear() 827 self.assertIsNot(nested, proto.optional_nested_message) 828 self.assertIsNot(foreign, proto.optional_foreign_message) 829 self.assertEqual(5, nested.bb) 830 self.assertEqual(6, foreign.c) 831 nested.bb = 15 832 foreign.c = 16 833 self.assertFalse(proto.HasField('optional_nested_message')) 834 self.assertEqual(0, proto.optional_nested_message.bb) 835 self.assertFalse(proto.HasField('optional_foreign_message')) 836 self.assertEqual(0, proto.optional_foreign_message.c) 837 838 def testStringUTF8Encoding(self, message_module): 839 proto = message_module.TestAllTypes() 840 841 # Assignment of a unicode object to a field of type 'bytes' is not allowed. 842 self.assertRaises(TypeError, 843 setattr, proto, 'optional_bytes', u'unicode object') 844 845 # Check that the default value is of python's 'unicode' type. 846 self.assertEqual(type(proto.optional_string), str) 847 848 proto.optional_string = str('Testing') 849 self.assertEqual(proto.optional_string, str('Testing')) 850 851 # Assign a value of type 'str' which can be encoded in UTF-8. 852 proto.optional_string = str('Testing') 853 self.assertEqual(proto.optional_string, str('Testing')) 854 855 # Try to assign a 'bytes' object which contains non-UTF-8. 856 self.assertRaises(ValueError, 857 setattr, proto, 'optional_string', b'a\x80a') 858 # No exception: Assign already encoded UTF-8 bytes to a string field. 859 utf8_bytes = u'Тест'.encode('utf-8') 860 proto.optional_string = utf8_bytes 861 # No exception: Assign the a non-ascii unicode object. 862 proto.optional_string = u'Тест' 863 # No exception thrown (normal str assignment containing ASCII). 864 proto.optional_string = 'abc' 865 866 def testBytesInTextFormat(self, message_module): 867 proto = message_module.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') 868 self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', str(proto)) 869 870 def testEmptyNestedMessage(self, message_module): 871 proto = message_module.TestAllTypes() 872 proto.optional_nested_message.MergeFrom( 873 message_module.TestAllTypes.NestedMessage()) 874 self.assertTrue(proto.HasField('optional_nested_message')) 875 876 proto = message_module.TestAllTypes() 877 proto.optional_nested_message.CopyFrom( 878 message_module.TestAllTypes.NestedMessage()) 879 self.assertTrue(proto.HasField('optional_nested_message')) 880 881 proto = message_module.TestAllTypes() 882 bytes_read = proto.optional_nested_message.MergeFromString(b'') 883 self.assertEqual(0, bytes_read) 884 self.assertTrue(proto.HasField('optional_nested_message')) 885 886 proto = message_module.TestAllTypes() 887 proto.optional_nested_message.ParseFromString(b'') 888 self.assertTrue(proto.HasField('optional_nested_message')) 889 890 serialized = proto.SerializeToString() 891 proto2 = message_module.TestAllTypes() 892 self.assertEqual( 893 len(serialized), 894 proto2.MergeFromString(serialized)) 895 self.assertTrue(proto2.HasField('optional_nested_message')) 896 897 898# Class to test proto2-only features (required, extensions, etc.) 899@testing_refleaks.TestCase 900class Proto2ReflectionTest(unittest.TestCase): 901 902 def testRepeatedCompositeConstructor(self): 903 # Constructor with only repeated composite types should succeed. 904 proto = unittest_pb2.TestAllTypes( 905 repeated_nested_message=[ 906 unittest_pb2.TestAllTypes.NestedMessage( 907 bb=unittest_pb2.TestAllTypes.FOO), 908 unittest_pb2.TestAllTypes.NestedMessage( 909 bb=unittest_pb2.TestAllTypes.BAR)], 910 repeated_foreign_message=[ 911 unittest_pb2.ForeignMessage(c=-43), 912 unittest_pb2.ForeignMessage(c=45324), 913 unittest_pb2.ForeignMessage(c=12)], 914 repeatedgroup=[ 915 unittest_pb2.TestAllTypes.RepeatedGroup(), 916 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 917 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) 918 919 self.assertEqual( 920 [unittest_pb2.TestAllTypes.NestedMessage( 921 bb=unittest_pb2.TestAllTypes.FOO), 922 unittest_pb2.TestAllTypes.NestedMessage( 923 bb=unittest_pb2.TestAllTypes.BAR)], 924 list(proto.repeated_nested_message)) 925 self.assertEqual( 926 [unittest_pb2.ForeignMessage(c=-43), 927 unittest_pb2.ForeignMessage(c=45324), 928 unittest_pb2.ForeignMessage(c=12)], 929 list(proto.repeated_foreign_message)) 930 self.assertEqual( 931 [unittest_pb2.TestAllTypes.RepeatedGroup(), 932 unittest_pb2.TestAllTypes.RepeatedGroup(a=1), 933 unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], 934 list(proto.repeatedgroup)) 935 936 def assertListsEqual(self, values, others): 937 self.assertEqual(len(values), len(others)) 938 for i in range(len(values)): 939 self.assertEqual(values[i], others[i]) 940 941 def testSimpleHasBits(self): 942 # Test a scalar. 943 proto = unittest_pb2.TestAllTypes() 944 self.assertFalse(proto.HasField('optional_int32')) 945 self.assertEqual(0, proto.optional_int32) 946 # HasField() shouldn't be true if all we've done is 947 # read the default value. 948 self.assertFalse(proto.HasField('optional_int32')) 949 proto.optional_int32 = 1 950 # Setting a value however *should* set the "has" bit. 951 self.assertTrue(proto.HasField('optional_int32')) 952 proto.ClearField('optional_int32') 953 # And clearing that value should unset the "has" bit. 954 self.assertFalse(proto.HasField('optional_int32')) 955 956 def testHasBitsWithSinglyNestedScalar(self): 957 # Helper used to test foreign messages and groups. 958 # 959 # composite_field_name should be the name of a non-repeated 960 # composite (i.e., foreign or group) field in TestAllTypes, 961 # and scalar_field_name should be the name of an integer-valued 962 # scalar field within that composite. 963 # 964 # I never thought I'd miss C++ macros and templates so much. :( 965 # This helper is semantically just: 966 # 967 # assert proto.composite_field.scalar_field == 0 968 # assert not proto.composite_field.HasField('scalar_field') 969 # assert not proto.HasField('composite_field') 970 # 971 # proto.composite_field.scalar_field = 10 972 # old_composite_field = proto.composite_field 973 # 974 # assert proto.composite_field.scalar_field == 10 975 # assert proto.composite_field.HasField('scalar_field') 976 # assert proto.HasField('composite_field') 977 # 978 # proto.ClearField('composite_field') 979 # 980 # assert not proto.composite_field.HasField('scalar_field') 981 # assert not proto.HasField('composite_field') 982 # assert proto.composite_field.scalar_field == 0 983 # 984 # # Now ensure that ClearField('composite_field') disconnected 985 # # the old field object from the object tree... 986 # assert old_composite_field is not proto.composite_field 987 # old_composite_field.scalar_field = 20 988 # assert not proto.composite_field.HasField('scalar_field') 989 # assert not proto.HasField('composite_field') 990 def TestCompositeHasBits(composite_field_name, scalar_field_name): 991 proto = unittest_pb2.TestAllTypes() 992 # First, check that we can get the scalar value, and see that it's the 993 # default (0), but that proto.HasField('omposite') and 994 # proto.composite.HasField('scalar') will still return False. 995 composite_field = getattr(proto, composite_field_name) 996 original_scalar_value = getattr(composite_field, scalar_field_name) 997 self.assertEqual(0, original_scalar_value) 998 # Assert that the composite object does not "have" the scalar. 999 self.assertFalse(composite_field.HasField(scalar_field_name)) 1000 # Assert that proto does not "have" the composite field. 1001 self.assertFalse(proto.HasField(composite_field_name)) 1002 1003 # Now set the scalar within the composite field. Ensure that the setting 1004 # is reflected, and that proto.HasField('composite') and 1005 # proto.composite.HasField('scalar') now both return True. 1006 new_val = 20 1007 setattr(composite_field, scalar_field_name, new_val) 1008 self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) 1009 # Hold on to a reference to the current composite_field object. 1010 old_composite_field = composite_field 1011 # Assert that the has methods now return true. 1012 self.assertTrue(composite_field.HasField(scalar_field_name)) 1013 self.assertTrue(proto.HasField(composite_field_name)) 1014 1015 # Now call the clear method... 1016 proto.ClearField(composite_field_name) 1017 1018 # ...and ensure that the "has" bits are all back to False... 1019 composite_field = getattr(proto, composite_field_name) 1020 self.assertFalse(composite_field.HasField(scalar_field_name)) 1021 self.assertFalse(proto.HasField(composite_field_name)) 1022 # ...and ensure that the scalar field has returned to its default. 1023 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 1024 1025 self.assertIsNot(old_composite_field, composite_field) 1026 setattr(old_composite_field, scalar_field_name, new_val) 1027 self.assertFalse(composite_field.HasField(scalar_field_name)) 1028 self.assertFalse(proto.HasField(composite_field_name)) 1029 self.assertEqual(0, getattr(composite_field, scalar_field_name)) 1030 1031 # Test simple, single-level nesting when we set a scalar. 1032 TestCompositeHasBits('optionalgroup', 'a') 1033 TestCompositeHasBits('optional_nested_message', 'bb') 1034 TestCompositeHasBits('optional_foreign_message', 'c') 1035 TestCompositeHasBits('optional_import_message', 'd') 1036 1037 def testHasBitsWhenModifyingRepeatedFields(self): 1038 # Test nesting when we add an element to a repeated field in a submessage. 1039 proto = unittest_pb2.TestNestedMessageHasBits() 1040 proto.optional_nested_message.nestedmessage_repeated_int32.append(5) 1041 self.assertEqual( 1042 [5], proto.optional_nested_message.nestedmessage_repeated_int32) 1043 self.assertTrue(proto.HasField('optional_nested_message')) 1044 1045 # Do the same test, but with a repeated composite field within the 1046 # submessage. 1047 proto.ClearField('optional_nested_message') 1048 self.assertFalse(proto.HasField('optional_nested_message')) 1049 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() 1050 self.assertTrue(proto.HasField('optional_nested_message')) 1051 1052 def testHasBitsForManyLevelsOfNesting(self): 1053 # Test nesting many levels deep. 1054 recursive_proto = unittest_pb2.TestMutualRecursionA() 1055 self.assertFalse(recursive_proto.HasField('bb')) 1056 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) 1057 self.assertFalse(recursive_proto.HasField('bb')) 1058 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 1059 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) 1060 self.assertTrue(recursive_proto.HasField('bb')) 1061 self.assertTrue(recursive_proto.bb.HasField('a')) 1062 self.assertTrue(recursive_proto.bb.a.HasField('bb')) 1063 self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) 1064 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) 1065 self.assertFalse(recursive_proto.bb.a.bb.a.bb.HasField('a')) 1066 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) 1067 1068 def testSingularListExtensions(self): 1069 proto = unittest_pb2.TestAllExtensions() 1070 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 1071 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 1072 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' 1073 self.assertEqual( 1074 [ (unittest_pb2.optional_int32_extension , 5), 1075 (unittest_pb2.optional_fixed32_extension, 1), 1076 (unittest_pb2.optional_string_extension , 'foo') ], 1077 proto.ListFields()) 1078 del proto.Extensions[unittest_pb2.optional_fixed32_extension] 1079 self.assertEqual( 1080 [(unittest_pb2.optional_int32_extension, 5), 1081 (unittest_pb2.optional_string_extension, 'foo')], 1082 proto.ListFields()) 1083 1084 def testRepeatedListExtensions(self): 1085 proto = unittest_pb2.TestAllExtensions() 1086 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) 1087 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) 1088 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) 1089 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') 1090 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') 1091 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') 1092 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 1093 self.assertEqual( 1094 [ (unittest_pb2.optional_int32_extension , 21), 1095 (unittest_pb2.repeated_int32_extension , [5, 11]), 1096 (unittest_pb2.repeated_fixed32_extension, [1]), 1097 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], 1098 proto.ListFields()) 1099 del proto.Extensions[unittest_pb2.repeated_int32_extension] 1100 del proto.Extensions[unittest_pb2.repeated_string_extension] 1101 self.assertEqual( 1102 [(unittest_pb2.optional_int32_extension, 21), 1103 (unittest_pb2.repeated_fixed32_extension, [1])], 1104 proto.ListFields()) 1105 1106 def testListFieldsAndExtensions(self): 1107 proto = unittest_pb2.TestFieldOrderings() 1108 test_util.SetAllFieldsAndExtensions(proto) 1109 unittest_pb2.my_extension_int 1110 self.assertEqual( 1111 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), 1112 (unittest_pb2.my_extension_int , 23), 1113 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), 1114 (unittest_pb2.my_extension_string , 'bar'), 1115 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], 1116 proto.ListFields()) 1117 1118 def testDefaultValues(self): 1119 proto = unittest_pb2.TestAllTypes() 1120 self.assertEqual(0, proto.optional_int32) 1121 self.assertEqual(0, proto.optional_int64) 1122 self.assertEqual(0, proto.optional_uint32) 1123 self.assertEqual(0, proto.optional_uint64) 1124 self.assertEqual(0, proto.optional_sint32) 1125 self.assertEqual(0, proto.optional_sint64) 1126 self.assertEqual(0, proto.optional_fixed32) 1127 self.assertEqual(0, proto.optional_fixed64) 1128 self.assertEqual(0, proto.optional_sfixed32) 1129 self.assertEqual(0, proto.optional_sfixed64) 1130 self.assertEqual(0.0, proto.optional_float) 1131 self.assertEqual(0.0, proto.optional_double) 1132 self.assertEqual(False, proto.optional_bool) 1133 self.assertEqual('', proto.optional_string) 1134 self.assertEqual(b'', proto.optional_bytes) 1135 1136 self.assertEqual(41, proto.default_int32) 1137 self.assertEqual(42, proto.default_int64) 1138 self.assertEqual(43, proto.default_uint32) 1139 self.assertEqual(44, proto.default_uint64) 1140 self.assertEqual(-45, proto.default_sint32) 1141 self.assertEqual(46, proto.default_sint64) 1142 self.assertEqual(47, proto.default_fixed32) 1143 self.assertEqual(48, proto.default_fixed64) 1144 self.assertEqual(49, proto.default_sfixed32) 1145 self.assertEqual(-50, proto.default_sfixed64) 1146 self.assertEqual(51.5, proto.default_float) 1147 self.assertEqual(52e3, proto.default_double) 1148 self.assertEqual(True, proto.default_bool) 1149 self.assertEqual('hello', proto.default_string) 1150 self.assertEqual(b'world', proto.default_bytes) 1151 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) 1152 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) 1153 self.assertEqual(unittest_import_pb2.IMPORT_BAR, 1154 proto.default_import_enum) 1155 1156 proto = unittest_pb2.TestExtremeDefaultValues() 1157 self.assertEqual(u'\u1234', proto.utf8_string) 1158 1159 def testHasFieldWithUnknownFieldName(self): 1160 proto = unittest_pb2.TestAllTypes() 1161 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') 1162 1163 def testClearRemovesChildren(self): 1164 # Make sure there aren't any implementation bugs that are only partially 1165 # clearing the message (which can happen in the more complex C++ 1166 # implementation which has parallel message lists). 1167 proto = unittest_pb2.TestRequiredForeign() 1168 for i in range(10): 1169 proto.repeated_message.add() 1170 proto2 = unittest_pb2.TestRequiredForeign() 1171 proto.CopyFrom(proto2) 1172 self.assertRaises(IndexError, lambda: proto.repeated_message[5]) 1173 1174 def testSingleScalarClearField(self): 1175 proto = unittest_pb2.TestAllTypes() 1176 # Should be allowed to clear something that's not there (a no-op). 1177 proto.ClearField('optional_int32') 1178 proto.optional_int32 = 1 1179 self.assertTrue(proto.HasField('optional_int32')) 1180 proto.ClearField('optional_int32') 1181 self.assertEqual(0, proto.optional_int32) 1182 self.assertFalse(proto.HasField('optional_int32')) 1183 # TODO(robinson): Test all other scalar field types. 1184 1185 def testRepeatedScalars(self): 1186 proto = unittest_pb2.TestAllTypes() 1187 1188 self.assertFalse(proto.repeated_int32) 1189 self.assertEqual(0, len(proto.repeated_int32)) 1190 proto.repeated_int32.append(5) 1191 proto.repeated_int32.append(10) 1192 proto.repeated_int32.append(15) 1193 self.assertTrue(proto.repeated_int32) 1194 self.assertEqual(3, len(proto.repeated_int32)) 1195 1196 self.assertEqual([5, 10, 15], proto.repeated_int32) 1197 1198 # Test single retrieval. 1199 self.assertEqual(5, proto.repeated_int32[0]) 1200 self.assertEqual(15, proto.repeated_int32[-1]) 1201 # Test out-of-bounds indices. 1202 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) 1203 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) 1204 # Test incorrect types passed to __getitem__. 1205 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') 1206 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) 1207 1208 # Test single assignment. 1209 proto.repeated_int32[1] = 20 1210 self.assertEqual([5, 20, 15], proto.repeated_int32) 1211 1212 # Test insertion. 1213 proto.repeated_int32.insert(1, 25) 1214 self.assertEqual([5, 25, 20, 15], proto.repeated_int32) 1215 1216 # Test slice retrieval. 1217 proto.repeated_int32.append(30) 1218 self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) 1219 self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) 1220 1221 # Test slice assignment with an iterator 1222 proto.repeated_int32[1:4] = (i for i in range(3)) 1223 self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) 1224 1225 # Test slice assignment. 1226 proto.repeated_int32[1:4] = [35, 40, 45] 1227 self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) 1228 1229 # Test that we can use the field as an iterator. 1230 result = [] 1231 for i in proto.repeated_int32: 1232 result.append(i) 1233 self.assertEqual([5, 35, 40, 45, 30], result) 1234 1235 # Test single deletion. 1236 del proto.repeated_int32[2] 1237 self.assertEqual([5, 35, 45, 30], proto.repeated_int32) 1238 1239 # Test slice deletion. 1240 del proto.repeated_int32[2:] 1241 self.assertEqual([5, 35], proto.repeated_int32) 1242 1243 # Test extending. 1244 proto.repeated_int32.extend([3, 13]) 1245 self.assertEqual([5, 35, 3, 13], proto.repeated_int32) 1246 1247 # Test clearing. 1248 proto.ClearField('repeated_int32') 1249 self.assertFalse(proto.repeated_int32) 1250 self.assertEqual(0, len(proto.repeated_int32)) 1251 1252 proto.repeated_int32.append(1) 1253 self.assertEqual(1, proto.repeated_int32[-1]) 1254 # Test assignment to a negative index. 1255 proto.repeated_int32[-1] = 2 1256 self.assertEqual(2, proto.repeated_int32[-1]) 1257 1258 # Test deletion at negative indices. 1259 proto.repeated_int32[:] = [0, 1, 2, 3] 1260 del proto.repeated_int32[-1] 1261 self.assertEqual([0, 1, 2], proto.repeated_int32) 1262 1263 del proto.repeated_int32[-2] 1264 self.assertEqual([0, 2], proto.repeated_int32) 1265 1266 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) 1267 self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) 1268 1269 del proto.repeated_int32[-2:-1] 1270 self.assertEqual([2], proto.repeated_int32) 1271 1272 del proto.repeated_int32[100:10000] 1273 self.assertEqual([2], proto.repeated_int32) 1274 1275 def testRepeatedScalarsRemove(self): 1276 proto = unittest_pb2.TestAllTypes() 1277 1278 self.assertFalse(proto.repeated_int32) 1279 self.assertEqual(0, len(proto.repeated_int32)) 1280 proto.repeated_int32.append(5) 1281 proto.repeated_int32.append(10) 1282 proto.repeated_int32.append(5) 1283 proto.repeated_int32.append(5) 1284 1285 self.assertEqual(4, len(proto.repeated_int32)) 1286 proto.repeated_int32.remove(5) 1287 self.assertEqual(3, len(proto.repeated_int32)) 1288 self.assertEqual(10, proto.repeated_int32[0]) 1289 self.assertEqual(5, proto.repeated_int32[1]) 1290 self.assertEqual(5, proto.repeated_int32[2]) 1291 1292 proto.repeated_int32.remove(5) 1293 self.assertEqual(2, len(proto.repeated_int32)) 1294 self.assertEqual(10, proto.repeated_int32[0]) 1295 self.assertEqual(5, proto.repeated_int32[1]) 1296 1297 proto.repeated_int32.remove(10) 1298 self.assertEqual(1, len(proto.repeated_int32)) 1299 self.assertEqual(5, proto.repeated_int32[0]) 1300 1301 # Remove a non-existent element. 1302 self.assertRaises(ValueError, proto.repeated_int32.remove, 123) 1303 1304 def testRepeatedScalarsReverse_Empty(self): 1305 proto = unittest_pb2.TestAllTypes() 1306 1307 self.assertFalse(proto.repeated_int32) 1308 self.assertEqual(0, len(proto.repeated_int32)) 1309 1310 self.assertIsNone(proto.repeated_int32.reverse()) 1311 1312 self.assertFalse(proto.repeated_int32) 1313 self.assertEqual(0, len(proto.repeated_int32)) 1314 1315 def testRepeatedScalarsReverse_NonEmpty(self): 1316 proto = unittest_pb2.TestAllTypes() 1317 1318 self.assertFalse(proto.repeated_int32) 1319 self.assertEqual(0, len(proto.repeated_int32)) 1320 1321 proto.repeated_int32.append(1) 1322 proto.repeated_int32.append(2) 1323 proto.repeated_int32.append(3) 1324 proto.repeated_int32.append(4) 1325 1326 self.assertEqual(4, len(proto.repeated_int32)) 1327 1328 self.assertIsNone(proto.repeated_int32.reverse()) 1329 1330 self.assertEqual(4, len(proto.repeated_int32)) 1331 self.assertEqual(4, proto.repeated_int32[0]) 1332 self.assertEqual(3, proto.repeated_int32[1]) 1333 self.assertEqual(2, proto.repeated_int32[2]) 1334 self.assertEqual(1, proto.repeated_int32[3]) 1335 1336 def testRepeatedComposites(self): 1337 proto = unittest_pb2.TestAllTypes() 1338 self.assertFalse(proto.repeated_nested_message) 1339 self.assertEqual(0, len(proto.repeated_nested_message)) 1340 m0 = proto.repeated_nested_message.add() 1341 m1 = proto.repeated_nested_message.add() 1342 self.assertTrue(proto.repeated_nested_message) 1343 self.assertEqual(2, len(proto.repeated_nested_message)) 1344 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 1345 self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage) 1346 1347 # Test out-of-bounds indices. 1348 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 1349 1234) 1350 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, 1351 -1234) 1352 1353 # Test incorrect types passed to __getitem__. 1354 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 1355 'foo') 1356 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, 1357 None) 1358 1359 # Test slice retrieval. 1360 m2 = proto.repeated_nested_message.add() 1361 m3 = proto.repeated_nested_message.add() 1362 m4 = proto.repeated_nested_message.add() 1363 self.assertListsEqual( 1364 [m1, m2, m3], proto.repeated_nested_message[1:4]) 1365 self.assertListsEqual( 1366 [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) 1367 self.assertListsEqual( 1368 [m0, m1], proto.repeated_nested_message[:2]) 1369 self.assertListsEqual( 1370 [m2, m3, m4], proto.repeated_nested_message[2:]) 1371 self.assertEqual( 1372 m0, proto.repeated_nested_message[0]) 1373 self.assertListsEqual( 1374 [m0], proto.repeated_nested_message[:1]) 1375 1376 # Test that we can use the field as an iterator. 1377 result = [] 1378 for i in proto.repeated_nested_message: 1379 result.append(i) 1380 self.assertListsEqual([m0, m1, m2, m3, m4], result) 1381 1382 # Test single deletion. 1383 del proto.repeated_nested_message[2] 1384 self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) 1385 1386 # Test slice deletion. 1387 del proto.repeated_nested_message[2:] 1388 self.assertListsEqual([m0, m1], proto.repeated_nested_message) 1389 1390 # Test extending. 1391 n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) 1392 n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) 1393 proto.repeated_nested_message.extend([n1,n2]) 1394 self.assertEqual(4, len(proto.repeated_nested_message)) 1395 self.assertEqual(n1, proto.repeated_nested_message[2]) 1396 self.assertEqual(n2, proto.repeated_nested_message[3]) 1397 self.assertRaises(TypeError, 1398 proto.repeated_nested_message.extend, n1) 1399 self.assertRaises(TypeError, 1400 proto.repeated_nested_message.extend, [0]) 1401 wrong_message_type = unittest_pb2.TestAllTypes() 1402 self.assertRaises(TypeError, 1403 proto.repeated_nested_message.extend, 1404 [wrong_message_type]) 1405 1406 # Test clearing. 1407 proto.ClearField('repeated_nested_message') 1408 self.assertFalse(proto.repeated_nested_message) 1409 self.assertEqual(0, len(proto.repeated_nested_message)) 1410 1411 # Test constructing an element while adding it. 1412 proto.repeated_nested_message.add(bb=23) 1413 self.assertEqual(1, len(proto.repeated_nested_message)) 1414 self.assertEqual(23, proto.repeated_nested_message[0].bb) 1415 self.assertRaises(TypeError, proto.repeated_nested_message.add, 23) 1416 with self.assertRaises(Exception): 1417 proto.repeated_nested_message[0] = 23 1418 1419 def testRepeatedCompositeRemove(self): 1420 proto = unittest_pb2.TestAllTypes() 1421 1422 self.assertEqual(0, len(proto.repeated_nested_message)) 1423 m0 = proto.repeated_nested_message.add() 1424 # Need to set some differentiating variable so m0 != m1 != m2: 1425 m0.bb = len(proto.repeated_nested_message) 1426 m1 = proto.repeated_nested_message.add() 1427 m1.bb = len(proto.repeated_nested_message) 1428 self.assertTrue(m0 != m1) 1429 m2 = proto.repeated_nested_message.add() 1430 m2.bb = len(proto.repeated_nested_message) 1431 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 1432 1433 self.assertEqual(3, len(proto.repeated_nested_message)) 1434 proto.repeated_nested_message.remove(m0) 1435 self.assertEqual(2, len(proto.repeated_nested_message)) 1436 self.assertEqual(m1, proto.repeated_nested_message[0]) 1437 self.assertEqual(m2, proto.repeated_nested_message[1]) 1438 1439 # Removing m0 again or removing None should raise error 1440 self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) 1441 self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) 1442 self.assertEqual(2, len(proto.repeated_nested_message)) 1443 1444 proto.repeated_nested_message.remove(m2) 1445 self.assertEqual(1, len(proto.repeated_nested_message)) 1446 self.assertEqual(m1, proto.repeated_nested_message[0]) 1447 1448 def testRepeatedCompositeReverse_Empty(self): 1449 proto = unittest_pb2.TestAllTypes() 1450 1451 self.assertFalse(proto.repeated_nested_message) 1452 self.assertEqual(0, len(proto.repeated_nested_message)) 1453 1454 self.assertIsNone(proto.repeated_nested_message.reverse()) 1455 1456 self.assertFalse(proto.repeated_nested_message) 1457 self.assertEqual(0, len(proto.repeated_nested_message)) 1458 1459 def testRepeatedCompositeReverse_NonEmpty(self): 1460 proto = unittest_pb2.TestAllTypes() 1461 1462 self.assertFalse(proto.repeated_nested_message) 1463 self.assertEqual(0, len(proto.repeated_nested_message)) 1464 1465 m0 = proto.repeated_nested_message.add() 1466 m0.bb = len(proto.repeated_nested_message) 1467 m1 = proto.repeated_nested_message.add() 1468 m1.bb = len(proto.repeated_nested_message) 1469 m2 = proto.repeated_nested_message.add() 1470 m2.bb = len(proto.repeated_nested_message) 1471 self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) 1472 1473 self.assertIsNone(proto.repeated_nested_message.reverse()) 1474 1475 self.assertListsEqual([m2, m1, m0], proto.repeated_nested_message) 1476 1477 def testHandWrittenReflection(self): 1478 # Hand written extensions are only supported by the pure-Python 1479 # implementation of the API. 1480 if api_implementation.Type() != 'python': 1481 return 1482 1483 FieldDescriptor = descriptor.FieldDescriptor 1484 foo_field_descriptor = FieldDescriptor( 1485 name='foo_field', full_name='MyProto.foo_field', 1486 index=0, number=1, type=FieldDescriptor.TYPE_INT64, 1487 cpp_type=FieldDescriptor.CPPTYPE_INT64, 1488 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, 1489 containing_type=None, message_type=None, enum_type=None, 1490 is_extension=False, extension_scope=None, 1491 options=descriptor_pb2.FieldOptions(), 1492 # pylint: disable=protected-access 1493 create_key=descriptor._internal_create_key) 1494 mydescriptor = descriptor.Descriptor( 1495 name='MyProto', full_name='MyProto', filename='ignored', 1496 containing_type=None, nested_types=[], enum_types=[], 1497 fields=[foo_field_descriptor], extensions=[], 1498 options=descriptor_pb2.MessageOptions(), 1499 # pylint: disable=protected-access 1500 create_key=descriptor._internal_create_key) 1501 1502 class MyProtoClass( 1503 message.Message, metaclass=reflection.GeneratedProtocolMessageType): 1504 DESCRIPTOR = mydescriptor 1505 myproto_instance = MyProtoClass() 1506 self.assertEqual(0, myproto_instance.foo_field) 1507 self.assertFalse(myproto_instance.HasField('foo_field')) 1508 myproto_instance.foo_field = 23 1509 self.assertEqual(23, myproto_instance.foo_field) 1510 self.assertTrue(myproto_instance.HasField('foo_field')) 1511 1512 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 1513 def testDescriptorProtoSupport(self): 1514 # Hand written descriptors/reflection are only supported by the pure-Python 1515 # implementation of the API. 1516 if api_implementation.Type() != 'python': 1517 return 1518 1519 def AddDescriptorField(proto, field_name, field_type): 1520 AddDescriptorField.field_index += 1 1521 new_field = proto.field.add() 1522 new_field.name = field_name 1523 new_field.type = field_type 1524 new_field.number = AddDescriptorField.field_index 1525 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL 1526 1527 AddDescriptorField.field_index = 0 1528 1529 desc_proto = descriptor_pb2.DescriptorProto() 1530 desc_proto.name = 'Car' 1531 fdp = descriptor_pb2.FieldDescriptorProto 1532 AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) 1533 AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) 1534 AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) 1535 AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) 1536 # Add a repeated field 1537 AddDescriptorField.field_index += 1 1538 new_field = desc_proto.field.add() 1539 new_field.name = 'owners' 1540 new_field.type = fdp.TYPE_STRING 1541 new_field.number = AddDescriptorField.field_index 1542 new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED 1543 1544 desc = descriptor.MakeDescriptor(desc_proto) 1545 self.assertTrue('name' in desc.fields_by_name) 1546 self.assertTrue('year' in desc.fields_by_name) 1547 self.assertTrue('automatic' in desc.fields_by_name) 1548 self.assertTrue('price' in desc.fields_by_name) 1549 self.assertTrue('owners' in desc.fields_by_name) 1550 1551 class CarMessage( 1552 message.Message, metaclass=reflection.GeneratedProtocolMessageType): 1553 DESCRIPTOR = desc 1554 1555 prius = CarMessage() 1556 prius.name = 'prius' 1557 prius.year = 2010 1558 prius.automatic = True 1559 prius.price = 25134.75 1560 prius.owners.extend(['bob', 'susan']) 1561 1562 serialized_prius = prius.SerializeToString() 1563 new_prius = reflection.ParseMessage(desc, serialized_prius) 1564 self.assertIsNot(new_prius, prius) 1565 self.assertEqual(prius, new_prius) 1566 1567 # these are unnecessary assuming message equality works as advertised but 1568 # explicitly check to be safe since we're mucking about in metaclass foo 1569 self.assertEqual(prius.name, new_prius.name) 1570 self.assertEqual(prius.year, new_prius.year) 1571 self.assertEqual(prius.automatic, new_prius.automatic) 1572 self.assertEqual(prius.price, new_prius.price) 1573 self.assertEqual(prius.owners, new_prius.owners) 1574 1575 def testExtensionDelete(self): 1576 extendee_proto = more_extensions_pb2.ExtendedMessage() 1577 1578 extension_int32 = more_extensions_pb2.optional_int_extension 1579 extendee_proto.Extensions[extension_int32] = 23 1580 1581 extension_repeated = more_extensions_pb2.repeated_int_extension 1582 extendee_proto.Extensions[extension_repeated].append(11) 1583 1584 extension_msg = more_extensions_pb2.optional_message_extension 1585 extendee_proto.Extensions[extension_msg].foreign_message_int = 56 1586 1587 self.assertEqual(len(extendee_proto.Extensions), 3) 1588 del extendee_proto.Extensions[extension_msg] 1589 self.assertEqual(len(extendee_proto.Extensions), 2) 1590 del extendee_proto.Extensions[extension_repeated] 1591 self.assertEqual(len(extendee_proto.Extensions), 1) 1592 # Delete a none exist extension. It is OK to "del m.Extensions[ext]" 1593 # even if the extension is not present in the message; we don't 1594 # raise KeyError. This is consistent with "m.Extensions[ext]" 1595 # returning a default value even if we did not set anything. 1596 del extendee_proto.Extensions[extension_repeated] 1597 self.assertEqual(len(extendee_proto.Extensions), 1) 1598 del extendee_proto.Extensions[extension_int32] 1599 self.assertEqual(len(extendee_proto.Extensions), 0) 1600 1601 def testExtensionIter(self): 1602 extendee_proto = more_extensions_pb2.ExtendedMessage() 1603 1604 extension_int32 = more_extensions_pb2.optional_int_extension 1605 extendee_proto.Extensions[extension_int32] = 23 1606 1607 extension_repeated = more_extensions_pb2.repeated_int_extension 1608 extendee_proto.Extensions[extension_repeated].append(11) 1609 1610 extension_msg = more_extensions_pb2.optional_message_extension 1611 extendee_proto.Extensions[extension_msg].foreign_message_int = 56 1612 1613 # Set some normal fields. 1614 extendee_proto.optional_int32 = 1 1615 extendee_proto.repeated_string.append('hi') 1616 1617 expected = (extension_int32, extension_msg, extension_repeated) 1618 count = 0 1619 for item in extendee_proto.Extensions: 1620 self.assertEqual(item.name, expected[count].name) 1621 self.assertIn(item, extendee_proto.Extensions) 1622 count += 1 1623 self.assertEqual(count, 3) 1624 1625 def testExtensionContainsError(self): 1626 extendee_proto = more_extensions_pb2.ExtendedMessage() 1627 self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0) 1628 1629 field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[ 1630 'optional_int32'] 1631 self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field) 1632 1633 def testTopLevelExtensionsForOptionalScalar(self): 1634 extendee_proto = unittest_pb2.TestAllExtensions() 1635 extension = unittest_pb2.optional_int32_extension 1636 self.assertFalse(extendee_proto.HasExtension(extension)) 1637 self.assertNotIn(extension, extendee_proto.Extensions) 1638 self.assertEqual(0, extendee_proto.Extensions[extension]) 1639 # As with normal scalar fields, just doing a read doesn't actually set the 1640 # "has" bit. 1641 self.assertFalse(extendee_proto.HasExtension(extension)) 1642 self.assertNotIn(extension, extendee_proto.Extensions) 1643 # Actually set the thing. 1644 extendee_proto.Extensions[extension] = 23 1645 self.assertEqual(23, extendee_proto.Extensions[extension]) 1646 self.assertTrue(extendee_proto.HasExtension(extension)) 1647 self.assertIn(extension, extendee_proto.Extensions) 1648 # Ensure that clearing works as well. 1649 extendee_proto.ClearExtension(extension) 1650 self.assertEqual(0, extendee_proto.Extensions[extension]) 1651 self.assertFalse(extendee_proto.HasExtension(extension)) 1652 self.assertNotIn(extension, extendee_proto.Extensions) 1653 1654 def testTopLevelExtensionsForRepeatedScalar(self): 1655 extendee_proto = unittest_pb2.TestAllExtensions() 1656 extension = unittest_pb2.repeated_string_extension 1657 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1658 self.assertNotIn(extension, extendee_proto.Extensions) 1659 extendee_proto.Extensions[extension].append('foo') 1660 self.assertEqual(['foo'], extendee_proto.Extensions[extension]) 1661 self.assertIn(extension, extendee_proto.Extensions) 1662 string_list = extendee_proto.Extensions[extension] 1663 extendee_proto.ClearExtension(extension) 1664 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1665 self.assertNotIn(extension, extendee_proto.Extensions) 1666 self.assertIsNot(string_list, extendee_proto.Extensions[extension]) 1667 # Shouldn't be allowed to do Extensions[extension] = 'a' 1668 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1669 extension, 'a') 1670 1671 def testTopLevelExtensionsForOptionalMessage(self): 1672 extendee_proto = unittest_pb2.TestAllExtensions() 1673 extension = unittest_pb2.optional_foreign_message_extension 1674 self.assertFalse(extendee_proto.HasExtension(extension)) 1675 self.assertNotIn(extension, extendee_proto.Extensions) 1676 self.assertEqual(0, extendee_proto.Extensions[extension].c) 1677 # As with normal (non-extension) fields, merely reading from the 1678 # thing shouldn't set the "has" bit. 1679 self.assertFalse(extendee_proto.HasExtension(extension)) 1680 self.assertNotIn(extension, extendee_proto.Extensions) 1681 extendee_proto.Extensions[extension].c = 23 1682 self.assertEqual(23, extendee_proto.Extensions[extension].c) 1683 self.assertTrue(extendee_proto.HasExtension(extension)) 1684 self.assertIn(extension, extendee_proto.Extensions) 1685 # Save a reference here. 1686 foreign_message = extendee_proto.Extensions[extension] 1687 extendee_proto.ClearExtension(extension) 1688 self.assertIsNot(foreign_message, extendee_proto.Extensions[extension]) 1689 # Setting a field on foreign_message now shouldn't set 1690 # any "has" bits on extendee_proto. 1691 foreign_message.c = 42 1692 self.assertEqual(42, foreign_message.c) 1693 self.assertTrue(foreign_message.HasField('c')) 1694 self.assertFalse(extendee_proto.HasExtension(extension)) 1695 self.assertNotIn(extension, extendee_proto.Extensions) 1696 # Shouldn't be allowed to do Extensions[extension] = 'a' 1697 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1698 extension, 'a') 1699 1700 def testTopLevelExtensionsForRepeatedMessage(self): 1701 extendee_proto = unittest_pb2.TestAllExtensions() 1702 extension = unittest_pb2.repeatedgroup_extension 1703 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1704 group = extendee_proto.Extensions[extension].add() 1705 group.a = 23 1706 self.assertEqual(23, extendee_proto.Extensions[extension][0].a) 1707 group.a = 42 1708 self.assertEqual(42, extendee_proto.Extensions[extension][0].a) 1709 group_list = extendee_proto.Extensions[extension] 1710 extendee_proto.ClearExtension(extension) 1711 self.assertEqual(0, len(extendee_proto.Extensions[extension])) 1712 self.assertIsNot(group_list, extendee_proto.Extensions[extension]) 1713 # Shouldn't be allowed to do Extensions[extension] = 'a' 1714 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, 1715 extension, 'a') 1716 1717 def testNestedExtensions(self): 1718 extendee_proto = unittest_pb2.TestAllExtensions() 1719 extension = unittest_pb2.TestRequired.single 1720 1721 # We just test the non-repeated case. 1722 self.assertFalse(extendee_proto.HasExtension(extension)) 1723 self.assertNotIn(extension, extendee_proto.Extensions) 1724 required = extendee_proto.Extensions[extension] 1725 self.assertEqual(0, required.a) 1726 self.assertFalse(extendee_proto.HasExtension(extension)) 1727 self.assertNotIn(extension, extendee_proto.Extensions) 1728 required.a = 23 1729 self.assertEqual(23, extendee_proto.Extensions[extension].a) 1730 self.assertTrue(extendee_proto.HasExtension(extension)) 1731 self.assertIn(extension, extendee_proto.Extensions) 1732 extendee_proto.ClearExtension(extension) 1733 self.assertIsNot(required, extendee_proto.Extensions[extension]) 1734 self.assertFalse(extendee_proto.HasExtension(extension)) 1735 self.assertNotIn(extension, extendee_proto.Extensions) 1736 1737 def testRegisteredExtensions(self): 1738 pool = unittest_pb2.DESCRIPTOR.pool 1739 self.assertTrue( 1740 pool.FindExtensionByNumber( 1741 unittest_pb2.TestAllExtensions.DESCRIPTOR, 1)) 1742 self.assertIs( 1743 pool.FindExtensionByName( 1744 'protobuf_unittest.optional_int32_extension').containing_type, 1745 unittest_pb2.TestAllExtensions.DESCRIPTOR) 1746 # Make sure extensions haven't been registered into types that shouldn't 1747 # have any. 1748 self.assertEqual(0, len( 1749 pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR))) 1750 1751 # If message A directly contains message B, and 1752 # a.HasField('b') is currently False, then mutating any 1753 # extension in B should change a.HasField('b') to True 1754 # (and so on up the object tree). 1755 def testHasBitsForAncestorsOfExtendedMessage(self): 1756 # Optional scalar extension. 1757 toplevel = more_extensions_pb2.TopLevelMessage() 1758 self.assertFalse(toplevel.HasField('submessage')) 1759 self.assertEqual(0, toplevel.submessage.Extensions[ 1760 more_extensions_pb2.optional_int_extension]) 1761 self.assertFalse(toplevel.HasField('submessage')) 1762 toplevel.submessage.Extensions[ 1763 more_extensions_pb2.optional_int_extension] = 23 1764 self.assertEqual(23, toplevel.submessage.Extensions[ 1765 more_extensions_pb2.optional_int_extension]) 1766 self.assertTrue(toplevel.HasField('submessage')) 1767 1768 # Repeated scalar extension. 1769 toplevel = more_extensions_pb2.TopLevelMessage() 1770 self.assertFalse(toplevel.HasField('submessage')) 1771 self.assertEqual([], toplevel.submessage.Extensions[ 1772 more_extensions_pb2.repeated_int_extension]) 1773 self.assertFalse(toplevel.HasField('submessage')) 1774 toplevel.submessage.Extensions[ 1775 more_extensions_pb2.repeated_int_extension].append(23) 1776 self.assertEqual([23], toplevel.submessage.Extensions[ 1777 more_extensions_pb2.repeated_int_extension]) 1778 self.assertTrue(toplevel.HasField('submessage')) 1779 1780 # Optional message extension. 1781 toplevel = more_extensions_pb2.TopLevelMessage() 1782 self.assertFalse(toplevel.HasField('submessage')) 1783 self.assertEqual(0, toplevel.submessage.Extensions[ 1784 more_extensions_pb2.optional_message_extension].foreign_message_int) 1785 self.assertFalse(toplevel.HasField('submessage')) 1786 toplevel.submessage.Extensions[ 1787 more_extensions_pb2.optional_message_extension].foreign_message_int = 23 1788 self.assertEqual(23, toplevel.submessage.Extensions[ 1789 more_extensions_pb2.optional_message_extension].foreign_message_int) 1790 self.assertTrue(toplevel.HasField('submessage')) 1791 1792 # Repeated message extension. 1793 toplevel = more_extensions_pb2.TopLevelMessage() 1794 self.assertFalse(toplevel.HasField('submessage')) 1795 self.assertEqual(0, len(toplevel.submessage.Extensions[ 1796 more_extensions_pb2.repeated_message_extension])) 1797 self.assertFalse(toplevel.HasField('submessage')) 1798 foreign = toplevel.submessage.Extensions[ 1799 more_extensions_pb2.repeated_message_extension].add() 1800 self.assertEqual(foreign, toplevel.submessage.Extensions[ 1801 more_extensions_pb2.repeated_message_extension][0]) 1802 self.assertTrue(toplevel.HasField('submessage')) 1803 1804 def testDisconnectionAfterClearingEmptyMessage(self): 1805 toplevel = more_extensions_pb2.TopLevelMessage() 1806 extendee_proto = toplevel.submessage 1807 extension = more_extensions_pb2.optional_message_extension 1808 extension_proto = extendee_proto.Extensions[extension] 1809 extendee_proto.ClearExtension(extension) 1810 extension_proto.foreign_message_int = 23 1811 1812 self.assertIsNot(extension_proto, extendee_proto.Extensions[extension]) 1813 1814 def testExtensionFailureModes(self): 1815 extendee_proto = unittest_pb2.TestAllExtensions() 1816 1817 # Try non-extension-handle arguments to HasExtension, 1818 # ClearExtension(), and Extensions[]... 1819 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) 1820 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) 1821 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) 1822 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) 1823 1824 # Try something that *is* an extension handle, just not for 1825 # this message... 1826 for unknown_handle in (more_extensions_pb2.optional_int_extension, 1827 more_extensions_pb2.optional_message_extension, 1828 more_extensions_pb2.repeated_int_extension, 1829 more_extensions_pb2.repeated_message_extension): 1830 self.assertRaises(KeyError, extendee_proto.HasExtension, 1831 unknown_handle) 1832 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1833 unknown_handle) 1834 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1835 unknown_handle) 1836 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1837 unknown_handle, 5) 1838 1839 # Try call HasExtension() with a valid handle, but for a 1840 # *repeated* field. (Just as with non-extension repeated 1841 # fields, Has*() isn't supported for extension repeated fields). 1842 self.assertRaises(KeyError, extendee_proto.HasExtension, 1843 unittest_pb2.repeated_string_extension) 1844 1845 def testMergeFromOptionalGroup(self): 1846 # Test merge with an optional group. 1847 proto1 = unittest_pb2.TestAllTypes() 1848 proto1.optionalgroup.a = 12 1849 proto2 = unittest_pb2.TestAllTypes() 1850 proto2.MergeFrom(proto1) 1851 self.assertEqual(12, proto2.optionalgroup.a) 1852 1853 def testMergeFromExtensionsSingular(self): 1854 proto1 = unittest_pb2.TestAllExtensions() 1855 proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 1856 1857 proto2 = unittest_pb2.TestAllExtensions() 1858 proto2.MergeFrom(proto1) 1859 self.assertEqual( 1860 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) 1861 1862 def testMergeFromExtensionsRepeated(self): 1863 proto1 = unittest_pb2.TestAllExtensions() 1864 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) 1865 proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) 1866 1867 proto2 = unittest_pb2.TestAllExtensions() 1868 proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) 1869 proto2.MergeFrom(proto1) 1870 self.assertEqual( 1871 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) 1872 self.assertEqual( 1873 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) 1874 self.assertEqual( 1875 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) 1876 self.assertEqual( 1877 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) 1878 1879 def testMergeFromExtensionsNestedMessage(self): 1880 proto1 = unittest_pb2.TestAllExtensions() 1881 ext1 = proto1.Extensions[ 1882 unittest_pb2.repeated_nested_message_extension] 1883 m = ext1.add() 1884 m.bb = 222 1885 m = ext1.add() 1886 m.bb = 333 1887 1888 proto2 = unittest_pb2.TestAllExtensions() 1889 ext2 = proto2.Extensions[ 1890 unittest_pb2.repeated_nested_message_extension] 1891 m = ext2.add() 1892 m.bb = 111 1893 1894 proto2.MergeFrom(proto1) 1895 ext2 = proto2.Extensions[ 1896 unittest_pb2.repeated_nested_message_extension] 1897 self.assertEqual(3, len(ext2)) 1898 self.assertEqual(111, ext2[0].bb) 1899 self.assertEqual(222, ext2[1].bb) 1900 self.assertEqual(333, ext2[2].bb) 1901 1902 def testCopyFromBadType(self): 1903 # The python implementation doesn't raise an exception in this 1904 # case. In theory it should. 1905 if api_implementation.Type() == 'python': 1906 return 1907 proto1 = unittest_pb2.TestAllTypes() 1908 proto2 = unittest_pb2.TestAllExtensions() 1909 self.assertRaises(TypeError, proto1.CopyFrom, proto2) 1910 1911 def testClear(self): 1912 proto = unittest_pb2.TestAllTypes() 1913 # C++ implementation does not support lazy fields right now so leave it 1914 # out for now. 1915 if api_implementation.Type() == 'python': 1916 test_util.SetAllFields(proto) 1917 else: 1918 test_util.SetAllNonLazyFields(proto) 1919 # Clear the message. 1920 proto.Clear() 1921 self.assertEqual(proto.ByteSize(), 0) 1922 empty_proto = unittest_pb2.TestAllTypes() 1923 self.assertEqual(proto, empty_proto) 1924 1925 # Test if extensions which were set are cleared. 1926 proto = unittest_pb2.TestAllExtensions() 1927 test_util.SetAllExtensions(proto) 1928 # Clear the message. 1929 proto.Clear() 1930 self.assertEqual(proto.ByteSize(), 0) 1931 empty_proto = unittest_pb2.TestAllExtensions() 1932 self.assertEqual(proto, empty_proto) 1933 1934 def testDisconnectingInOneof(self): 1935 m = unittest_pb2.TestOneof2() # This message has two messages in a oneof. 1936 m.foo_message.moo_int = 5 1937 sub_message = m.foo_message 1938 # Accessing another message's field does not clear the first one 1939 self.assertEqual(m.foo_lazy_message.moo_int, 0) 1940 self.assertEqual(m.foo_message.moo_int, 5) 1941 # But mutating another message in the oneof detaches the first one. 1942 m.foo_lazy_message.moo_int = 6 1943 self.assertEqual(m.foo_message.moo_int, 0) 1944 # The reference we got above was detached and is still valid. 1945 self.assertEqual(sub_message.moo_int, 5) 1946 sub_message.moo_int = 7 1947 1948 def assertInitialized(self, proto): 1949 self.assertTrue(proto.IsInitialized()) 1950 # Neither method should raise an exception. 1951 proto.SerializeToString() 1952 proto.SerializePartialToString() 1953 1954 def assertNotInitialized(self, proto, error_size=None): 1955 errors = [] 1956 self.assertFalse(proto.IsInitialized()) 1957 self.assertFalse(proto.IsInitialized(errors)) 1958 self.assertEqual(error_size, len(errors)) 1959 self.assertRaises(message.EncodeError, proto.SerializeToString) 1960 # "Partial" serialization doesn't care if message is uninitialized. 1961 proto.SerializePartialToString() 1962 1963 def testIsInitialized(self): 1964 # Trivial cases - all optional fields and extensions. 1965 proto = unittest_pb2.TestAllTypes() 1966 self.assertInitialized(proto) 1967 proto = unittest_pb2.TestAllExtensions() 1968 self.assertInitialized(proto) 1969 1970 # The case of uninitialized required fields. 1971 proto = unittest_pb2.TestRequired() 1972 self.assertNotInitialized(proto, 3) 1973 proto.a = proto.b = proto.c = 2 1974 self.assertInitialized(proto) 1975 1976 # The case of uninitialized submessage. 1977 proto = unittest_pb2.TestRequiredForeign() 1978 self.assertInitialized(proto) 1979 proto.optional_message.a = 1 1980 self.assertNotInitialized(proto, 2) 1981 proto.optional_message.b = 0 1982 proto.optional_message.c = 0 1983 self.assertInitialized(proto) 1984 1985 # Uninitialized repeated submessage. 1986 message1 = proto.repeated_message.add() 1987 self.assertNotInitialized(proto, 3) 1988 message1.a = message1.b = message1.c = 0 1989 self.assertInitialized(proto) 1990 1991 # Uninitialized repeated group in an extension. 1992 proto = unittest_pb2.TestAllExtensions() 1993 extension = unittest_pb2.TestRequired.multi 1994 message1 = proto.Extensions[extension].add() 1995 message2 = proto.Extensions[extension].add() 1996 self.assertNotInitialized(proto, 6) 1997 message1.a = 1 1998 message1.b = 1 1999 message1.c = 1 2000 self.assertNotInitialized(proto, 3) 2001 message2.a = 2 2002 message2.b = 2 2003 message2.c = 2 2004 self.assertInitialized(proto) 2005 2006 # Uninitialized nonrepeated message in an extension. 2007 proto = unittest_pb2.TestAllExtensions() 2008 extension = unittest_pb2.TestRequired.single 2009 proto.Extensions[extension].a = 1 2010 self.assertNotInitialized(proto, 2) 2011 proto.Extensions[extension].b = 2 2012 proto.Extensions[extension].c = 3 2013 self.assertInitialized(proto) 2014 2015 # Try passing an errors list. 2016 errors = [] 2017 proto = unittest_pb2.TestRequired() 2018 self.assertFalse(proto.IsInitialized(errors)) 2019 self.assertEqual(errors, ['a', 'b', 'c']) 2020 self.assertRaises(TypeError, proto.IsInitialized, 1, 2, 3) 2021 2022 @unittest.skipIf( 2023 api_implementation.Type() == 'python', 2024 'Errors are only available from the most recent C++ implementation.') 2025 def testFileDescriptorErrors(self): 2026 file_name = 'test_file_descriptor_errors.proto' 2027 package_name = 'test_file_descriptor_errors.proto' 2028 file_descriptor_proto = descriptor_pb2.FileDescriptorProto() 2029 file_descriptor_proto.name = file_name 2030 file_descriptor_proto.package = package_name 2031 m1 = file_descriptor_proto.message_type.add() 2032 m1.name = 'msg1' 2033 # Compiles the proto into the C++ descriptor pool 2034 descriptor.FileDescriptor( 2035 file_name, 2036 package_name, 2037 serialized_pb=file_descriptor_proto.SerializeToString()) 2038 # Add a FileDescriptorProto that has duplicate symbols 2039 another_file_name = 'another_test_file_descriptor_errors.proto' 2040 file_descriptor_proto.name = another_file_name 2041 m2 = file_descriptor_proto.message_type.add() 2042 m2.name = 'msg2' 2043 with self.assertRaises(TypeError) as cm: 2044 descriptor.FileDescriptor( 2045 another_file_name, 2046 package_name, 2047 serialized_pb=file_descriptor_proto.SerializeToString()) 2048 self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % 2049 getattr(cm.expected, '__name__', cm.expected)) 2050 self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) 2051 # Error message will say something about this definition being a 2052 # duplicate, though we don't check the message exactly to avoid a 2053 # dependency on the C++ logging code. 2054 self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) 2055 2056 def testStringUTF8Serialization(self): 2057 proto = message_set_extensions_pb2.TestMessageSet() 2058 extension_message = message_set_extensions_pb2.TestMessageSetExtension2 2059 extension = extension_message.message_set_extension 2060 2061 test_utf8 = u'Тест' 2062 test_utf8_bytes = test_utf8.encode('utf-8') 2063 2064 # 'Test' in another language, using UTF-8 charset. 2065 proto.Extensions[extension].str = test_utf8 2066 2067 # Serialize using the MessageSet wire format (this is specified in the 2068 # .proto file). 2069 serialized = proto.SerializeToString() 2070 2071 # Check byte size. 2072 self.assertEqual(proto.ByteSize(), len(serialized)) 2073 2074 raw = unittest_mset_pb2.RawMessageSet() 2075 bytes_read = raw.MergeFromString(serialized) 2076 self.assertEqual(len(serialized), bytes_read) 2077 2078 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2079 2080 self.assertEqual(1, len(raw.item)) 2081 # Check that the type_id is the same as the tag ID in the .proto file. 2082 self.assertEqual(raw.item[0].type_id, 98418634) 2083 2084 # Check the actual bytes on the wire. 2085 self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) 2086 bytes_read = message2.MergeFromString(raw.item[0].message) 2087 self.assertEqual(len(raw.item[0].message), bytes_read) 2088 2089 self.assertEqual(type(message2.str), str) 2090 self.assertEqual(message2.str, test_utf8) 2091 2092 # The pure Python API throws an exception on MergeFromString(), 2093 # if any of the string fields of the message can't be UTF-8 decoded. 2094 # The C++ implementation of the API has no way to check that on 2095 # MergeFromString and thus has no way to throw the exception. 2096 # 2097 # The pure Python API always returns objects of type 'unicode' (UTF-8 2098 # encoded), or 'bytes' (in 7 bit ASCII). 2099 badbytes = raw.item[0].message.replace( 2100 test_utf8_bytes, len(test_utf8_bytes) * b'\xff') 2101 2102 unicode_decode_failed = False 2103 try: 2104 message2.MergeFromString(badbytes) 2105 except UnicodeDecodeError: 2106 unicode_decode_failed = True 2107 string_field = message2.str 2108 self.assertTrue(unicode_decode_failed or type(string_field) is bytes) 2109 2110 def testSetInParent(self): 2111 proto = unittest_pb2.TestAllTypes() 2112 self.assertFalse(proto.HasField('optionalgroup')) 2113 proto.optionalgroup.SetInParent() 2114 self.assertTrue(proto.HasField('optionalgroup')) 2115 2116 def testPackageInitializationImport(self): 2117 """Test that we can import nested messages from their __init__.py. 2118 2119 Such setup is not trivial since at the time of processing of __init__.py one 2120 can't refer to its submodules by name in code, so expressions like 2121 google.protobuf.internal.import_test_package.inner_pb2 2122 don't work. They do work in imports, so we have assign an alias at import 2123 and then use that alias in generated code. 2124 """ 2125 # We import here since it's the import that used to fail, and we want 2126 # the failure to have the right context. 2127 # pylint: disable=g-import-not-at-top 2128 from google.protobuf.internal import import_test_package 2129 # pylint: enable=g-import-not-at-top 2130 msg = import_test_package.myproto.Outer() 2131 # Just check the default value. 2132 self.assertEqual(57, msg.inner.value) 2133 2134# Since we had so many tests for protocol buffer equality, we broke these out 2135# into separate TestCase classes. 2136 2137 2138@testing_refleaks.TestCase 2139class TestAllTypesEqualityTest(unittest.TestCase): 2140 2141 def setUp(self): 2142 self.first_proto = unittest_pb2.TestAllTypes() 2143 self.second_proto = unittest_pb2.TestAllTypes() 2144 2145 def testNotHashable(self): 2146 self.assertRaises(TypeError, hash, self.first_proto) 2147 2148 def testSelfEquality(self): 2149 self.assertEqual(self.first_proto, self.first_proto) 2150 2151 def testEmptyProtosEqual(self): 2152 self.assertEqual(self.first_proto, self.second_proto) 2153 2154 2155@testing_refleaks.TestCase 2156class FullProtosEqualityTest(unittest.TestCase): 2157 2158 """Equality tests using completely-full protos as a starting point.""" 2159 2160 def setUp(self): 2161 self.first_proto = unittest_pb2.TestAllTypes() 2162 self.second_proto = unittest_pb2.TestAllTypes() 2163 test_util.SetAllFields(self.first_proto) 2164 test_util.SetAllFields(self.second_proto) 2165 2166 def testNotHashable(self): 2167 self.assertRaises(TypeError, hash, self.first_proto) 2168 2169 def testNoneNotEqual(self): 2170 self.assertNotEqual(self.first_proto, None) 2171 self.assertNotEqual(None, self.second_proto) 2172 2173 def testNotEqualToOtherMessage(self): 2174 third_proto = unittest_pb2.TestRequired() 2175 self.assertNotEqual(self.first_proto, third_proto) 2176 self.assertNotEqual(third_proto, self.second_proto) 2177 2178 def testAllFieldsFilledEquality(self): 2179 self.assertEqual(self.first_proto, self.second_proto) 2180 2181 def testNonRepeatedScalar(self): 2182 # Nonrepeated scalar field change should cause inequality. 2183 self.first_proto.optional_int32 += 1 2184 self.assertNotEqual(self.first_proto, self.second_proto) 2185 # ...as should clearing a field. 2186 self.first_proto.ClearField('optional_int32') 2187 self.assertNotEqual(self.first_proto, self.second_proto) 2188 2189 def testNonRepeatedComposite(self): 2190 # Change a nonrepeated composite field. 2191 self.first_proto.optional_nested_message.bb += 1 2192 self.assertNotEqual(self.first_proto, self.second_proto) 2193 self.first_proto.optional_nested_message.bb -= 1 2194 self.assertEqual(self.first_proto, self.second_proto) 2195 # Clear a field in the nested message. 2196 self.first_proto.optional_nested_message.ClearField('bb') 2197 self.assertNotEqual(self.first_proto, self.second_proto) 2198 self.first_proto.optional_nested_message.bb = ( 2199 self.second_proto.optional_nested_message.bb) 2200 self.assertEqual(self.first_proto, self.second_proto) 2201 # Remove the nested message entirely. 2202 self.first_proto.ClearField('optional_nested_message') 2203 self.assertNotEqual(self.first_proto, self.second_proto) 2204 2205 def testRepeatedScalar(self): 2206 # Change a repeated scalar field. 2207 self.first_proto.repeated_int32.append(5) 2208 self.assertNotEqual(self.first_proto, self.second_proto) 2209 self.first_proto.ClearField('repeated_int32') 2210 self.assertNotEqual(self.first_proto, self.second_proto) 2211 2212 def testRepeatedComposite(self): 2213 # Change value within a repeated composite field. 2214 self.first_proto.repeated_nested_message[0].bb += 1 2215 self.assertNotEqual(self.first_proto, self.second_proto) 2216 self.first_proto.repeated_nested_message[0].bb -= 1 2217 self.assertEqual(self.first_proto, self.second_proto) 2218 # Add a value to a repeated composite field. 2219 self.first_proto.repeated_nested_message.add() 2220 self.assertNotEqual(self.first_proto, self.second_proto) 2221 self.second_proto.repeated_nested_message.add() 2222 self.assertEqual(self.first_proto, self.second_proto) 2223 2224 def testNonRepeatedScalarHasBits(self): 2225 # Ensure that we test "has" bits as well as value for 2226 # nonrepeated scalar field. 2227 self.first_proto.ClearField('optional_int32') 2228 self.second_proto.optional_int32 = 0 2229 self.assertNotEqual(self.first_proto, self.second_proto) 2230 2231 def testNonRepeatedCompositeHasBits(self): 2232 # Ensure that we test "has" bits as well as value for 2233 # nonrepeated composite field. 2234 self.first_proto.ClearField('optional_nested_message') 2235 self.second_proto.optional_nested_message.ClearField('bb') 2236 self.assertNotEqual(self.first_proto, self.second_proto) 2237 self.first_proto.optional_nested_message.bb = 0 2238 self.first_proto.optional_nested_message.ClearField('bb') 2239 self.assertEqual(self.first_proto, self.second_proto) 2240 2241 2242@testing_refleaks.TestCase 2243class ExtensionEqualityTest(unittest.TestCase): 2244 2245 def testExtensionEquality(self): 2246 first_proto = unittest_pb2.TestAllExtensions() 2247 second_proto = unittest_pb2.TestAllExtensions() 2248 self.assertEqual(first_proto, second_proto) 2249 test_util.SetAllExtensions(first_proto) 2250 self.assertNotEqual(first_proto, second_proto) 2251 test_util.SetAllExtensions(second_proto) 2252 self.assertEqual(first_proto, second_proto) 2253 2254 # Ensure that we check value equality. 2255 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 2256 self.assertNotEqual(first_proto, second_proto) 2257 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 2258 self.assertEqual(first_proto, second_proto) 2259 2260 # Ensure that we also look at "has" bits. 2261 first_proto.ClearExtension(unittest_pb2.optional_int32_extension) 2262 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 2263 self.assertNotEqual(first_proto, second_proto) 2264 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 2265 self.assertEqual(first_proto, second_proto) 2266 2267 # Ensure that differences in cached values 2268 # don't matter if "has" bits are both false. 2269 first_proto = unittest_pb2.TestAllExtensions() 2270 second_proto = unittest_pb2.TestAllExtensions() 2271 self.assertEqual( 2272 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) 2273 self.assertEqual(first_proto, second_proto) 2274 2275 2276@testing_refleaks.TestCase 2277class MutualRecursionEqualityTest(unittest.TestCase): 2278 2279 def testEqualityWithMutualRecursion(self): 2280 first_proto = unittest_pb2.TestMutualRecursionA() 2281 second_proto = unittest_pb2.TestMutualRecursionA() 2282 self.assertEqual(first_proto, second_proto) 2283 first_proto.bb.a.bb.optional_int32 = 23 2284 self.assertNotEqual(first_proto, second_proto) 2285 second_proto.bb.a.bb.optional_int32 = 23 2286 self.assertEqual(first_proto, second_proto) 2287 2288 2289@testing_refleaks.TestCase 2290class ByteSizeTest(unittest.TestCase): 2291 2292 def setUp(self): 2293 self.proto = unittest_pb2.TestAllTypes() 2294 self.extended_proto = more_extensions_pb2.ExtendedMessage() 2295 self.packed_proto = unittest_pb2.TestPackedTypes() 2296 self.packed_extended_proto = unittest_pb2.TestPackedExtensions() 2297 2298 def Size(self): 2299 return self.proto.ByteSize() 2300 2301 def testEmptyMessage(self): 2302 self.assertEqual(0, self.proto.ByteSize()) 2303 2304 def testSizedOnKwargs(self): 2305 # Use a separate message to ensure testing right after creation. 2306 proto = unittest_pb2.TestAllTypes() 2307 self.assertEqual(0, proto.ByteSize()) 2308 proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) 2309 # One byte for the tag, one to encode varint 1. 2310 self.assertEqual(2, proto_kwargs.ByteSize()) 2311 2312 def testVarints(self): 2313 def Test(i, expected_varint_size): 2314 self.proto.Clear() 2315 self.proto.optional_int64 = i 2316 # Add one to the varint size for the tag info 2317 # for tag 1. 2318 self.assertEqual(expected_varint_size + 1, self.Size()) 2319 Test(0, 1) 2320 Test(1, 1) 2321 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): 2322 Test((1 << i) - 1, num_bytes) 2323 Test(-1, 10) 2324 Test(-2, 10) 2325 Test(-(1 << 63), 10) 2326 2327 def testStrings(self): 2328 self.proto.optional_string = '' 2329 # Need one byte for tag info (tag #14), and one byte for length. 2330 self.assertEqual(2, self.Size()) 2331 2332 self.proto.optional_string = 'abc' 2333 # Need one byte for tag info (tag #14), and one byte for length. 2334 self.assertEqual(2 + len(self.proto.optional_string), self.Size()) 2335 2336 self.proto.optional_string = 'x' * 128 2337 # Need one byte for tag info (tag #14), and TWO bytes for length. 2338 self.assertEqual(3 + len(self.proto.optional_string), self.Size()) 2339 2340 def testOtherNumerics(self): 2341 self.proto.optional_fixed32 = 1234 2342 # One byte for tag and 4 bytes for fixed32. 2343 self.assertEqual(5, self.Size()) 2344 self.proto = unittest_pb2.TestAllTypes() 2345 2346 self.proto.optional_fixed64 = 1234 2347 # One byte for tag and 8 bytes for fixed64. 2348 self.assertEqual(9, self.Size()) 2349 self.proto = unittest_pb2.TestAllTypes() 2350 2351 self.proto.optional_float = 1.234 2352 # One byte for tag and 4 bytes for float. 2353 self.assertEqual(5, self.Size()) 2354 self.proto = unittest_pb2.TestAllTypes() 2355 2356 self.proto.optional_double = 1.234 2357 # One byte for tag and 8 bytes for float. 2358 self.assertEqual(9, self.Size()) 2359 self.proto = unittest_pb2.TestAllTypes() 2360 2361 self.proto.optional_sint32 = 64 2362 # One byte for tag and 2 bytes for zig-zag-encoded 64. 2363 self.assertEqual(3, self.Size()) 2364 self.proto = unittest_pb2.TestAllTypes() 2365 2366 def testComposites(self): 2367 # 3 bytes. 2368 self.proto.optional_nested_message.bb = (1 << 14) 2369 # Plus one byte for bb tag. 2370 # Plus 1 byte for optional_nested_message serialized size. 2371 # Plus two bytes for optional_nested_message tag. 2372 self.assertEqual(3 + 1 + 1 + 2, self.Size()) 2373 2374 def testGroups(self): 2375 # 4 bytes. 2376 self.proto.optionalgroup.a = (1 << 21) 2377 # Plus two bytes for |a| tag. 2378 # Plus 2 * two bytes for START_GROUP and END_GROUP tags. 2379 self.assertEqual(4 + 2 + 2*2, self.Size()) 2380 2381 def testRepeatedScalars(self): 2382 self.proto.repeated_int32.append(10) # 1 byte. 2383 self.proto.repeated_int32.append(128) # 2 bytes. 2384 # Also need 2 bytes for each entry for tag. 2385 self.assertEqual(1 + 2 + 2*2, self.Size()) 2386 2387 def testRepeatedScalarsExtend(self): 2388 self.proto.repeated_int32.extend([10, 128]) # 3 bytes. 2389 # Also need 2 bytes for each entry for tag. 2390 self.assertEqual(1 + 2 + 2*2, self.Size()) 2391 2392 def testRepeatedScalarsRemove(self): 2393 self.proto.repeated_int32.append(10) # 1 byte. 2394 self.proto.repeated_int32.append(128) # 2 bytes. 2395 # Also need 2 bytes for each entry for tag. 2396 self.assertEqual(1 + 2 + 2*2, self.Size()) 2397 self.proto.repeated_int32.remove(128) 2398 self.assertEqual(1 + 2, self.Size()) 2399 2400 def testRepeatedComposites(self): 2401 # Empty message. 2 bytes tag plus 1 byte length. 2402 foreign_message_0 = self.proto.repeated_nested_message.add() 2403 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2404 foreign_message_1 = self.proto.repeated_nested_message.add() 2405 foreign_message_1.bb = 7 2406 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2407 2408 def testRepeatedCompositesDelete(self): 2409 # Empty message. 2 bytes tag plus 1 byte length. 2410 foreign_message_0 = self.proto.repeated_nested_message.add() 2411 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2412 foreign_message_1 = self.proto.repeated_nested_message.add() 2413 foreign_message_1.bb = 9 2414 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) 2415 repeated_nested_message = copy.deepcopy( 2416 self.proto.repeated_nested_message) 2417 2418 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2419 del self.proto.repeated_nested_message[0] 2420 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2421 2422 # Now add a new message. 2423 foreign_message_2 = self.proto.repeated_nested_message.add() 2424 foreign_message_2.bb = 12 2425 2426 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2427 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2428 self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) 2429 2430 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. 2431 del self.proto.repeated_nested_message[1] 2432 self.assertEqual(2 + 1 + 1 + 1, self.Size()) 2433 2434 del self.proto.repeated_nested_message[0] 2435 self.assertEqual(0, self.Size()) 2436 2437 self.assertEqual(2, len(repeated_nested_message)) 2438 del repeated_nested_message[0:1] 2439 # TODO(jieluo): Fix cpp extension bug when delete repeated message. 2440 if api_implementation.Type() == 'python': 2441 self.assertEqual(1, len(repeated_nested_message)) 2442 del repeated_nested_message[-1] 2443 # TODO(jieluo): Fix cpp extension bug when delete repeated message. 2444 if api_implementation.Type() == 'python': 2445 self.assertEqual(0, len(repeated_nested_message)) 2446 2447 def testRepeatedGroups(self): 2448 # 2-byte START_GROUP plus 2-byte END_GROUP. 2449 group_0 = self.proto.repeatedgroup.add() 2450 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| 2451 # plus 2-byte END_GROUP. 2452 group_1 = self.proto.repeatedgroup.add() 2453 group_1.a = 7 2454 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) 2455 2456 def testExtensions(self): 2457 proto = unittest_pb2.TestAllExtensions() 2458 self.assertEqual(0, proto.ByteSize()) 2459 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. 2460 proto.Extensions[extension] = 23 2461 # 1 byte for tag, 1 byte for value. 2462 self.assertEqual(2, proto.ByteSize()) 2463 field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[ 2464 'optional_int32'] 2465 with self.assertRaises(KeyError): 2466 proto.Extensions[field] = 23 2467 2468 def testCacheInvalidationForNonrepeatedScalar(self): 2469 # Test non-extension. 2470 self.proto.optional_int32 = 1 2471 self.assertEqual(2, self.proto.ByteSize()) 2472 self.proto.optional_int32 = 128 2473 self.assertEqual(3, self.proto.ByteSize()) 2474 self.proto.ClearField('optional_int32') 2475 self.assertEqual(0, self.proto.ByteSize()) 2476 2477 # Test within extension. 2478 extension = more_extensions_pb2.optional_int_extension 2479 self.extended_proto.Extensions[extension] = 1 2480 self.assertEqual(2, self.extended_proto.ByteSize()) 2481 self.extended_proto.Extensions[extension] = 128 2482 self.assertEqual(3, self.extended_proto.ByteSize()) 2483 self.extended_proto.ClearExtension(extension) 2484 self.assertEqual(0, self.extended_proto.ByteSize()) 2485 2486 def testCacheInvalidationForRepeatedScalar(self): 2487 # Test non-extension. 2488 self.proto.repeated_int32.append(1) 2489 self.assertEqual(3, self.proto.ByteSize()) 2490 self.proto.repeated_int32.append(1) 2491 self.assertEqual(6, self.proto.ByteSize()) 2492 self.proto.repeated_int32[1] = 128 2493 self.assertEqual(7, self.proto.ByteSize()) 2494 self.proto.ClearField('repeated_int32') 2495 self.assertEqual(0, self.proto.ByteSize()) 2496 2497 # Test within extension. 2498 extension = more_extensions_pb2.repeated_int_extension 2499 repeated = self.extended_proto.Extensions[extension] 2500 repeated.append(1) 2501 self.assertEqual(2, self.extended_proto.ByteSize()) 2502 repeated.append(1) 2503 self.assertEqual(4, self.extended_proto.ByteSize()) 2504 repeated[1] = 128 2505 self.assertEqual(5, self.extended_proto.ByteSize()) 2506 self.extended_proto.ClearExtension(extension) 2507 self.assertEqual(0, self.extended_proto.ByteSize()) 2508 2509 def testCacheInvalidationForNonrepeatedMessage(self): 2510 # Test non-extension. 2511 self.proto.optional_foreign_message.c = 1 2512 self.assertEqual(5, self.proto.ByteSize()) 2513 self.proto.optional_foreign_message.c = 128 2514 self.assertEqual(6, self.proto.ByteSize()) 2515 self.proto.optional_foreign_message.ClearField('c') 2516 self.assertEqual(3, self.proto.ByteSize()) 2517 self.proto.ClearField('optional_foreign_message') 2518 self.assertEqual(0, self.proto.ByteSize()) 2519 2520 if api_implementation.Type() == 'python': 2521 # This is only possible in pure-Python implementation of the API. 2522 child = self.proto.optional_foreign_message 2523 self.proto.ClearField('optional_foreign_message') 2524 child.c = 128 2525 self.assertEqual(0, self.proto.ByteSize()) 2526 2527 # Test within extension. 2528 extension = more_extensions_pb2.optional_message_extension 2529 child = self.extended_proto.Extensions[extension] 2530 self.assertEqual(0, self.extended_proto.ByteSize()) 2531 child.foreign_message_int = 1 2532 self.assertEqual(4, self.extended_proto.ByteSize()) 2533 child.foreign_message_int = 128 2534 self.assertEqual(5, self.extended_proto.ByteSize()) 2535 self.extended_proto.ClearExtension(extension) 2536 self.assertEqual(0, self.extended_proto.ByteSize()) 2537 2538 def testCacheInvalidationForRepeatedMessage(self): 2539 # Test non-extension. 2540 child0 = self.proto.repeated_foreign_message.add() 2541 self.assertEqual(3, self.proto.ByteSize()) 2542 self.proto.repeated_foreign_message.add() 2543 self.assertEqual(6, self.proto.ByteSize()) 2544 child0.c = 1 2545 self.assertEqual(8, self.proto.ByteSize()) 2546 self.proto.ClearField('repeated_foreign_message') 2547 self.assertEqual(0, self.proto.ByteSize()) 2548 2549 # Test within extension. 2550 extension = more_extensions_pb2.repeated_message_extension 2551 child_list = self.extended_proto.Extensions[extension] 2552 child0 = child_list.add() 2553 self.assertEqual(2, self.extended_proto.ByteSize()) 2554 child_list.add() 2555 self.assertEqual(4, self.extended_proto.ByteSize()) 2556 child0.foreign_message_int = 1 2557 self.assertEqual(6, self.extended_proto.ByteSize()) 2558 child0.ClearField('foreign_message_int') 2559 self.assertEqual(4, self.extended_proto.ByteSize()) 2560 self.extended_proto.ClearExtension(extension) 2561 self.assertEqual(0, self.extended_proto.ByteSize()) 2562 2563 def testPackedRepeatedScalars(self): 2564 self.assertEqual(0, self.packed_proto.ByteSize()) 2565 2566 self.packed_proto.packed_int32.append(10) # 1 byte. 2567 self.packed_proto.packed_int32.append(128) # 2 bytes. 2568 # The tag is 2 bytes (the field number is 90), and the varint 2569 # storing the length is 1 byte. 2570 int_size = 1 + 2 + 3 2571 self.assertEqual(int_size, self.packed_proto.ByteSize()) 2572 2573 self.packed_proto.packed_double.append(4.2) # 8 bytes 2574 self.packed_proto.packed_double.append(3.25) # 8 bytes 2575 # 2 more tag bytes, 1 more length byte. 2576 double_size = 8 + 8 + 3 2577 self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) 2578 2579 self.packed_proto.ClearField('packed_int32') 2580 self.assertEqual(double_size, self.packed_proto.ByteSize()) 2581 2582 def testPackedExtensions(self): 2583 self.assertEqual(0, self.packed_extended_proto.ByteSize()) 2584 extension = self.packed_extended_proto.Extensions[ 2585 unittest_pb2.packed_fixed32_extension] 2586 extension.extend([1, 2, 3, 4]) # 16 bytes 2587 # Tag is 3 bytes. 2588 self.assertEqual(19, self.packed_extended_proto.ByteSize()) 2589 2590 2591# Issues to be sure to cover include: 2592# * Handling of unrecognized tags ("uninterpreted_bytes"). 2593# * Handling of MessageSets. 2594# * Consistent ordering of tags in the wire format, 2595# including ordering between extensions and non-extension 2596# fields. 2597# * Consistent serialization of negative numbers, especially 2598# negative int32s. 2599# * Handling of empty submessages (with and without "has" 2600# bits set). 2601 2602@testing_refleaks.TestCase 2603class SerializationTest(unittest.TestCase): 2604 2605 def testSerializeEmtpyMessage(self): 2606 first_proto = unittest_pb2.TestAllTypes() 2607 second_proto = unittest_pb2.TestAllTypes() 2608 serialized = first_proto.SerializeToString() 2609 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2610 self.assertEqual( 2611 len(serialized), 2612 second_proto.MergeFromString(serialized)) 2613 self.assertEqual(first_proto, second_proto) 2614 2615 def testSerializeAllFields(self): 2616 first_proto = unittest_pb2.TestAllTypes() 2617 second_proto = unittest_pb2.TestAllTypes() 2618 test_util.SetAllFields(first_proto) 2619 serialized = first_proto.SerializeToString() 2620 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2621 self.assertEqual( 2622 len(serialized), 2623 second_proto.MergeFromString(serialized)) 2624 self.assertEqual(first_proto, second_proto) 2625 2626 def testSerializeAllExtensions(self): 2627 first_proto = unittest_pb2.TestAllExtensions() 2628 second_proto = unittest_pb2.TestAllExtensions() 2629 test_util.SetAllExtensions(first_proto) 2630 serialized = first_proto.SerializeToString() 2631 self.assertEqual( 2632 len(serialized), 2633 second_proto.MergeFromString(serialized)) 2634 self.assertEqual(first_proto, second_proto) 2635 2636 def testSerializeWithOptionalGroup(self): 2637 first_proto = unittest_pb2.TestAllTypes() 2638 second_proto = unittest_pb2.TestAllTypes() 2639 first_proto.optionalgroup.a = 242 2640 serialized = first_proto.SerializeToString() 2641 self.assertEqual( 2642 len(serialized), 2643 second_proto.MergeFromString(serialized)) 2644 self.assertEqual(first_proto, second_proto) 2645 2646 def testSerializeNegativeValues(self): 2647 first_proto = unittest_pb2.TestAllTypes() 2648 2649 first_proto.optional_int32 = -1 2650 first_proto.optional_int64 = -(2 << 40) 2651 first_proto.optional_sint32 = -3 2652 first_proto.optional_sint64 = -(4 << 40) 2653 first_proto.optional_sfixed32 = -5 2654 first_proto.optional_sfixed64 = -(6 << 40) 2655 2656 second_proto = unittest_pb2.TestAllTypes.FromString( 2657 first_proto.SerializeToString()) 2658 2659 self.assertEqual(first_proto, second_proto) 2660 2661 def testParseTruncated(self): 2662 # This test is only applicable for the Python implementation of the API. 2663 if api_implementation.Type() != 'python': 2664 return 2665 2666 first_proto = unittest_pb2.TestAllTypes() 2667 test_util.SetAllFields(first_proto) 2668 serialized = memoryview(first_proto.SerializeToString()) 2669 2670 for truncation_point in range(len(serialized) + 1): 2671 try: 2672 second_proto = unittest_pb2.TestAllTypes() 2673 unknown_fields = unittest_pb2.TestEmptyMessage() 2674 pos = second_proto._InternalParse(serialized, 0, truncation_point) 2675 # If we didn't raise an error then we read exactly the amount expected. 2676 self.assertEqual(truncation_point, pos) 2677 2678 # Parsing to unknown fields should not throw if parsing to known fields 2679 # did not. 2680 try: 2681 pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) 2682 self.assertEqual(truncation_point, pos2) 2683 except message.DecodeError: 2684 self.fail('Parsing unknown fields failed when parsing known fields ' 2685 'did not.') 2686 except message.DecodeError: 2687 # Parsing unknown fields should also fail. 2688 self.assertRaises(message.DecodeError, unknown_fields._InternalParse, 2689 serialized, 0, truncation_point) 2690 2691 def testCanonicalSerializationOrder(self): 2692 proto = more_messages_pb2.OutOfOrderFields() 2693 # These are also their tag numbers. Even though we're setting these in 2694 # reverse-tag order AND they're listed in reverse tag-order in the .proto 2695 # file, they should nonetheless be serialized in tag order. 2696 proto.optional_sint32 = 5 2697 proto.Extensions[more_messages_pb2.optional_uint64] = 4 2698 proto.optional_uint32 = 3 2699 proto.Extensions[more_messages_pb2.optional_int64] = 2 2700 proto.optional_int32 = 1 2701 serialized = proto.SerializeToString() 2702 self.assertEqual(proto.ByteSize(), len(serialized)) 2703 d = _MiniDecoder(serialized) 2704 ReadTag = d.ReadFieldNumberAndWireType 2705 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) 2706 self.assertEqual(1, d.ReadInt32()) 2707 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) 2708 self.assertEqual(2, d.ReadInt64()) 2709 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) 2710 self.assertEqual(3, d.ReadUInt32()) 2711 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) 2712 self.assertEqual(4, d.ReadUInt64()) 2713 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) 2714 self.assertEqual(5, d.ReadSInt32()) 2715 2716 def testCanonicalSerializationOrderSameAsCpp(self): 2717 # Copy of the same test we use for C++. 2718 proto = unittest_pb2.TestFieldOrderings() 2719 test_util.SetAllFieldsAndExtensions(proto) 2720 serialized = proto.SerializeToString() 2721 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) 2722 2723 def testMergeFromStringWhenFieldsAlreadySet(self): 2724 first_proto = unittest_pb2.TestAllTypes() 2725 first_proto.repeated_string.append('foobar') 2726 first_proto.optional_int32 = 23 2727 first_proto.optional_nested_message.bb = 42 2728 serialized = first_proto.SerializeToString() 2729 2730 second_proto = unittest_pb2.TestAllTypes() 2731 second_proto.repeated_string.append('baz') 2732 second_proto.optional_int32 = 100 2733 second_proto.optional_nested_message.bb = 999 2734 2735 bytes_parsed = second_proto.MergeFromString(serialized) 2736 self.assertEqual(len(serialized), bytes_parsed) 2737 2738 # Ensure that we append to repeated fields. 2739 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) 2740 # Ensure that we overwrite nonrepeatd scalars. 2741 self.assertEqual(23, second_proto.optional_int32) 2742 # Ensure that we recursively call MergeFromString() on 2743 # submessages. 2744 self.assertEqual(42, second_proto.optional_nested_message.bb) 2745 2746 def testMessageSetWireFormat(self): 2747 proto = message_set_extensions_pb2.TestMessageSet() 2748 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2749 extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 2750 extension1 = extension_message1.message_set_extension 2751 extension2 = extension_message2.message_set_extension 2752 extension3 = message_set_extensions_pb2.message_set_extension3 2753 proto.Extensions[extension1].i = 123 2754 proto.Extensions[extension2].str = 'foo' 2755 proto.Extensions[extension3].text = 'bar' 2756 2757 # Serialize using the MessageSet wire format (this is specified in the 2758 # .proto file). 2759 serialized = proto.SerializeToString() 2760 2761 raw = unittest_mset_pb2.RawMessageSet() 2762 self.assertEqual(False, 2763 raw.DESCRIPTOR.GetOptions().message_set_wire_format) 2764 self.assertEqual( 2765 len(serialized), 2766 raw.MergeFromString(serialized)) 2767 self.assertEqual(3, len(raw.item)) 2768 2769 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2770 self.assertEqual( 2771 len(raw.item[0].message), 2772 message1.MergeFromString(raw.item[0].message)) 2773 self.assertEqual(123, message1.i) 2774 2775 message2 = message_set_extensions_pb2.TestMessageSetExtension2() 2776 self.assertEqual( 2777 len(raw.item[1].message), 2778 message2.MergeFromString(raw.item[1].message)) 2779 self.assertEqual('foo', message2.str) 2780 2781 message3 = message_set_extensions_pb2.TestMessageSetExtension3() 2782 self.assertEqual( 2783 len(raw.item[2].message), 2784 message3.MergeFromString(raw.item[2].message)) 2785 self.assertEqual('bar', message3.text) 2786 2787 # Deserialize using the MessageSet wire format. 2788 proto2 = message_set_extensions_pb2.TestMessageSet() 2789 self.assertEqual( 2790 len(serialized), 2791 proto2.MergeFromString(serialized)) 2792 self.assertEqual(123, proto2.Extensions[extension1].i) 2793 self.assertEqual('foo', proto2.Extensions[extension2].str) 2794 self.assertEqual('bar', proto2.Extensions[extension3].text) 2795 2796 # Check byte size. 2797 self.assertEqual(proto2.ByteSize(), len(serialized)) 2798 self.assertEqual(proto.ByteSize(), len(serialized)) 2799 2800 def testMessageSetWireFormatUnknownExtension(self): 2801 # Create a message using the message set wire format with an unknown 2802 # message. 2803 raw = unittest_mset_pb2.RawMessageSet() 2804 2805 # Add an item. 2806 item = raw.item.add() 2807 item.type_id = 98418603 2808 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2809 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2810 message1.i = 12345 2811 item.message = message1.SerializeToString() 2812 2813 # Add a second, unknown extension. 2814 item = raw.item.add() 2815 item.type_id = 98418604 2816 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2817 message1 = message_set_extensions_pb2.TestMessageSetExtension1() 2818 message1.i = 12346 2819 item.message = message1.SerializeToString() 2820 2821 # Add another unknown extension. 2822 item = raw.item.add() 2823 item.type_id = 98418605 2824 message1 = message_set_extensions_pb2.TestMessageSetExtension2() 2825 message1.str = 'foo' 2826 item.message = message1.SerializeToString() 2827 2828 serialized = raw.SerializeToString() 2829 2830 # Parse message using the message set wire format. 2831 proto = message_set_extensions_pb2.TestMessageSet() 2832 self.assertEqual( 2833 len(serialized), 2834 proto.MergeFromString(serialized)) 2835 2836 # Check that the message parsed well. 2837 extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 2838 extension1 = extension_message1.message_set_extension 2839 self.assertEqual(12345, proto.Extensions[extension1].i) 2840 2841 def testUnknownFields(self): 2842 proto = unittest_pb2.TestAllTypes() 2843 test_util.SetAllFields(proto) 2844 2845 serialized = proto.SerializeToString() 2846 2847 # The empty message should be parsable with all of the fields 2848 # unknown. 2849 proto2 = unittest_pb2.TestEmptyMessage() 2850 2851 # Parsing this message should succeed. 2852 self.assertEqual( 2853 len(serialized), 2854 proto2.MergeFromString(serialized)) 2855 2856 # Now test with a int64 field set. 2857 proto = unittest_pb2.TestAllTypes() 2858 proto.optional_int64 = 0x0fffffffffffffff 2859 serialized = proto.SerializeToString() 2860 # The empty message should be parsable with all of the fields 2861 # unknown. 2862 proto2 = unittest_pb2.TestEmptyMessage() 2863 # Parsing this message should succeed. 2864 self.assertEqual( 2865 len(serialized), 2866 proto2.MergeFromString(serialized)) 2867 2868 def _CheckRaises(self, exc_class, callable_obj, exception): 2869 """This method checks if the exception type and message are as expected.""" 2870 try: 2871 callable_obj() 2872 except exc_class as ex: 2873 # Check if the exception message is the right one. 2874 self.assertEqual(exception, str(ex)) 2875 return 2876 else: 2877 raise self.failureException('%s not raised' % str(exc_class)) 2878 2879 def testSerializeUninitialized(self): 2880 proto = unittest_pb2.TestRequired() 2881 self._CheckRaises( 2882 message.EncodeError, 2883 proto.SerializeToString, 2884 'Message protobuf_unittest.TestRequired is missing required fields: ' 2885 'a,b,c') 2886 # Shouldn't raise exceptions. 2887 partial = proto.SerializePartialToString() 2888 2889 proto2 = unittest_pb2.TestRequired() 2890 self.assertFalse(proto2.HasField('a')) 2891 # proto2 ParseFromString does not check that required fields are set. 2892 proto2.ParseFromString(partial) 2893 self.assertFalse(proto2.HasField('a')) 2894 2895 proto.a = 1 2896 self._CheckRaises( 2897 message.EncodeError, 2898 proto.SerializeToString, 2899 'Message protobuf_unittest.TestRequired is missing required fields: b,c') 2900 # Shouldn't raise exceptions. 2901 partial = proto.SerializePartialToString() 2902 2903 proto.b = 2 2904 self._CheckRaises( 2905 message.EncodeError, 2906 proto.SerializeToString, 2907 'Message protobuf_unittest.TestRequired is missing required fields: c') 2908 # Shouldn't raise exceptions. 2909 partial = proto.SerializePartialToString() 2910 2911 proto.c = 3 2912 serialized = proto.SerializeToString() 2913 # Shouldn't raise exceptions. 2914 partial = proto.SerializePartialToString() 2915 2916 proto2 = unittest_pb2.TestRequired() 2917 self.assertEqual( 2918 len(serialized), 2919 proto2.MergeFromString(serialized)) 2920 self.assertEqual(1, proto2.a) 2921 self.assertEqual(2, proto2.b) 2922 self.assertEqual(3, proto2.c) 2923 self.assertEqual( 2924 len(partial), 2925 proto2.MergeFromString(partial)) 2926 self.assertEqual(1, proto2.a) 2927 self.assertEqual(2, proto2.b) 2928 self.assertEqual(3, proto2.c) 2929 2930 def testSerializeUninitializedSubMessage(self): 2931 proto = unittest_pb2.TestRequiredForeign() 2932 2933 # Sub-message doesn't exist yet, so this succeeds. 2934 proto.SerializeToString() 2935 2936 proto.optional_message.a = 1 2937 self._CheckRaises( 2938 message.EncodeError, 2939 proto.SerializeToString, 2940 'Message protobuf_unittest.TestRequiredForeign ' 2941 'is missing required fields: ' 2942 'optional_message.b,optional_message.c') 2943 2944 proto.optional_message.b = 2 2945 proto.optional_message.c = 3 2946 proto.SerializeToString() 2947 2948 proto.repeated_message.add().a = 1 2949 proto.repeated_message.add().b = 2 2950 self._CheckRaises( 2951 message.EncodeError, 2952 proto.SerializeToString, 2953 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 2954 'repeated_message[0].b,repeated_message[0].c,' 2955 'repeated_message[1].a,repeated_message[1].c') 2956 2957 proto.repeated_message[0].b = 2 2958 proto.repeated_message[0].c = 3 2959 proto.repeated_message[1].a = 1 2960 proto.repeated_message[1].c = 3 2961 proto.SerializeToString() 2962 2963 def testSerializeAllPackedFields(self): 2964 first_proto = unittest_pb2.TestPackedTypes() 2965 second_proto = unittest_pb2.TestPackedTypes() 2966 test_util.SetAllPackedFields(first_proto) 2967 serialized = first_proto.SerializeToString() 2968 self.assertEqual(first_proto.ByteSize(), len(serialized)) 2969 bytes_read = second_proto.MergeFromString(serialized) 2970 self.assertEqual(second_proto.ByteSize(), bytes_read) 2971 self.assertEqual(first_proto, second_proto) 2972 2973 def testSerializeAllPackedExtensions(self): 2974 first_proto = unittest_pb2.TestPackedExtensions() 2975 second_proto = unittest_pb2.TestPackedExtensions() 2976 test_util.SetAllPackedExtensions(first_proto) 2977 serialized = first_proto.SerializeToString() 2978 bytes_read = second_proto.MergeFromString(serialized) 2979 self.assertEqual(second_proto.ByteSize(), bytes_read) 2980 self.assertEqual(first_proto, second_proto) 2981 2982 def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): 2983 first_proto = unittest_pb2.TestPackedTypes() 2984 first_proto.packed_int32.extend([1, 2]) 2985 first_proto.packed_double.append(3.0) 2986 serialized = first_proto.SerializeToString() 2987 2988 second_proto = unittest_pb2.TestPackedTypes() 2989 second_proto.packed_int32.append(3) 2990 second_proto.packed_double.extend([1.0, 2.0]) 2991 second_proto.packed_sint32.append(4) 2992 2993 self.assertEqual( 2994 len(serialized), 2995 second_proto.MergeFromString(serialized)) 2996 self.assertEqual([3, 1, 2], second_proto.packed_int32) 2997 self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) 2998 self.assertEqual([4], second_proto.packed_sint32) 2999 3000 def testPackedFieldsWireFormat(self): 3001 proto = unittest_pb2.TestPackedTypes() 3002 proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes 3003 proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes 3004 proto.packed_float.append(2.0) # 4 bytes, will be before double 3005 serialized = proto.SerializeToString() 3006 self.assertEqual(proto.ByteSize(), len(serialized)) 3007 d = _MiniDecoder(serialized) 3008 ReadTag = d.ReadFieldNumberAndWireType 3009 self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 3010 self.assertEqual(1+1+1+2, d.ReadInt32()) 3011 self.assertEqual(1, d.ReadInt32()) 3012 self.assertEqual(2, d.ReadInt32()) 3013 self.assertEqual(150, d.ReadInt32()) 3014 self.assertEqual(3, d.ReadInt32()) 3015 self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 3016 self.assertEqual(4, d.ReadInt32()) 3017 self.assertEqual(2.0, d.ReadFloat()) 3018 self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) 3019 self.assertEqual(8+8, d.ReadInt32()) 3020 self.assertEqual(1.0, d.ReadDouble()) 3021 self.assertEqual(1000.0, d.ReadDouble()) 3022 self.assertTrue(d.EndOfStream()) 3023 3024 def testParsePackedFromUnpacked(self): 3025 unpacked = unittest_pb2.TestUnpackedTypes() 3026 test_util.SetAllUnpackedFields(unpacked) 3027 packed = unittest_pb2.TestPackedTypes() 3028 serialized = unpacked.SerializeToString() 3029 self.assertEqual( 3030 len(serialized), 3031 packed.MergeFromString(serialized)) 3032 expected = unittest_pb2.TestPackedTypes() 3033 test_util.SetAllPackedFields(expected) 3034 self.assertEqual(expected, packed) 3035 3036 def testParseUnpackedFromPacked(self): 3037 packed = unittest_pb2.TestPackedTypes() 3038 test_util.SetAllPackedFields(packed) 3039 unpacked = unittest_pb2.TestUnpackedTypes() 3040 serialized = packed.SerializeToString() 3041 self.assertEqual( 3042 len(serialized), 3043 unpacked.MergeFromString(serialized)) 3044 expected = unittest_pb2.TestUnpackedTypes() 3045 test_util.SetAllUnpackedFields(expected) 3046 self.assertEqual(expected, unpacked) 3047 3048 def testFieldNumbers(self): 3049 proto = unittest_pb2.TestAllTypes() 3050 self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) 3051 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) 3052 self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) 3053 self.assertEqual( 3054 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) 3055 self.assertEqual( 3056 unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) 3057 self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) 3058 self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) 3059 self.assertEqual( 3060 unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) 3061 self.assertEqual( 3062 unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) 3063 3064 def testExtensionFieldNumbers(self): 3065 self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) 3066 self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) 3067 self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) 3068 self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) 3069 self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) 3070 self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) 3071 self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) 3072 self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) 3073 self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) 3074 self.assertEqual( 3075 unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) 3076 self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) 3077 self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 3078 21) 3079 self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) 3080 self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) 3081 self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) 3082 self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) 3083 self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) 3084 self.assertEqual( 3085 unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) 3086 self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) 3087 self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 3088 51) 3089 3090 def testFieldProperties(self): 3091 cls = unittest_pb2.TestAllTypes 3092 self.assertIs(cls.optional_int32.DESCRIPTOR, 3093 cls.DESCRIPTOR.fields_by_name['optional_int32']) 3094 self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER, 3095 cls.optional_int32.DESCRIPTOR.number) 3096 self.assertIs(cls.optional_nested_message.DESCRIPTOR, 3097 cls.DESCRIPTOR.fields_by_name['optional_nested_message']) 3098 self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 3099 cls.optional_nested_message.DESCRIPTOR.number) 3100 self.assertIs(cls.repeated_int32.DESCRIPTOR, 3101 cls.DESCRIPTOR.fields_by_name['repeated_int32']) 3102 self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER, 3103 cls.repeated_int32.DESCRIPTOR.number) 3104 3105 def testFieldDataDescriptor(self): 3106 msg = unittest_pb2.TestAllTypes() 3107 msg.optional_int32 = 42 3108 self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42) 3109 unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25) 3110 self.assertEqual(msg.optional_int32, 25) 3111 with self.assertRaises(AttributeError): 3112 del msg.optional_int32 3113 try: 3114 unittest_pb2.ForeignMessage.c.__get__(msg) 3115 except TypeError: 3116 pass # The cpp implementation cannot mix fields from other messages. 3117 # This test exercises a specific check that avoids a crash. 3118 else: 3119 pass # The python implementation allows fields from other messages. 3120 # This is useless, but works. 3121 3122 def testInitKwargs(self): 3123 proto = unittest_pb2.TestAllTypes( 3124 optional_int32=1, 3125 optional_string='foo', 3126 optional_bool=True, 3127 optional_bytes=b'bar', 3128 optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), 3129 optional_foreign_message=unittest_pb2.ForeignMessage(c=1), 3130 optional_nested_enum=unittest_pb2.TestAllTypes.FOO, 3131 optional_foreign_enum=unittest_pb2.FOREIGN_FOO, 3132 repeated_int32=[1, 2, 3]) 3133 self.assertTrue(proto.IsInitialized()) 3134 self.assertTrue(proto.HasField('optional_int32')) 3135 self.assertTrue(proto.HasField('optional_string')) 3136 self.assertTrue(proto.HasField('optional_bool')) 3137 self.assertTrue(proto.HasField('optional_bytes')) 3138 self.assertTrue(proto.HasField('optional_nested_message')) 3139 self.assertTrue(proto.HasField('optional_foreign_message')) 3140 self.assertTrue(proto.HasField('optional_nested_enum')) 3141 self.assertTrue(proto.HasField('optional_foreign_enum')) 3142 self.assertEqual(1, proto.optional_int32) 3143 self.assertEqual('foo', proto.optional_string) 3144 self.assertEqual(True, proto.optional_bool) 3145 self.assertEqual(b'bar', proto.optional_bytes) 3146 self.assertEqual(1, proto.optional_nested_message.bb) 3147 self.assertEqual(1, proto.optional_foreign_message.c) 3148 self.assertEqual(unittest_pb2.TestAllTypes.FOO, 3149 proto.optional_nested_enum) 3150 self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) 3151 self.assertEqual([1, 2, 3], proto.repeated_int32) 3152 3153 def testInitArgsUnknownFieldName(self): 3154 def InitalizeEmptyMessageWithExtraKeywordArg(): 3155 unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') 3156 self._CheckRaises( 3157 ValueError, 3158 InitalizeEmptyMessageWithExtraKeywordArg, 3159 'Protocol message TestEmptyMessage has no "unknown" field.') 3160 3161 def testInitRequiredKwargs(self): 3162 proto = unittest_pb2.TestRequired(a=1, b=1, c=1) 3163 self.assertTrue(proto.IsInitialized()) 3164 self.assertTrue(proto.HasField('a')) 3165 self.assertTrue(proto.HasField('b')) 3166 self.assertTrue(proto.HasField('c')) 3167 self.assertFalse(proto.HasField('dummy2')) 3168 self.assertEqual(1, proto.a) 3169 self.assertEqual(1, proto.b) 3170 self.assertEqual(1, proto.c) 3171 3172 def testInitRequiredForeignKwargs(self): 3173 proto = unittest_pb2.TestRequiredForeign( 3174 optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) 3175 self.assertTrue(proto.IsInitialized()) 3176 self.assertTrue(proto.HasField('optional_message')) 3177 self.assertTrue(proto.optional_message.IsInitialized()) 3178 self.assertTrue(proto.optional_message.HasField('a')) 3179 self.assertTrue(proto.optional_message.HasField('b')) 3180 self.assertTrue(proto.optional_message.HasField('c')) 3181 self.assertFalse(proto.optional_message.HasField('dummy2')) 3182 self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), 3183 proto.optional_message) 3184 self.assertEqual(1, proto.optional_message.a) 3185 self.assertEqual(1, proto.optional_message.b) 3186 self.assertEqual(1, proto.optional_message.c) 3187 3188 def testInitRepeatedKwargs(self): 3189 proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) 3190 self.assertTrue(proto.IsInitialized()) 3191 self.assertEqual(1, proto.repeated_int32[0]) 3192 self.assertEqual(2, proto.repeated_int32[1]) 3193 self.assertEqual(3, proto.repeated_int32[2]) 3194 3195 3196@testing_refleaks.TestCase 3197class OptionsTest(unittest.TestCase): 3198 3199 def testMessageOptions(self): 3200 proto = message_set_extensions_pb2.TestMessageSet() 3201 self.assertEqual(True, 3202 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 3203 proto = unittest_pb2.TestAllTypes() 3204 self.assertEqual(False, 3205 proto.DESCRIPTOR.GetOptions().message_set_wire_format) 3206 3207 def testPackedOptions(self): 3208 proto = unittest_pb2.TestAllTypes() 3209 proto.optional_int32 = 1 3210 proto.optional_double = 3.0 3211 for field_descriptor, _ in proto.ListFields(): 3212 self.assertEqual(False, field_descriptor.GetOptions().packed) 3213 3214 proto = unittest_pb2.TestPackedTypes() 3215 proto.packed_int32.append(1) 3216 proto.packed_double.append(3.0) 3217 for field_descriptor, _ in proto.ListFields(): 3218 self.assertEqual(True, field_descriptor.GetOptions().packed) 3219 self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, 3220 field_descriptor.label) 3221 3222 3223 3224@testing_refleaks.TestCase 3225class ClassAPITest(unittest.TestCase): 3226 3227 @unittest.skipIf( 3228 api_implementation.Type() != 'python', 3229 'C++ implementation requires a call to MakeDescriptor()') 3230 @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable') 3231 def testMakeClassWithNestedDescriptor(self): 3232 leaf_desc = descriptor.Descriptor( 3233 'leaf', 'package.parent.child.leaf', '', 3234 containing_type=None, fields=[], 3235 nested_types=[], enum_types=[], 3236 extensions=[], 3237 # pylint: disable=protected-access 3238 create_key=descriptor._internal_create_key) 3239 child_desc = descriptor.Descriptor( 3240 'child', 'package.parent.child', '', 3241 containing_type=None, fields=[], 3242 nested_types=[leaf_desc], enum_types=[], 3243 extensions=[], 3244 # pylint: disable=protected-access 3245 create_key=descriptor._internal_create_key) 3246 sibling_desc = descriptor.Descriptor( 3247 'sibling', 'package.parent.sibling', 3248 '', containing_type=None, fields=[], 3249 nested_types=[], enum_types=[], 3250 extensions=[], 3251 # pylint: disable=protected-access 3252 create_key=descriptor._internal_create_key) 3253 parent_desc = descriptor.Descriptor( 3254 'parent', 'package.parent', '', 3255 containing_type=None, fields=[], 3256 nested_types=[child_desc, sibling_desc], 3257 enum_types=[], extensions=[], 3258 # pylint: disable=protected-access 3259 create_key=descriptor._internal_create_key) 3260 reflection.MakeClass(parent_desc) 3261 3262 def _GetSerializedFileDescriptor(self, name): 3263 """Get a serialized representation of a test FileDescriptorProto. 3264 3265 Args: 3266 name: All calls to this must use a unique message name, to avoid 3267 collisions in the cpp descriptor pool. 3268 Returns: 3269 A string containing the serialized form of a test FileDescriptorProto. 3270 """ 3271 file_descriptor_str = ( 3272 'message_type {' 3273 ' name: "' + name + '"' 3274 ' field {' 3275 ' name: "flat"' 3276 ' number: 1' 3277 ' label: LABEL_REPEATED' 3278 ' type: TYPE_UINT32' 3279 ' }' 3280 ' field {' 3281 ' name: "bar"' 3282 ' number: 2' 3283 ' label: LABEL_OPTIONAL' 3284 ' type: TYPE_MESSAGE' 3285 ' type_name: "Bar"' 3286 ' }' 3287 ' nested_type {' 3288 ' name: "Bar"' 3289 ' field {' 3290 ' name: "baz"' 3291 ' number: 3' 3292 ' label: LABEL_OPTIONAL' 3293 ' type: TYPE_MESSAGE' 3294 ' type_name: "Baz"' 3295 ' }' 3296 ' nested_type {' 3297 ' name: "Baz"' 3298 ' enum_type {' 3299 ' name: "deep_enum"' 3300 ' value {' 3301 ' name: "VALUE_A"' 3302 ' number: 0' 3303 ' }' 3304 ' }' 3305 ' field {' 3306 ' name: "deep"' 3307 ' number: 4' 3308 ' label: LABEL_OPTIONAL' 3309 ' type: TYPE_UINT32' 3310 ' }' 3311 ' }' 3312 ' }' 3313 '}') 3314 file_descriptor = descriptor_pb2.FileDescriptorProto() 3315 text_format.Merge(file_descriptor_str, file_descriptor) 3316 return file_descriptor.SerializeToString() 3317 3318 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3319 # This test can only run once; the second time, it raises errors about 3320 # conflicting message descriptors. 3321 def testParsingFlatClassWithExplicitClassDeclaration(self): 3322 """Test that the generated class can parse a flat message.""" 3323 # TODO(xiaofeng): This test fails with cpp implementation in the call 3324 # of six.with_metaclass(). The other two callsites of with_metaclass 3325 # in this file are both excluded from cpp test, so it might be expected 3326 # to fail. Need someone more familiar with the python code to take a 3327 # look at this. 3328 if api_implementation.Type() != 'python': 3329 return 3330 file_descriptor = descriptor_pb2.FileDescriptorProto() 3331 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) 3332 msg_descriptor = descriptor.MakeDescriptor( 3333 file_descriptor.message_type[0]) 3334 3335 class MessageClass( 3336 message.Message, metaclass=reflection.GeneratedProtocolMessageType): 3337 DESCRIPTOR = msg_descriptor 3338 msg = MessageClass() 3339 msg_str = ( 3340 'flat: 0 ' 3341 'flat: 1 ' 3342 'flat: 2 ') 3343 text_format.Merge(msg_str, msg) 3344 self.assertEqual(msg.flat, [0, 1, 2]) 3345 3346 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3347 def testParsingFlatClass(self): 3348 """Test that the generated class can parse a flat message.""" 3349 file_descriptor = descriptor_pb2.FileDescriptorProto() 3350 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) 3351 msg_descriptor = descriptor.MakeDescriptor( 3352 file_descriptor.message_type[0]) 3353 msg_class = reflection.MakeClass(msg_descriptor) 3354 msg = msg_class() 3355 msg_str = ( 3356 'flat: 0 ' 3357 'flat: 1 ' 3358 'flat: 2 ') 3359 text_format.Merge(msg_str, msg) 3360 self.assertEqual(msg.flat, [0, 1, 2]) 3361 3362 @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') 3363 def testParsingNestedClass(self): 3364 """Test that the generated class can parse a nested message.""" 3365 file_descriptor = descriptor_pb2.FileDescriptorProto() 3366 file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) 3367 msg_descriptor = descriptor.MakeDescriptor( 3368 file_descriptor.message_type[0]) 3369 msg_class = reflection.MakeClass(msg_descriptor) 3370 msg = msg_class() 3371 msg_str = ( 3372 'bar {' 3373 ' baz {' 3374 ' deep: 4' 3375 ' }' 3376 '}') 3377 text_format.Merge(msg_str, msg) 3378 self.assertEqual(msg.bar.baz.deep, 4) 3379 3380if __name__ == '__main__': 3381 unittest.main() 3382