xref: /aosp_15_r20/external/cronet/third_party/protobuf/python/google/protobuf/internal/reflection_test.py (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1# -*- coding: utf-8 -*-
2# Protocol Buffers - Google's data interchange format
3# Copyright 2008 Google Inc.  All rights reserved.
4# https://developers.google.com/protocol-buffers/
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are
8# met:
9#
10#     * Redistributions of source code must retain the above copyright
11# notice, this list of conditions and the following disclaimer.
12#     * Redistributions in binary form must reproduce the above
13# copyright notice, this list of conditions and the following disclaimer
14# in the documentation and/or other materials provided with the
15# distribution.
16#     * Neither the name of Google Inc. nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
32"""Unittest for reflection.py, which also indirectly tests the output of the
33pure-Python protocol compiler.
34"""
35
36import copy
37import gc
38import operator
39import struct
40import sys
41import warnings
42import unittest
43
44from google.protobuf import unittest_import_pb2
45from google.protobuf import unittest_mset_pb2
46from google.protobuf import unittest_pb2
47from google.protobuf import unittest_proto3_arena_pb2
48from google.protobuf import descriptor_pb2
49from google.protobuf import descriptor
50from google.protobuf import message
51from google.protobuf import reflection
52from google.protobuf import text_format
53from google.protobuf.internal import api_implementation
54from google.protobuf.internal import more_extensions_pb2
55from google.protobuf.internal import more_messages_pb2
56from google.protobuf.internal import message_set_extensions_pb2
57from google.protobuf.internal import wire_format
58from google.protobuf.internal import test_util
59from google.protobuf.internal import testing_refleaks
60from google.protobuf.internal import decoder
61from google.protobuf.internal import _parameterized
62
63
64warnings.simplefilter('error', DeprecationWarning)
65
66
67class _MiniDecoder(object):
68  """Decodes a stream of values from a string.
69
70  Once upon a time we actually had a class called decoder.Decoder.  Then we
71  got rid of it during a redesign that made decoding much, much faster overall.
72  But a couple tests in this file used it to check that the serialized form of
73  a message was correct.  So, this class implements just the methods that were
74  used by said tests, so that we don't have to rewrite the tests.
75  """
76
77  def __init__(self, bytes):
78    self._bytes = bytes
79    self._pos = 0
80
81  def ReadVarint(self):
82    result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
83    return result
84
85  ReadInt32 = ReadVarint
86  ReadInt64 = ReadVarint
87  ReadUInt32 = ReadVarint
88  ReadUInt64 = ReadVarint
89
90  def ReadSInt64(self):
91    return wire_format.ZigZagDecode(self.ReadVarint())
92
93  ReadSInt32 = ReadSInt64
94
95  def ReadFieldNumberAndWireType(self):
96    return wire_format.UnpackTag(self.ReadVarint())
97
98  def ReadFloat(self):
99    result = struct.unpack('<f', self._bytes[self._pos:self._pos+4])[0]
100    self._pos += 4
101    return result
102
103  def ReadDouble(self):
104    result = struct.unpack('<d', self._bytes[self._pos:self._pos+8])[0]
105    self._pos += 8
106    return result
107
108  def EndOfStream(self):
109    return self._pos == len(self._bytes)
110
111
112@_parameterized.named_parameters(
113    ('_proto2', unittest_pb2),
114    ('_proto3', unittest_proto3_arena_pb2))
115@testing_refleaks.TestCase
116class ReflectionTest(unittest.TestCase):
117
118  def assertListsEqual(self, values, others):
119    self.assertEqual(len(values), len(others))
120    for i in range(len(values)):
121      self.assertEqual(values[i], others[i])
122
123  def testScalarConstructor(self, message_module):
124    # Constructor with only scalar types should succeed.
125    proto = message_module.TestAllTypes(
126        optional_int32=24,
127        optional_double=54.321,
128        optional_string='optional_string',
129        optional_float=None)
130
131    self.assertEqual(24, proto.optional_int32)
132    self.assertEqual(54.321, proto.optional_double)
133    self.assertEqual('optional_string', proto.optional_string)
134    if message_module is unittest_pb2:
135      self.assertFalse(proto.HasField("optional_float"))
136
137  def testRepeatedScalarConstructor(self, message_module):
138    # Constructor with only repeated scalar types should succeed.
139    proto = message_module.TestAllTypes(
140        repeated_int32=[1, 2, 3, 4],
141        repeated_double=[1.23, 54.321],
142        repeated_bool=[True, False, False],
143        repeated_string=["optional_string"],
144        repeated_float=None)
145
146    self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32))
147    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
148    self.assertEqual([True, False, False], list(proto.repeated_bool))
149    self.assertEqual(["optional_string"], list(proto.repeated_string))
150    self.assertEqual([], list(proto.repeated_float))
151
152  def testMixedConstructor(self, message_module):
153    # Constructor with only mixed types should succeed.
154    proto = message_module.TestAllTypes(
155        optional_int32=24,
156        optional_string='optional_string',
157        repeated_double=[1.23, 54.321],
158        repeated_bool=[True, False, False],
159        repeated_nested_message=[
160            message_module.TestAllTypes.NestedMessage(
161                bb=message_module.TestAllTypes.FOO),
162            message_module.TestAllTypes.NestedMessage(
163                bb=message_module.TestAllTypes.BAR)],
164        repeated_foreign_message=[
165            message_module.ForeignMessage(c=-43),
166            message_module.ForeignMessage(c=45324),
167            message_module.ForeignMessage(c=12)],
168        optional_nested_message=None)
169
170    self.assertEqual(24, proto.optional_int32)
171    self.assertEqual('optional_string', proto.optional_string)
172    self.assertEqual([1.23, 54.321], list(proto.repeated_double))
173    self.assertEqual([True, False, False], list(proto.repeated_bool))
174    self.assertEqual(
175        [message_module.TestAllTypes.NestedMessage(
176            bb=message_module.TestAllTypes.FOO),
177         message_module.TestAllTypes.NestedMessage(
178             bb=message_module.TestAllTypes.BAR)],
179        list(proto.repeated_nested_message))
180    self.assertEqual(
181        [message_module.ForeignMessage(c=-43),
182         message_module.ForeignMessage(c=45324),
183         message_module.ForeignMessage(c=12)],
184        list(proto.repeated_foreign_message))
185    self.assertFalse(proto.HasField("optional_nested_message"))
186
187  def testConstructorTypeError(self, message_module):
188    self.assertRaises(
189        TypeError, message_module.TestAllTypes, optional_int32='foo')
190    self.assertRaises(
191        TypeError, message_module.TestAllTypes, optional_string=1234)
192    self.assertRaises(
193        TypeError, message_module.TestAllTypes, optional_nested_message=1234)
194    self.assertRaises(
195        TypeError, message_module.TestAllTypes, repeated_int32=1234)
196    self.assertRaises(
197        TypeError, message_module.TestAllTypes, repeated_int32=['foo'])
198    self.assertRaises(
199        TypeError, message_module.TestAllTypes, repeated_string=1234)
200    self.assertRaises(
201        TypeError, message_module.TestAllTypes, repeated_string=[1234])
202    self.assertRaises(
203        TypeError, message_module.TestAllTypes, repeated_nested_message=1234)
204    self.assertRaises(
205        TypeError, message_module.TestAllTypes, repeated_nested_message=[1234])
206
207  def testConstructorInvalidatesCachedByteSize(self, message_module):
208    message = message_module.TestAllTypes(optional_int32=12)
209    self.assertEqual(2, message.ByteSize())
210
211    message = message_module.TestAllTypes(
212        optional_nested_message=message_module.TestAllTypes.NestedMessage())
213    self.assertEqual(3, message.ByteSize())
214
215    message = message_module.TestAllTypes(repeated_int32=[12])
216    # TODO(jieluo): Add this test back for proto3
217    if message_module is unittest_pb2:
218      self.assertEqual(3, message.ByteSize())
219
220    message = message_module.TestAllTypes(
221        repeated_nested_message=[message_module.TestAllTypes.NestedMessage()])
222    self.assertEqual(3, message.ByteSize())
223
224  def testReferencesToNestedMessage(self, message_module):
225    proto = message_module.TestAllTypes()
226    nested = proto.optional_nested_message
227    del proto
228    # A previous version had a bug where this would raise an exception when
229    # hitting a now-dead weak reference.
230    nested.bb = 23
231
232  def testOneOf(self, message_module):
233    proto = message_module.TestAllTypes()
234    proto.oneof_uint32 = 10
235    proto.oneof_nested_message.bb = 11
236    self.assertEqual(11, proto.oneof_nested_message.bb)
237    self.assertFalse(proto.HasField('oneof_uint32'))
238    nested = proto.oneof_nested_message
239    proto.oneof_string = 'abc'
240    self.assertEqual('abc', proto.oneof_string)
241    self.assertEqual(11, nested.bb)
242    self.assertFalse(proto.HasField('oneof_nested_message'))
243
244  def testGetDefaultMessageAfterDisconnectingDefaultMessage(
245      self, message_module):
246    proto = message_module.TestAllTypes()
247    nested = proto.optional_nested_message
248    proto.ClearField('optional_nested_message')
249    del proto
250    del nested
251    # Force a garbage collect so that the underlying CMessages are freed along
252    # with the Messages they point to. This is to make sure we're not deleting
253    # default message instances.
254    gc.collect()
255    proto = message_module.TestAllTypes()
256    nested = proto.optional_nested_message
257
258  def testDisconnectingNestedMessageAfterSettingField(self, message_module):
259    proto = message_module.TestAllTypes()
260    nested = proto.optional_nested_message
261    nested.bb = 5
262    self.assertTrue(proto.HasField('optional_nested_message'))
263    proto.ClearField('optional_nested_message')  # Should disconnect from parent
264    self.assertEqual(5, nested.bb)
265    self.assertEqual(0, proto.optional_nested_message.bb)
266    self.assertIsNot(nested, proto.optional_nested_message)
267    nested.bb = 23
268    self.assertFalse(proto.HasField('optional_nested_message'))
269    self.assertEqual(0, proto.optional_nested_message.bb)
270
271  def testDisconnectingNestedMessageBeforeGettingField(self, message_module):
272    proto = message_module.TestAllTypes()
273    self.assertFalse(proto.HasField('optional_nested_message'))
274    proto.ClearField('optional_nested_message')
275    self.assertFalse(proto.HasField('optional_nested_message'))
276
277  def testDisconnectingNestedMessageAfterMerge(self, message_module):
278    # This test exercises the code path that does not use ReleaseMessage().
279    # The underlying fear is that if we use ReleaseMessage() incorrectly,
280    # we will have memory leaks.  It's hard to check that that doesn't happen,
281    # but at least we can exercise that code path to make sure it works.
282    proto1 = message_module.TestAllTypes()
283    proto2 = message_module.TestAllTypes()
284    proto2.optional_nested_message.bb = 5
285    proto1.MergeFrom(proto2)
286    self.assertTrue(proto1.HasField('optional_nested_message'))
287    proto1.ClearField('optional_nested_message')
288    self.assertFalse(proto1.HasField('optional_nested_message'))
289
290  def testDisconnectingLazyNestedMessage(self, message_module):
291    # This test exercises releasing a nested message that is lazy. This test
292    # only exercises real code in the C++ implementation as Python does not
293    # support lazy parsing, but the current C++ implementation results in
294    # memory corruption and a crash.
295    if api_implementation.Type() != 'python':
296      return
297    proto = message_module.TestAllTypes()
298    proto.optional_lazy_message.bb = 5
299    proto.ClearField('optional_lazy_message')
300    del proto
301    gc.collect()
302
303  def testSingularListFields(self, message_module):
304    proto = message_module.TestAllTypes()
305    proto.optional_fixed32 = 1
306    proto.optional_int32 = 5
307    proto.optional_string = 'foo'
308    # Access sub-message but don't set it yet.
309    nested_message = proto.optional_nested_message
310    self.assertEqual(
311      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
312        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
313        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
314      proto.ListFields())
315
316    proto.optional_nested_message.bb = 123
317    self.assertEqual(
318      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 5),
319        (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
320        (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
321        (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
322             nested_message) ],
323      proto.ListFields())
324
325  def testRepeatedListFields(self, message_module):
326    proto = message_module.TestAllTypes()
327    proto.repeated_fixed32.append(1)
328    proto.repeated_int32.append(5)
329    proto.repeated_int32.append(11)
330    proto.repeated_string.extend(['foo', 'bar'])
331    proto.repeated_string.extend([])
332    proto.repeated_string.append('baz')
333    proto.repeated_string.extend(str(x) for x in range(2))
334    proto.optional_int32 = 21
335    proto.repeated_bool  # Access but don't set anything; should not be listed.
336    self.assertEqual(
337      [ (proto.DESCRIPTOR.fields_by_name['optional_int32'  ], 21),
338        (proto.DESCRIPTOR.fields_by_name['repeated_int32'  ], [5, 11]),
339        (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
340        (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
341          ['foo', 'bar', 'baz', '0', '1']) ],
342      proto.ListFields())
343
344  def testClearFieldWithUnknownFieldName(self, message_module):
345    proto = message_module.TestAllTypes()
346    self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
347    self.assertRaises(ValueError, proto.ClearField, b'nonexistent_field')
348
349  def testDisallowedAssignments(self, message_module):
350    # It's illegal to assign values directly to repeated fields
351    # or to nonrepeated composite fields.  Ensure that this fails.
352    proto = message_module.TestAllTypes()
353    # Repeated fields.
354    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
355    # Lists shouldn't work, either.
356    self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
357    # Composite fields.
358    self.assertRaises(AttributeError, setattr, proto,
359                      'optional_nested_message', 23)
360    # Assignment to a repeated nested message field without specifying
361    # the index in the array of nested messages.
362    self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
363                      'bb', 34)
364    # Assignment to an attribute of a repeated field.
365    self.assertRaises(AttributeError, setattr, proto.repeated_float,
366                      'some_attribute', 34)
367    # proto.nonexistent_field = 23 should fail as well.
368    self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
369
370  def testSingleScalarTypeSafety(self, message_module):
371    proto = message_module.TestAllTypes()
372    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
373    self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
374    self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
375    self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
376    self.assertRaises(TypeError, setattr, proto, 'optional_bool', 'foo')
377    self.assertRaises(TypeError, setattr, proto, 'optional_float', 'foo')
378    self.assertRaises(TypeError, setattr, proto, 'optional_double', 'foo')
379    # TODO(jieluo): Fix type checking difference for python and c extension
380    if (api_implementation.Type() == 'python' or
381        (sys.version_info.major, sys.version_info.minor) >= (3, 10)):
382      self.assertRaises(TypeError, setattr, proto, 'optional_bool', 1.1)
383    else:
384      proto.optional_bool = 1.1
385
386  def assertIntegerTypes(self, integer_fn, message_module):
387    """Verifies setting of scalar integers.
388
389    Args:
390      integer_fn: A function to wrap the integers that will be assigned.
391      message_module: unittest_pb2 or unittest_proto3_arena_pb2
392    """
393    def TestGetAndDeserialize(field_name, value, expected_type):
394      proto = message_module.TestAllTypes()
395      value = integer_fn(value)
396      setattr(proto, field_name, value)
397      self.assertIsInstance(getattr(proto, field_name), expected_type)
398      proto2 = message_module.TestAllTypes()
399      proto2.ParseFromString(proto.SerializeToString())
400      self.assertIsInstance(getattr(proto2, field_name), expected_type)
401
402    TestGetAndDeserialize('optional_int32', 1, int)
403    TestGetAndDeserialize('optional_int32', 1 << 30, int)
404    TestGetAndDeserialize('optional_uint32', 1 << 30, int)
405    integer_64 = int
406    if struct.calcsize('L') == 4:
407      # Python only has signed ints, so 32-bit python can't fit an uint32
408      # in an int.
409      TestGetAndDeserialize('optional_uint32', 1 << 31, integer_64)
410    else:
411      # 64-bit python can fit uint32 inside an int
412      TestGetAndDeserialize('optional_uint32', 1 << 31, int)
413    TestGetAndDeserialize('optional_int64', 1 << 30, integer_64)
414    TestGetAndDeserialize('optional_int64', 1 << 60, integer_64)
415    TestGetAndDeserialize('optional_uint64', 1 << 30, integer_64)
416    TestGetAndDeserialize('optional_uint64', 1 << 60, integer_64)
417
418  def testIntegerTypes(self, message_module):
419    self.assertIntegerTypes(lambda x: x, message_module)
420
421  def testNonStandardIntegerTypes(self, message_module):
422    self.assertIntegerTypes(test_util.NonStandardInteger, message_module)
423
424  def testIllegalValuesForIntegers(self, message_module):
425    pb = message_module.TestAllTypes()
426
427    # Strings are illegal, even when the represent an integer.
428    with self.assertRaises(TypeError):
429      pb.optional_uint64 = '2'
430
431    # The exact error should propagate with a poorly written custom integer.
432    with self.assertRaisesRegex(RuntimeError, 'my_error'):
433      pb.optional_uint64 = test_util.NonStandardInteger(5, 'my_error')
434
435  def assetIntegerBoundsChecking(self, integer_fn, message_module):
436    """Verifies bounds checking for scalar integer fields.
437
438    Args:
439      integer_fn: A function to wrap the integers that will be assigned.
440      message_module: unittest_pb2 or unittest_proto3_arena_pb2
441    """
442    def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
443      pb = message_module.TestAllTypes()
444      expected_min = integer_fn(expected_min)
445      expected_max = integer_fn(expected_max)
446      setattr(pb, field_name, expected_min)
447      self.assertEqual(expected_min, getattr(pb, field_name))
448      setattr(pb, field_name, expected_max)
449      self.assertEqual(expected_max, getattr(pb, field_name))
450      self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
451                        expected_min - 1)
452      self.assertRaises((ValueError, TypeError), setattr, pb, field_name,
453                        expected_max + 1)
454
455    TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
456    TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
457    TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
458    TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
459    # A bit of white-box testing since -1 is an int and not a long in C++ and
460    # so goes down a different path.
461    pb = message_module.TestAllTypes()
462    with self.assertRaises((ValueError, TypeError)):
463      pb.optional_uint64 = integer_fn(-(1 << 63))
464
465    pb = message_module.TestAllTypes()
466    pb.optional_nested_enum = integer_fn(1)
467    self.assertEqual(1, pb.optional_nested_enum)
468
469  def testSingleScalarBoundsChecking(self, message_module):
470    self.assetIntegerBoundsChecking(lambda x: x, message_module)
471
472  def testNonStandardSingleScalarBoundsChecking(self, message_module):
473    self.assetIntegerBoundsChecking(
474        test_util.NonStandardInteger, message_module)
475
476  def testRepeatedScalarTypeSafety(self, message_module):
477    proto = message_module.TestAllTypes()
478    self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
479    self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
480    self.assertRaises(TypeError, proto.repeated_string, 10)
481    self.assertRaises(TypeError, proto.repeated_bytes, 10)
482
483    proto.repeated_int32.append(10)
484    proto.repeated_int32[0] = 23
485    self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
486    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
487    self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, [])
488    self.assertRaises(TypeError, proto.repeated_int32.__setitem__,
489                      'index', 23)
490
491    proto.repeated_string.append('2')
492    self.assertRaises(TypeError, proto.repeated_string.__setitem__, 0, 10)
493
494    # Repeated enums tests.
495    #proto.repeated_nested_enum.append(0)
496
497  def testSingleScalarGettersAndSetters(self, message_module):
498    proto = message_module.TestAllTypes()
499    self.assertEqual(0, proto.optional_int32)
500    proto.optional_int32 = 1
501    self.assertEqual(1, proto.optional_int32)
502
503    proto.optional_uint64 = 0xffffffffffff
504    self.assertEqual(0xffffffffffff, proto.optional_uint64)
505    proto.optional_uint64 = 0xffffffffffffffff
506    self.assertEqual(0xffffffffffffffff, proto.optional_uint64)
507    # TODO(robinson): Test all other scalar field types.
508
509  def testEnums(self, message_module):
510    proto = message_module.TestAllTypes()
511    self.assertEqual(1, proto.FOO)
512    self.assertEqual(1, message_module.TestAllTypes.FOO)
513    self.assertEqual(2, proto.BAR)
514    self.assertEqual(2, message_module.TestAllTypes.BAR)
515    self.assertEqual(3, proto.BAZ)
516    self.assertEqual(3, message_module.TestAllTypes.BAZ)
517
518  def testEnum_Name(self, message_module):
519    self.assertEqual(
520        'FOREIGN_FOO',
521        message_module.ForeignEnum.Name(message_module.FOREIGN_FOO))
522    self.assertEqual(
523        'FOREIGN_BAR',
524        message_module.ForeignEnum.Name(message_module.FOREIGN_BAR))
525    self.assertEqual(
526        'FOREIGN_BAZ',
527        message_module.ForeignEnum.Name(message_module.FOREIGN_BAZ))
528    self.assertRaises(ValueError,
529                      message_module.ForeignEnum.Name, 11312)
530
531    proto = message_module.TestAllTypes()
532    self.assertEqual('FOO',
533                     proto.NestedEnum.Name(proto.FOO))
534    self.assertEqual('FOO',
535                     message_module.TestAllTypes.NestedEnum.Name(proto.FOO))
536    self.assertEqual('BAR',
537                     proto.NestedEnum.Name(proto.BAR))
538    self.assertEqual('BAR',
539                     message_module.TestAllTypes.NestedEnum.Name(proto.BAR))
540    self.assertEqual('BAZ',
541                     proto.NestedEnum.Name(proto.BAZ))
542    self.assertEqual('BAZ',
543                     message_module.TestAllTypes.NestedEnum.Name(proto.BAZ))
544    self.assertRaises(ValueError,
545                      proto.NestedEnum.Name, 11312)
546    self.assertRaises(ValueError,
547                      message_module.TestAllTypes.NestedEnum.Name, 11312)
548
549    # Check some coercion cases.
550    self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name,
551                      11312.0)
552    self.assertRaises(TypeError, message_module.TestAllTypes.NestedEnum.Name,
553                      None)
554    self.assertEqual('FOO', message_module.TestAllTypes.NestedEnum.Name(True))
555
556  def testEnum_Value(self, message_module):
557    self.assertEqual(message_module.FOREIGN_FOO,
558                     message_module.ForeignEnum.Value('FOREIGN_FOO'))
559    self.assertEqual(message_module.FOREIGN_FOO,
560                     message_module.ForeignEnum.FOREIGN_FOO)
561
562    self.assertEqual(message_module.FOREIGN_BAR,
563                     message_module.ForeignEnum.Value('FOREIGN_BAR'))
564    self.assertEqual(message_module.FOREIGN_BAR,
565                     message_module.ForeignEnum.FOREIGN_BAR)
566
567    self.assertEqual(message_module.FOREIGN_BAZ,
568                     message_module.ForeignEnum.Value('FOREIGN_BAZ'))
569    self.assertEqual(message_module.FOREIGN_BAZ,
570                     message_module.ForeignEnum.FOREIGN_BAZ)
571
572    self.assertRaises(ValueError,
573                      message_module.ForeignEnum.Value, 'FO')
574    with self.assertRaises(AttributeError):
575      message_module.ForeignEnum.FO
576
577    proto = message_module.TestAllTypes()
578    self.assertEqual(proto.FOO,
579                     proto.NestedEnum.Value('FOO'))
580    self.assertEqual(proto.FOO,
581                     proto.NestedEnum.FOO)
582
583    self.assertEqual(proto.FOO,
584                     message_module.TestAllTypes.NestedEnum.Value('FOO'))
585    self.assertEqual(proto.FOO,
586                     message_module.TestAllTypes.NestedEnum.FOO)
587
588    self.assertEqual(proto.BAR,
589                     proto.NestedEnum.Value('BAR'))
590    self.assertEqual(proto.BAR,
591                     proto.NestedEnum.BAR)
592
593    self.assertEqual(proto.BAR,
594                     message_module.TestAllTypes.NestedEnum.Value('BAR'))
595    self.assertEqual(proto.BAR,
596                     message_module.TestAllTypes.NestedEnum.BAR)
597
598    self.assertEqual(proto.BAZ,
599                     proto.NestedEnum.Value('BAZ'))
600    self.assertEqual(proto.BAZ,
601                     proto.NestedEnum.BAZ)
602
603    self.assertEqual(proto.BAZ,
604                     message_module.TestAllTypes.NestedEnum.Value('BAZ'))
605    self.assertEqual(proto.BAZ,
606                     message_module.TestAllTypes.NestedEnum.BAZ)
607
608    self.assertRaises(ValueError,
609                      proto.NestedEnum.Value, 'Foo')
610    with self.assertRaises(AttributeError):
611      proto.NestedEnum.Value.Foo
612
613    self.assertRaises(ValueError,
614                      message_module.TestAllTypes.NestedEnum.Value, 'Foo')
615    with self.assertRaises(AttributeError):
616      message_module.TestAllTypes.NestedEnum.Value.Foo
617
618  def testEnum_KeysAndValues(self, message_module):
619    if message_module == unittest_pb2:
620      keys = ['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ']
621      values = [4, 5, 6]
622      items = [('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)]
623    else:
624      keys = ['FOREIGN_ZERO', 'FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ']
625      values = [0, 4, 5, 6]
626      items = [('FOREIGN_ZERO', 0), ('FOREIGN_FOO', 4),
627               ('FOREIGN_BAR', 5), ('FOREIGN_BAZ', 6)]
628    self.assertEqual(keys,
629                     list(message_module.ForeignEnum.keys()))
630    self.assertEqual(values,
631                     list(message_module.ForeignEnum.values()))
632    self.assertEqual(items,
633                     list(message_module.ForeignEnum.items()))
634
635    proto = message_module.TestAllTypes()
636    if message_module == unittest_pb2:
637      keys = ['FOO', 'BAR', 'BAZ', 'NEG']
638      values = [1, 2, 3, -1]
639      items = [('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)]
640    else:
641      keys = ['ZERO', 'FOO', 'BAR', 'BAZ', 'NEG']
642      values = [0, 1, 2, 3, -1]
643      items = [('ZERO', 0), ('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)]
644    self.assertEqual(keys, list(proto.NestedEnum.keys()))
645    self.assertEqual(values, list(proto.NestedEnum.values()))
646    self.assertEqual(items,
647                     list(proto.NestedEnum.items()))
648
649  def testStaticParseFrom(self, message_module):
650    proto1 = message_module.TestAllTypes()
651    test_util.SetAllFields(proto1)
652
653    string1 = proto1.SerializeToString()
654    proto2 = message_module.TestAllTypes.FromString(string1)
655
656    # Messages should be equal.
657    self.assertEqual(proto2, proto1)
658
659  def testMergeFromSingularField(self, message_module):
660    # Test merge with just a singular field.
661    proto1 = message_module.TestAllTypes()
662    proto1.optional_int32 = 1
663
664    proto2 = message_module.TestAllTypes()
665    # This shouldn't get overwritten.
666    proto2.optional_string = 'value'
667
668    proto2.MergeFrom(proto1)
669    self.assertEqual(1, proto2.optional_int32)
670    self.assertEqual('value', proto2.optional_string)
671
672  def testMergeFromRepeatedField(self, message_module):
673    # Test merge with just a repeated field.
674    proto1 = message_module.TestAllTypes()
675    proto1.repeated_int32.append(1)
676    proto1.repeated_int32.append(2)
677
678    proto2 = message_module.TestAllTypes()
679    proto2.repeated_int32.append(0)
680    proto2.MergeFrom(proto1)
681
682    self.assertEqual(0, proto2.repeated_int32[0])
683    self.assertEqual(1, proto2.repeated_int32[1])
684    self.assertEqual(2, proto2.repeated_int32[2])
685
686  def testMergeFromRepeatedNestedMessage(self, message_module):
687    # Test merge with a repeated nested message.
688    proto1 = message_module.TestAllTypes()
689    m = proto1.repeated_nested_message.add()
690    m.bb = 123
691    m = proto1.repeated_nested_message.add()
692    m.bb = 321
693
694    proto2 = message_module.TestAllTypes()
695    m = proto2.repeated_nested_message.add()
696    m.bb = 999
697    proto2.MergeFrom(proto1)
698    self.assertEqual(999, proto2.repeated_nested_message[0].bb)
699    self.assertEqual(123, proto2.repeated_nested_message[1].bb)
700    self.assertEqual(321, proto2.repeated_nested_message[2].bb)
701
702    proto3 = message_module.TestAllTypes()
703    proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message)
704    self.assertEqual(999, proto3.repeated_nested_message[0].bb)
705    self.assertEqual(123, proto3.repeated_nested_message[1].bb)
706    self.assertEqual(321, proto3.repeated_nested_message[2].bb)
707
708  def testMergeFromAllFields(self, message_module):
709    # With all fields set.
710    proto1 = message_module.TestAllTypes()
711    test_util.SetAllFields(proto1)
712    proto2 = message_module.TestAllTypes()
713    proto2.MergeFrom(proto1)
714
715    # Messages should be equal.
716    self.assertEqual(proto2, proto1)
717
718    # Serialized string should be equal too.
719    string1 = proto1.SerializeToString()
720    string2 = proto2.SerializeToString()
721    self.assertEqual(string1, string2)
722
723  def testMergeFromBug(self, message_module):
724    message1 = message_module.TestAllTypes()
725    message2 = message_module.TestAllTypes()
726
727    # Cause optional_nested_message to be instantiated within message1, even
728    # though it is not considered to be "present".
729    message1.optional_nested_message
730    self.assertFalse(message1.HasField('optional_nested_message'))
731
732    # Merge into message2.  This should not instantiate the field is message2.
733    message2.MergeFrom(message1)
734    self.assertFalse(message2.HasField('optional_nested_message'))
735
736  def testCopyFromSingularField(self, message_module):
737    # Test copy with just a singular field.
738    proto1 = message_module.TestAllTypes()
739    proto1.optional_int32 = 1
740    proto1.optional_string = 'important-text'
741
742    proto2 = message_module.TestAllTypes()
743    proto2.optional_string = 'value'
744
745    proto2.CopyFrom(proto1)
746    self.assertEqual(1, proto2.optional_int32)
747    self.assertEqual('important-text', proto2.optional_string)
748
749  def testCopyFromRepeatedField(self, message_module):
750    # Test copy with a repeated field.
751    proto1 = message_module.TestAllTypes()
752    proto1.repeated_int32.append(1)
753    proto1.repeated_int32.append(2)
754
755    proto2 = message_module.TestAllTypes()
756    proto2.repeated_int32.append(0)
757    proto2.CopyFrom(proto1)
758
759    self.assertEqual(1, proto2.repeated_int32[0])
760    self.assertEqual(2, proto2.repeated_int32[1])
761
762  def testCopyFromAllFields(self, message_module):
763    # With all fields set.
764    proto1 = message_module.TestAllTypes()
765    test_util.SetAllFields(proto1)
766    proto2 = message_module.TestAllTypes()
767    proto2.CopyFrom(proto1)
768
769    # Messages should be equal.
770    self.assertEqual(proto2, proto1)
771
772    # Serialized string should be equal too.
773    string1 = proto1.SerializeToString()
774    string2 = proto2.SerializeToString()
775    self.assertEqual(string1, string2)
776
777  def testCopyFromSelf(self, message_module):
778    proto1 = message_module.TestAllTypes()
779    proto1.repeated_int32.append(1)
780    proto1.optional_int32 = 2
781    proto1.optional_string = 'important-text'
782
783    proto1.CopyFrom(proto1)
784    self.assertEqual(1, proto1.repeated_int32[0])
785    self.assertEqual(2, proto1.optional_int32)
786    self.assertEqual('important-text', proto1.optional_string)
787
788  def testDeepCopy(self, message_module):
789    proto1 = message_module.TestAllTypes()
790    proto1.optional_int32 = 1
791    proto2 = copy.deepcopy(proto1)
792    self.assertEqual(1, proto2.optional_int32)
793
794    proto1.repeated_int32.append(2)
795    proto1.repeated_int32.append(3)
796    container = copy.deepcopy(proto1.repeated_int32)
797    self.assertEqual([2, 3], container)
798    container.remove(container[0])
799    self.assertEqual([3], container)
800
801    message1 = proto1.repeated_nested_message.add()
802    message1.bb = 1
803    messages = copy.deepcopy(proto1.repeated_nested_message)
804    self.assertEqual(proto1.repeated_nested_message, messages)
805    message1.bb = 2
806    self.assertNotEqual(proto1.repeated_nested_message, messages)
807    messages.remove(messages[0])
808    self.assertEqual(len(messages), 0)
809
810    # TODO(anuraag): Implement deepcopy for extension dict
811
812  def testDisconnectingBeforeClear(self, message_module):
813    proto = message_module.TestAllTypes()
814    nested = proto.optional_nested_message
815    proto.Clear()
816    self.assertIsNot(nested, proto.optional_nested_message)
817    nested.bb = 23
818    self.assertFalse(proto.HasField('optional_nested_message'))
819    self.assertEqual(0, proto.optional_nested_message.bb)
820
821    proto = message_module.TestAllTypes()
822    nested = proto.optional_nested_message
823    nested.bb = 5
824    foreign = proto.optional_foreign_message
825    foreign.c = 6
826    proto.Clear()
827    self.assertIsNot(nested, proto.optional_nested_message)
828    self.assertIsNot(foreign, proto.optional_foreign_message)
829    self.assertEqual(5, nested.bb)
830    self.assertEqual(6, foreign.c)
831    nested.bb = 15
832    foreign.c = 16
833    self.assertFalse(proto.HasField('optional_nested_message'))
834    self.assertEqual(0, proto.optional_nested_message.bb)
835    self.assertFalse(proto.HasField('optional_foreign_message'))
836    self.assertEqual(0, proto.optional_foreign_message.c)
837
838  def testStringUTF8Encoding(self, message_module):
839    proto = message_module.TestAllTypes()
840
841    # Assignment of a unicode object to a field of type 'bytes' is not allowed.
842    self.assertRaises(TypeError,
843                      setattr, proto, 'optional_bytes', u'unicode object')
844
845    # Check that the default value is of python's 'unicode' type.
846    self.assertEqual(type(proto.optional_string), str)
847
848    proto.optional_string = str('Testing')
849    self.assertEqual(proto.optional_string, str('Testing'))
850
851    # Assign a value of type 'str' which can be encoded in UTF-8.
852    proto.optional_string = str('Testing')
853    self.assertEqual(proto.optional_string, str('Testing'))
854
855    # Try to assign a 'bytes' object which contains non-UTF-8.
856    self.assertRaises(ValueError,
857                      setattr, proto, 'optional_string', b'a\x80a')
858    # No exception: Assign already encoded UTF-8 bytes to a string field.
859    utf8_bytes = u'Тест'.encode('utf-8')
860    proto.optional_string = utf8_bytes
861    # No exception: Assign the a non-ascii unicode object.
862    proto.optional_string = u'Тест'
863    # No exception thrown (normal str assignment containing ASCII).
864    proto.optional_string = 'abc'
865
866  def testBytesInTextFormat(self, message_module):
867    proto = message_module.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff')
868    self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', str(proto))
869
870  def testEmptyNestedMessage(self, message_module):
871    proto = message_module.TestAllTypes()
872    proto.optional_nested_message.MergeFrom(
873        message_module.TestAllTypes.NestedMessage())
874    self.assertTrue(proto.HasField('optional_nested_message'))
875
876    proto = message_module.TestAllTypes()
877    proto.optional_nested_message.CopyFrom(
878        message_module.TestAllTypes.NestedMessage())
879    self.assertTrue(proto.HasField('optional_nested_message'))
880
881    proto = message_module.TestAllTypes()
882    bytes_read = proto.optional_nested_message.MergeFromString(b'')
883    self.assertEqual(0, bytes_read)
884    self.assertTrue(proto.HasField('optional_nested_message'))
885
886    proto = message_module.TestAllTypes()
887    proto.optional_nested_message.ParseFromString(b'')
888    self.assertTrue(proto.HasField('optional_nested_message'))
889
890    serialized = proto.SerializeToString()
891    proto2 = message_module.TestAllTypes()
892    self.assertEqual(
893        len(serialized),
894        proto2.MergeFromString(serialized))
895    self.assertTrue(proto2.HasField('optional_nested_message'))
896
897
898# Class to test proto2-only features (required, extensions, etc.)
899@testing_refleaks.TestCase
900class Proto2ReflectionTest(unittest.TestCase):
901
902  def testRepeatedCompositeConstructor(self):
903    # Constructor with only repeated composite types should succeed.
904    proto = unittest_pb2.TestAllTypes(
905        repeated_nested_message=[
906            unittest_pb2.TestAllTypes.NestedMessage(
907                bb=unittest_pb2.TestAllTypes.FOO),
908            unittest_pb2.TestAllTypes.NestedMessage(
909                bb=unittest_pb2.TestAllTypes.BAR)],
910        repeated_foreign_message=[
911            unittest_pb2.ForeignMessage(c=-43),
912            unittest_pb2.ForeignMessage(c=45324),
913            unittest_pb2.ForeignMessage(c=12)],
914        repeatedgroup=[
915            unittest_pb2.TestAllTypes.RepeatedGroup(),
916            unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
917            unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
918
919    self.assertEqual(
920        [unittest_pb2.TestAllTypes.NestedMessage(
921            bb=unittest_pb2.TestAllTypes.FOO),
922         unittest_pb2.TestAllTypes.NestedMessage(
923             bb=unittest_pb2.TestAllTypes.BAR)],
924        list(proto.repeated_nested_message))
925    self.assertEqual(
926        [unittest_pb2.ForeignMessage(c=-43),
927         unittest_pb2.ForeignMessage(c=45324),
928         unittest_pb2.ForeignMessage(c=12)],
929        list(proto.repeated_foreign_message))
930    self.assertEqual(
931        [unittest_pb2.TestAllTypes.RepeatedGroup(),
932         unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
933         unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
934        list(proto.repeatedgroup))
935
936  def assertListsEqual(self, values, others):
937    self.assertEqual(len(values), len(others))
938    for i in range(len(values)):
939      self.assertEqual(values[i], others[i])
940
941  def testSimpleHasBits(self):
942    # Test a scalar.
943    proto = unittest_pb2.TestAllTypes()
944    self.assertFalse(proto.HasField('optional_int32'))
945    self.assertEqual(0, proto.optional_int32)
946    # HasField() shouldn't be true if all we've done is
947    # read the default value.
948    self.assertFalse(proto.HasField('optional_int32'))
949    proto.optional_int32 = 1
950    # Setting a value however *should* set the "has" bit.
951    self.assertTrue(proto.HasField('optional_int32'))
952    proto.ClearField('optional_int32')
953    # And clearing that value should unset the "has" bit.
954    self.assertFalse(proto.HasField('optional_int32'))
955
956  def testHasBitsWithSinglyNestedScalar(self):
957    # Helper used to test foreign messages and groups.
958    #
959    # composite_field_name should be the name of a non-repeated
960    # composite (i.e., foreign or group) field in TestAllTypes,
961    # and scalar_field_name should be the name of an integer-valued
962    # scalar field within that composite.
963    #
964    # I never thought I'd miss C++ macros and templates so much. :(
965    # This helper is semantically just:
966    #
967    #   assert proto.composite_field.scalar_field == 0
968    #   assert not proto.composite_field.HasField('scalar_field')
969    #   assert not proto.HasField('composite_field')
970    #
971    #   proto.composite_field.scalar_field = 10
972    #   old_composite_field = proto.composite_field
973    #
974    #   assert proto.composite_field.scalar_field == 10
975    #   assert proto.composite_field.HasField('scalar_field')
976    #   assert proto.HasField('composite_field')
977    #
978    #   proto.ClearField('composite_field')
979    #
980    #   assert not proto.composite_field.HasField('scalar_field')
981    #   assert not proto.HasField('composite_field')
982    #   assert proto.composite_field.scalar_field == 0
983    #
984    #   # Now ensure that ClearField('composite_field') disconnected
985    #   # the old field object from the object tree...
986    #   assert old_composite_field is not proto.composite_field
987    #   old_composite_field.scalar_field = 20
988    #   assert not proto.composite_field.HasField('scalar_field')
989    #   assert not proto.HasField('composite_field')
990    def TestCompositeHasBits(composite_field_name, scalar_field_name):
991      proto = unittest_pb2.TestAllTypes()
992      # First, check that we can get the scalar value, and see that it's the
993      # default (0), but that proto.HasField('omposite') and
994      # proto.composite.HasField('scalar') will still return False.
995      composite_field = getattr(proto, composite_field_name)
996      original_scalar_value = getattr(composite_field, scalar_field_name)
997      self.assertEqual(0, original_scalar_value)
998      # Assert that the composite object does not "have" the scalar.
999      self.assertFalse(composite_field.HasField(scalar_field_name))
1000      # Assert that proto does not "have" the composite field.
1001      self.assertFalse(proto.HasField(composite_field_name))
1002
1003      # Now set the scalar within the composite field.  Ensure that the setting
1004      # is reflected, and that proto.HasField('composite') and
1005      # proto.composite.HasField('scalar') now both return True.
1006      new_val = 20
1007      setattr(composite_field, scalar_field_name, new_val)
1008      self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
1009      # Hold on to a reference to the current composite_field object.
1010      old_composite_field = composite_field
1011      # Assert that the has methods now return true.
1012      self.assertTrue(composite_field.HasField(scalar_field_name))
1013      self.assertTrue(proto.HasField(composite_field_name))
1014
1015      # Now call the clear method...
1016      proto.ClearField(composite_field_name)
1017
1018      # ...and ensure that the "has" bits are all back to False...
1019      composite_field = getattr(proto, composite_field_name)
1020      self.assertFalse(composite_field.HasField(scalar_field_name))
1021      self.assertFalse(proto.HasField(composite_field_name))
1022      # ...and ensure that the scalar field has returned to its default.
1023      self.assertEqual(0, getattr(composite_field, scalar_field_name))
1024
1025      self.assertIsNot(old_composite_field, composite_field)
1026      setattr(old_composite_field, scalar_field_name, new_val)
1027      self.assertFalse(composite_field.HasField(scalar_field_name))
1028      self.assertFalse(proto.HasField(composite_field_name))
1029      self.assertEqual(0, getattr(composite_field, scalar_field_name))
1030
1031    # Test simple, single-level nesting when we set a scalar.
1032    TestCompositeHasBits('optionalgroup', 'a')
1033    TestCompositeHasBits('optional_nested_message', 'bb')
1034    TestCompositeHasBits('optional_foreign_message', 'c')
1035    TestCompositeHasBits('optional_import_message', 'd')
1036
1037  def testHasBitsWhenModifyingRepeatedFields(self):
1038    # Test nesting when we add an element to a repeated field in a submessage.
1039    proto = unittest_pb2.TestNestedMessageHasBits()
1040    proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
1041    self.assertEqual(
1042        [5], proto.optional_nested_message.nestedmessage_repeated_int32)
1043    self.assertTrue(proto.HasField('optional_nested_message'))
1044
1045    # Do the same test, but with a repeated composite field within the
1046    # submessage.
1047    proto.ClearField('optional_nested_message')
1048    self.assertFalse(proto.HasField('optional_nested_message'))
1049    proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
1050    self.assertTrue(proto.HasField('optional_nested_message'))
1051
1052  def testHasBitsForManyLevelsOfNesting(self):
1053    # Test nesting many levels deep.
1054    recursive_proto = unittest_pb2.TestMutualRecursionA()
1055    self.assertFalse(recursive_proto.HasField('bb'))
1056    self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
1057    self.assertFalse(recursive_proto.HasField('bb'))
1058    recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
1059    self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
1060    self.assertTrue(recursive_proto.HasField('bb'))
1061    self.assertTrue(recursive_proto.bb.HasField('a'))
1062    self.assertTrue(recursive_proto.bb.a.HasField('bb'))
1063    self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
1064    self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
1065    self.assertFalse(recursive_proto.bb.a.bb.a.bb.HasField('a'))
1066    self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
1067
1068  def testSingularListExtensions(self):
1069    proto = unittest_pb2.TestAllExtensions()
1070    proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
1071    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 5
1072    proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
1073    self.assertEqual(
1074      [ (unittest_pb2.optional_int32_extension  , 5),
1075        (unittest_pb2.optional_fixed32_extension, 1),
1076        (unittest_pb2.optional_string_extension , 'foo') ],
1077      proto.ListFields())
1078    del proto.Extensions[unittest_pb2.optional_fixed32_extension]
1079    self.assertEqual(
1080        [(unittest_pb2.optional_int32_extension, 5),
1081         (unittest_pb2.optional_string_extension, 'foo')],
1082        proto.ListFields())
1083
1084  def testRepeatedListExtensions(self):
1085    proto = unittest_pb2.TestAllExtensions()
1086    proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
1087    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(5)
1088    proto.Extensions[unittest_pb2.repeated_int32_extension  ].append(11)
1089    proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
1090    proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
1091    proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
1092    proto.Extensions[unittest_pb2.optional_int32_extension  ] = 21
1093    self.assertEqual(
1094      [ (unittest_pb2.optional_int32_extension  , 21),
1095        (unittest_pb2.repeated_int32_extension  , [5, 11]),
1096        (unittest_pb2.repeated_fixed32_extension, [1]),
1097        (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
1098      proto.ListFields())
1099    del proto.Extensions[unittest_pb2.repeated_int32_extension]
1100    del proto.Extensions[unittest_pb2.repeated_string_extension]
1101    self.assertEqual(
1102        [(unittest_pb2.optional_int32_extension, 21),
1103         (unittest_pb2.repeated_fixed32_extension, [1])],
1104        proto.ListFields())
1105
1106  def testListFieldsAndExtensions(self):
1107    proto = unittest_pb2.TestFieldOrderings()
1108    test_util.SetAllFieldsAndExtensions(proto)
1109    unittest_pb2.my_extension_int
1110    self.assertEqual(
1111      [ (proto.DESCRIPTOR.fields_by_name['my_int'   ], 1),
1112        (unittest_pb2.my_extension_int               , 23),
1113        (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
1114        (unittest_pb2.my_extension_string            , 'bar'),
1115        (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
1116      proto.ListFields())
1117
1118  def testDefaultValues(self):
1119    proto = unittest_pb2.TestAllTypes()
1120    self.assertEqual(0, proto.optional_int32)
1121    self.assertEqual(0, proto.optional_int64)
1122    self.assertEqual(0, proto.optional_uint32)
1123    self.assertEqual(0, proto.optional_uint64)
1124    self.assertEqual(0, proto.optional_sint32)
1125    self.assertEqual(0, proto.optional_sint64)
1126    self.assertEqual(0, proto.optional_fixed32)
1127    self.assertEqual(0, proto.optional_fixed64)
1128    self.assertEqual(0, proto.optional_sfixed32)
1129    self.assertEqual(0, proto.optional_sfixed64)
1130    self.assertEqual(0.0, proto.optional_float)
1131    self.assertEqual(0.0, proto.optional_double)
1132    self.assertEqual(False, proto.optional_bool)
1133    self.assertEqual('', proto.optional_string)
1134    self.assertEqual(b'', proto.optional_bytes)
1135
1136    self.assertEqual(41, proto.default_int32)
1137    self.assertEqual(42, proto.default_int64)
1138    self.assertEqual(43, proto.default_uint32)
1139    self.assertEqual(44, proto.default_uint64)
1140    self.assertEqual(-45, proto.default_sint32)
1141    self.assertEqual(46, proto.default_sint64)
1142    self.assertEqual(47, proto.default_fixed32)
1143    self.assertEqual(48, proto.default_fixed64)
1144    self.assertEqual(49, proto.default_sfixed32)
1145    self.assertEqual(-50, proto.default_sfixed64)
1146    self.assertEqual(51.5, proto.default_float)
1147    self.assertEqual(52e3, proto.default_double)
1148    self.assertEqual(True, proto.default_bool)
1149    self.assertEqual('hello', proto.default_string)
1150    self.assertEqual(b'world', proto.default_bytes)
1151    self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
1152    self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
1153    self.assertEqual(unittest_import_pb2.IMPORT_BAR,
1154                     proto.default_import_enum)
1155
1156    proto = unittest_pb2.TestExtremeDefaultValues()
1157    self.assertEqual(u'\u1234', proto.utf8_string)
1158
1159  def testHasFieldWithUnknownFieldName(self):
1160    proto = unittest_pb2.TestAllTypes()
1161    self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
1162
1163  def testClearRemovesChildren(self):
1164    # Make sure there aren't any implementation bugs that are only partially
1165    # clearing the message (which can happen in the more complex C++
1166    # implementation which has parallel message lists).
1167    proto = unittest_pb2.TestRequiredForeign()
1168    for i in range(10):
1169      proto.repeated_message.add()
1170    proto2 = unittest_pb2.TestRequiredForeign()
1171    proto.CopyFrom(proto2)
1172    self.assertRaises(IndexError, lambda: proto.repeated_message[5])
1173
1174  def testSingleScalarClearField(self):
1175    proto = unittest_pb2.TestAllTypes()
1176    # Should be allowed to clear something that's not there (a no-op).
1177    proto.ClearField('optional_int32')
1178    proto.optional_int32 = 1
1179    self.assertTrue(proto.HasField('optional_int32'))
1180    proto.ClearField('optional_int32')
1181    self.assertEqual(0, proto.optional_int32)
1182    self.assertFalse(proto.HasField('optional_int32'))
1183    # TODO(robinson): Test all other scalar field types.
1184
1185  def testRepeatedScalars(self):
1186    proto = unittest_pb2.TestAllTypes()
1187
1188    self.assertFalse(proto.repeated_int32)
1189    self.assertEqual(0, len(proto.repeated_int32))
1190    proto.repeated_int32.append(5)
1191    proto.repeated_int32.append(10)
1192    proto.repeated_int32.append(15)
1193    self.assertTrue(proto.repeated_int32)
1194    self.assertEqual(3, len(proto.repeated_int32))
1195
1196    self.assertEqual([5, 10, 15], proto.repeated_int32)
1197
1198    # Test single retrieval.
1199    self.assertEqual(5, proto.repeated_int32[0])
1200    self.assertEqual(15, proto.repeated_int32[-1])
1201    # Test out-of-bounds indices.
1202    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
1203    self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
1204    # Test incorrect types passed to __getitem__.
1205    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
1206    self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
1207
1208    # Test single assignment.
1209    proto.repeated_int32[1] = 20
1210    self.assertEqual([5, 20, 15], proto.repeated_int32)
1211
1212    # Test insertion.
1213    proto.repeated_int32.insert(1, 25)
1214    self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
1215
1216    # Test slice retrieval.
1217    proto.repeated_int32.append(30)
1218    self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
1219    self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
1220
1221    # Test slice assignment with an iterator
1222    proto.repeated_int32[1:4] = (i for i in range(3))
1223    self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
1224
1225    # Test slice assignment.
1226    proto.repeated_int32[1:4] = [35, 40, 45]
1227    self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
1228
1229    # Test that we can use the field as an iterator.
1230    result = []
1231    for i in proto.repeated_int32:
1232      result.append(i)
1233    self.assertEqual([5, 35, 40, 45, 30], result)
1234
1235    # Test single deletion.
1236    del proto.repeated_int32[2]
1237    self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
1238
1239    # Test slice deletion.
1240    del proto.repeated_int32[2:]
1241    self.assertEqual([5, 35], proto.repeated_int32)
1242
1243    # Test extending.
1244    proto.repeated_int32.extend([3, 13])
1245    self.assertEqual([5, 35, 3, 13], proto.repeated_int32)
1246
1247    # Test clearing.
1248    proto.ClearField('repeated_int32')
1249    self.assertFalse(proto.repeated_int32)
1250    self.assertEqual(0, len(proto.repeated_int32))
1251
1252    proto.repeated_int32.append(1)
1253    self.assertEqual(1, proto.repeated_int32[-1])
1254    # Test assignment to a negative index.
1255    proto.repeated_int32[-1] = 2
1256    self.assertEqual(2, proto.repeated_int32[-1])
1257
1258    # Test deletion at negative indices.
1259    proto.repeated_int32[:] = [0, 1, 2, 3]
1260    del proto.repeated_int32[-1]
1261    self.assertEqual([0, 1, 2], proto.repeated_int32)
1262
1263    del proto.repeated_int32[-2]
1264    self.assertEqual([0, 2], proto.repeated_int32)
1265
1266    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3)
1267    self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300)
1268
1269    del proto.repeated_int32[-2:-1]
1270    self.assertEqual([2], proto.repeated_int32)
1271
1272    del proto.repeated_int32[100:10000]
1273    self.assertEqual([2], proto.repeated_int32)
1274
1275  def testRepeatedScalarsRemove(self):
1276    proto = unittest_pb2.TestAllTypes()
1277
1278    self.assertFalse(proto.repeated_int32)
1279    self.assertEqual(0, len(proto.repeated_int32))
1280    proto.repeated_int32.append(5)
1281    proto.repeated_int32.append(10)
1282    proto.repeated_int32.append(5)
1283    proto.repeated_int32.append(5)
1284
1285    self.assertEqual(4, len(proto.repeated_int32))
1286    proto.repeated_int32.remove(5)
1287    self.assertEqual(3, len(proto.repeated_int32))
1288    self.assertEqual(10, proto.repeated_int32[0])
1289    self.assertEqual(5, proto.repeated_int32[1])
1290    self.assertEqual(5, proto.repeated_int32[2])
1291
1292    proto.repeated_int32.remove(5)
1293    self.assertEqual(2, len(proto.repeated_int32))
1294    self.assertEqual(10, proto.repeated_int32[0])
1295    self.assertEqual(5, proto.repeated_int32[1])
1296
1297    proto.repeated_int32.remove(10)
1298    self.assertEqual(1, len(proto.repeated_int32))
1299    self.assertEqual(5, proto.repeated_int32[0])
1300
1301    # Remove a non-existent element.
1302    self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
1303
1304  def testRepeatedScalarsReverse_Empty(self):
1305    proto = unittest_pb2.TestAllTypes()
1306
1307    self.assertFalse(proto.repeated_int32)
1308    self.assertEqual(0, len(proto.repeated_int32))
1309
1310    self.assertIsNone(proto.repeated_int32.reverse())
1311
1312    self.assertFalse(proto.repeated_int32)
1313    self.assertEqual(0, len(proto.repeated_int32))
1314
1315  def testRepeatedScalarsReverse_NonEmpty(self):
1316    proto = unittest_pb2.TestAllTypes()
1317
1318    self.assertFalse(proto.repeated_int32)
1319    self.assertEqual(0, len(proto.repeated_int32))
1320
1321    proto.repeated_int32.append(1)
1322    proto.repeated_int32.append(2)
1323    proto.repeated_int32.append(3)
1324    proto.repeated_int32.append(4)
1325
1326    self.assertEqual(4, len(proto.repeated_int32))
1327
1328    self.assertIsNone(proto.repeated_int32.reverse())
1329
1330    self.assertEqual(4, len(proto.repeated_int32))
1331    self.assertEqual(4, proto.repeated_int32[0])
1332    self.assertEqual(3, proto.repeated_int32[1])
1333    self.assertEqual(2, proto.repeated_int32[2])
1334    self.assertEqual(1, proto.repeated_int32[3])
1335
1336  def testRepeatedComposites(self):
1337    proto = unittest_pb2.TestAllTypes()
1338    self.assertFalse(proto.repeated_nested_message)
1339    self.assertEqual(0, len(proto.repeated_nested_message))
1340    m0 = proto.repeated_nested_message.add()
1341    m1 = proto.repeated_nested_message.add()
1342    self.assertTrue(proto.repeated_nested_message)
1343    self.assertEqual(2, len(proto.repeated_nested_message))
1344    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
1345    self.assertIsInstance(m0, unittest_pb2.TestAllTypes.NestedMessage)
1346
1347    # Test out-of-bounds indices.
1348    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1349                      1234)
1350    self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
1351                      -1234)
1352
1353    # Test incorrect types passed to __getitem__.
1354    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
1355                      'foo')
1356    self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
1357                      None)
1358
1359    # Test slice retrieval.
1360    m2 = proto.repeated_nested_message.add()
1361    m3 = proto.repeated_nested_message.add()
1362    m4 = proto.repeated_nested_message.add()
1363    self.assertListsEqual(
1364        [m1, m2, m3], proto.repeated_nested_message[1:4])
1365    self.assertListsEqual(
1366        [m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
1367    self.assertListsEqual(
1368        [m0, m1], proto.repeated_nested_message[:2])
1369    self.assertListsEqual(
1370        [m2, m3, m4], proto.repeated_nested_message[2:])
1371    self.assertEqual(
1372        m0, proto.repeated_nested_message[0])
1373    self.assertListsEqual(
1374        [m0], proto.repeated_nested_message[:1])
1375
1376    # Test that we can use the field as an iterator.
1377    result = []
1378    for i in proto.repeated_nested_message:
1379      result.append(i)
1380    self.assertListsEqual([m0, m1, m2, m3, m4], result)
1381
1382    # Test single deletion.
1383    del proto.repeated_nested_message[2]
1384    self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message)
1385
1386    # Test slice deletion.
1387    del proto.repeated_nested_message[2:]
1388    self.assertListsEqual([m0, m1], proto.repeated_nested_message)
1389
1390    # Test extending.
1391    n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1)
1392    n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2)
1393    proto.repeated_nested_message.extend([n1,n2])
1394    self.assertEqual(4, len(proto.repeated_nested_message))
1395    self.assertEqual(n1, proto.repeated_nested_message[2])
1396    self.assertEqual(n2, proto.repeated_nested_message[3])
1397    self.assertRaises(TypeError,
1398                      proto.repeated_nested_message.extend, n1)
1399    self.assertRaises(TypeError,
1400                      proto.repeated_nested_message.extend, [0])
1401    wrong_message_type = unittest_pb2.TestAllTypes()
1402    self.assertRaises(TypeError,
1403                      proto.repeated_nested_message.extend,
1404                      [wrong_message_type])
1405
1406    # Test clearing.
1407    proto.ClearField('repeated_nested_message')
1408    self.assertFalse(proto.repeated_nested_message)
1409    self.assertEqual(0, len(proto.repeated_nested_message))
1410
1411    # Test constructing an element while adding it.
1412    proto.repeated_nested_message.add(bb=23)
1413    self.assertEqual(1, len(proto.repeated_nested_message))
1414    self.assertEqual(23, proto.repeated_nested_message[0].bb)
1415    self.assertRaises(TypeError, proto.repeated_nested_message.add, 23)
1416    with self.assertRaises(Exception):
1417      proto.repeated_nested_message[0] = 23
1418
1419  def testRepeatedCompositeRemove(self):
1420    proto = unittest_pb2.TestAllTypes()
1421
1422    self.assertEqual(0, len(proto.repeated_nested_message))
1423    m0 = proto.repeated_nested_message.add()
1424    # Need to set some differentiating variable so m0 != m1 != m2:
1425    m0.bb = len(proto.repeated_nested_message)
1426    m1 = proto.repeated_nested_message.add()
1427    m1.bb = len(proto.repeated_nested_message)
1428    self.assertTrue(m0 != m1)
1429    m2 = proto.repeated_nested_message.add()
1430    m2.bb = len(proto.repeated_nested_message)
1431    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
1432
1433    self.assertEqual(3, len(proto.repeated_nested_message))
1434    proto.repeated_nested_message.remove(m0)
1435    self.assertEqual(2, len(proto.repeated_nested_message))
1436    self.assertEqual(m1, proto.repeated_nested_message[0])
1437    self.assertEqual(m2, proto.repeated_nested_message[1])
1438
1439    # Removing m0 again or removing None should raise error
1440    self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0)
1441    self.assertRaises(ValueError, proto.repeated_nested_message.remove, None)
1442    self.assertEqual(2, len(proto.repeated_nested_message))
1443
1444    proto.repeated_nested_message.remove(m2)
1445    self.assertEqual(1, len(proto.repeated_nested_message))
1446    self.assertEqual(m1, proto.repeated_nested_message[0])
1447
1448  def testRepeatedCompositeReverse_Empty(self):
1449    proto = unittest_pb2.TestAllTypes()
1450
1451    self.assertFalse(proto.repeated_nested_message)
1452    self.assertEqual(0, len(proto.repeated_nested_message))
1453
1454    self.assertIsNone(proto.repeated_nested_message.reverse())
1455
1456    self.assertFalse(proto.repeated_nested_message)
1457    self.assertEqual(0, len(proto.repeated_nested_message))
1458
1459  def testRepeatedCompositeReverse_NonEmpty(self):
1460    proto = unittest_pb2.TestAllTypes()
1461
1462    self.assertFalse(proto.repeated_nested_message)
1463    self.assertEqual(0, len(proto.repeated_nested_message))
1464
1465    m0 = proto.repeated_nested_message.add()
1466    m0.bb = len(proto.repeated_nested_message)
1467    m1 = proto.repeated_nested_message.add()
1468    m1.bb = len(proto.repeated_nested_message)
1469    m2 = proto.repeated_nested_message.add()
1470    m2.bb = len(proto.repeated_nested_message)
1471    self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message)
1472
1473    self.assertIsNone(proto.repeated_nested_message.reverse())
1474
1475    self.assertListsEqual([m2, m1, m0], proto.repeated_nested_message)
1476
1477  def testHandWrittenReflection(self):
1478    # Hand written extensions are only supported by the pure-Python
1479    # implementation of the API.
1480    if api_implementation.Type() != 'python':
1481      return
1482
1483    FieldDescriptor = descriptor.FieldDescriptor
1484    foo_field_descriptor = FieldDescriptor(
1485        name='foo_field', full_name='MyProto.foo_field',
1486        index=0, number=1, type=FieldDescriptor.TYPE_INT64,
1487        cpp_type=FieldDescriptor.CPPTYPE_INT64,
1488        label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
1489        containing_type=None, message_type=None, enum_type=None,
1490        is_extension=False, extension_scope=None,
1491        options=descriptor_pb2.FieldOptions(),
1492        # pylint: disable=protected-access
1493        create_key=descriptor._internal_create_key)
1494    mydescriptor = descriptor.Descriptor(
1495        name='MyProto', full_name='MyProto', filename='ignored',
1496        containing_type=None, nested_types=[], enum_types=[],
1497        fields=[foo_field_descriptor], extensions=[],
1498        options=descriptor_pb2.MessageOptions(),
1499        # pylint: disable=protected-access
1500        create_key=descriptor._internal_create_key)
1501
1502    class MyProtoClass(
1503        message.Message, metaclass=reflection.GeneratedProtocolMessageType):
1504      DESCRIPTOR = mydescriptor
1505    myproto_instance = MyProtoClass()
1506    self.assertEqual(0, myproto_instance.foo_field)
1507    self.assertFalse(myproto_instance.HasField('foo_field'))
1508    myproto_instance.foo_field = 23
1509    self.assertEqual(23, myproto_instance.foo_field)
1510    self.assertTrue(myproto_instance.HasField('foo_field'))
1511
1512  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
1513  def testDescriptorProtoSupport(self):
1514    # Hand written descriptors/reflection are only supported by the pure-Python
1515    # implementation of the API.
1516    if api_implementation.Type() != 'python':
1517      return
1518
1519    def AddDescriptorField(proto, field_name, field_type):
1520      AddDescriptorField.field_index += 1
1521      new_field = proto.field.add()
1522      new_field.name = field_name
1523      new_field.type = field_type
1524      new_field.number = AddDescriptorField.field_index
1525      new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL
1526
1527    AddDescriptorField.field_index = 0
1528
1529    desc_proto = descriptor_pb2.DescriptorProto()
1530    desc_proto.name = 'Car'
1531    fdp = descriptor_pb2.FieldDescriptorProto
1532    AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING)
1533    AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64)
1534    AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL)
1535    AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE)
1536    # Add a repeated field
1537    AddDescriptorField.field_index += 1
1538    new_field = desc_proto.field.add()
1539    new_field.name = 'owners'
1540    new_field.type = fdp.TYPE_STRING
1541    new_field.number = AddDescriptorField.field_index
1542    new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
1543
1544    desc = descriptor.MakeDescriptor(desc_proto)
1545    self.assertTrue('name' in desc.fields_by_name)
1546    self.assertTrue('year' in desc.fields_by_name)
1547    self.assertTrue('automatic' in desc.fields_by_name)
1548    self.assertTrue('price' in desc.fields_by_name)
1549    self.assertTrue('owners' in desc.fields_by_name)
1550
1551    class CarMessage(
1552        message.Message, metaclass=reflection.GeneratedProtocolMessageType):
1553      DESCRIPTOR = desc
1554
1555    prius = CarMessage()
1556    prius.name = 'prius'
1557    prius.year = 2010
1558    prius.automatic = True
1559    prius.price = 25134.75
1560    prius.owners.extend(['bob', 'susan'])
1561
1562    serialized_prius = prius.SerializeToString()
1563    new_prius = reflection.ParseMessage(desc, serialized_prius)
1564    self.assertIsNot(new_prius, prius)
1565    self.assertEqual(prius, new_prius)
1566
1567    # these are unnecessary assuming message equality works as advertised but
1568    # explicitly check to be safe since we're mucking about in metaclass foo
1569    self.assertEqual(prius.name, new_prius.name)
1570    self.assertEqual(prius.year, new_prius.year)
1571    self.assertEqual(prius.automatic, new_prius.automatic)
1572    self.assertEqual(prius.price, new_prius.price)
1573    self.assertEqual(prius.owners, new_prius.owners)
1574
1575  def testExtensionDelete(self):
1576    extendee_proto = more_extensions_pb2.ExtendedMessage()
1577
1578    extension_int32 = more_extensions_pb2.optional_int_extension
1579    extendee_proto.Extensions[extension_int32] = 23
1580
1581    extension_repeated = more_extensions_pb2.repeated_int_extension
1582    extendee_proto.Extensions[extension_repeated].append(11)
1583
1584    extension_msg = more_extensions_pb2.optional_message_extension
1585    extendee_proto.Extensions[extension_msg].foreign_message_int = 56
1586
1587    self.assertEqual(len(extendee_proto.Extensions), 3)
1588    del extendee_proto.Extensions[extension_msg]
1589    self.assertEqual(len(extendee_proto.Extensions), 2)
1590    del extendee_proto.Extensions[extension_repeated]
1591    self.assertEqual(len(extendee_proto.Extensions), 1)
1592    # Delete a none exist extension. It is OK to "del m.Extensions[ext]"
1593    # even if the extension is not present in the message; we don't
1594    # raise KeyError. This is consistent with "m.Extensions[ext]"
1595    # returning a default value even if we did not set anything.
1596    del extendee_proto.Extensions[extension_repeated]
1597    self.assertEqual(len(extendee_proto.Extensions), 1)
1598    del extendee_proto.Extensions[extension_int32]
1599    self.assertEqual(len(extendee_proto.Extensions), 0)
1600
1601  def testExtensionIter(self):
1602    extendee_proto = more_extensions_pb2.ExtendedMessage()
1603
1604    extension_int32 = more_extensions_pb2.optional_int_extension
1605    extendee_proto.Extensions[extension_int32] = 23
1606
1607    extension_repeated = more_extensions_pb2.repeated_int_extension
1608    extendee_proto.Extensions[extension_repeated].append(11)
1609
1610    extension_msg = more_extensions_pb2.optional_message_extension
1611    extendee_proto.Extensions[extension_msg].foreign_message_int = 56
1612
1613    # Set some normal fields.
1614    extendee_proto.optional_int32 = 1
1615    extendee_proto.repeated_string.append('hi')
1616
1617    expected = (extension_int32, extension_msg, extension_repeated)
1618    count = 0
1619    for item in extendee_proto.Extensions:
1620      self.assertEqual(item.name, expected[count].name)
1621      self.assertIn(item, extendee_proto.Extensions)
1622      count += 1
1623    self.assertEqual(count, 3)
1624
1625  def testExtensionContainsError(self):
1626    extendee_proto = more_extensions_pb2.ExtendedMessage()
1627    self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, 0)
1628
1629    field = more_extensions_pb2.ExtendedMessage.DESCRIPTOR.fields_by_name[
1630        'optional_int32']
1631    self.assertRaises(KeyError, extendee_proto.Extensions.__contains__, field)
1632
1633  def testTopLevelExtensionsForOptionalScalar(self):
1634    extendee_proto = unittest_pb2.TestAllExtensions()
1635    extension = unittest_pb2.optional_int32_extension
1636    self.assertFalse(extendee_proto.HasExtension(extension))
1637    self.assertNotIn(extension, extendee_proto.Extensions)
1638    self.assertEqual(0, extendee_proto.Extensions[extension])
1639    # As with normal scalar fields, just doing a read doesn't actually set the
1640    # "has" bit.
1641    self.assertFalse(extendee_proto.HasExtension(extension))
1642    self.assertNotIn(extension, extendee_proto.Extensions)
1643    # Actually set the thing.
1644    extendee_proto.Extensions[extension] = 23
1645    self.assertEqual(23, extendee_proto.Extensions[extension])
1646    self.assertTrue(extendee_proto.HasExtension(extension))
1647    self.assertIn(extension, extendee_proto.Extensions)
1648    # Ensure that clearing works as well.
1649    extendee_proto.ClearExtension(extension)
1650    self.assertEqual(0, extendee_proto.Extensions[extension])
1651    self.assertFalse(extendee_proto.HasExtension(extension))
1652    self.assertNotIn(extension, extendee_proto.Extensions)
1653
1654  def testTopLevelExtensionsForRepeatedScalar(self):
1655    extendee_proto = unittest_pb2.TestAllExtensions()
1656    extension = unittest_pb2.repeated_string_extension
1657    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1658    self.assertNotIn(extension, extendee_proto.Extensions)
1659    extendee_proto.Extensions[extension].append('foo')
1660    self.assertEqual(['foo'], extendee_proto.Extensions[extension])
1661    self.assertIn(extension, extendee_proto.Extensions)
1662    string_list = extendee_proto.Extensions[extension]
1663    extendee_proto.ClearExtension(extension)
1664    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1665    self.assertNotIn(extension, extendee_proto.Extensions)
1666    self.assertIsNot(string_list, extendee_proto.Extensions[extension])
1667    # Shouldn't be allowed to do Extensions[extension] = 'a'
1668    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1669                      extension, 'a')
1670
1671  def testTopLevelExtensionsForOptionalMessage(self):
1672    extendee_proto = unittest_pb2.TestAllExtensions()
1673    extension = unittest_pb2.optional_foreign_message_extension
1674    self.assertFalse(extendee_proto.HasExtension(extension))
1675    self.assertNotIn(extension, extendee_proto.Extensions)
1676    self.assertEqual(0, extendee_proto.Extensions[extension].c)
1677    # As with normal (non-extension) fields, merely reading from the
1678    # thing shouldn't set the "has" bit.
1679    self.assertFalse(extendee_proto.HasExtension(extension))
1680    self.assertNotIn(extension, extendee_proto.Extensions)
1681    extendee_proto.Extensions[extension].c = 23
1682    self.assertEqual(23, extendee_proto.Extensions[extension].c)
1683    self.assertTrue(extendee_proto.HasExtension(extension))
1684    self.assertIn(extension, extendee_proto.Extensions)
1685    # Save a reference here.
1686    foreign_message = extendee_proto.Extensions[extension]
1687    extendee_proto.ClearExtension(extension)
1688    self.assertIsNot(foreign_message, extendee_proto.Extensions[extension])
1689    # Setting a field on foreign_message now shouldn't set
1690    # any "has" bits on extendee_proto.
1691    foreign_message.c = 42
1692    self.assertEqual(42, foreign_message.c)
1693    self.assertTrue(foreign_message.HasField('c'))
1694    self.assertFalse(extendee_proto.HasExtension(extension))
1695    self.assertNotIn(extension, extendee_proto.Extensions)
1696    # Shouldn't be allowed to do Extensions[extension] = 'a'
1697    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1698                      extension, 'a')
1699
1700  def testTopLevelExtensionsForRepeatedMessage(self):
1701    extendee_proto = unittest_pb2.TestAllExtensions()
1702    extension = unittest_pb2.repeatedgroup_extension
1703    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1704    group = extendee_proto.Extensions[extension].add()
1705    group.a = 23
1706    self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
1707    group.a = 42
1708    self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
1709    group_list = extendee_proto.Extensions[extension]
1710    extendee_proto.ClearExtension(extension)
1711    self.assertEqual(0, len(extendee_proto.Extensions[extension]))
1712    self.assertIsNot(group_list, extendee_proto.Extensions[extension])
1713    # Shouldn't be allowed to do Extensions[extension] = 'a'
1714    self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
1715                      extension, 'a')
1716
1717  def testNestedExtensions(self):
1718    extendee_proto = unittest_pb2.TestAllExtensions()
1719    extension = unittest_pb2.TestRequired.single
1720
1721    # We just test the non-repeated case.
1722    self.assertFalse(extendee_proto.HasExtension(extension))
1723    self.assertNotIn(extension, extendee_proto.Extensions)
1724    required = extendee_proto.Extensions[extension]
1725    self.assertEqual(0, required.a)
1726    self.assertFalse(extendee_proto.HasExtension(extension))
1727    self.assertNotIn(extension, extendee_proto.Extensions)
1728    required.a = 23
1729    self.assertEqual(23, extendee_proto.Extensions[extension].a)
1730    self.assertTrue(extendee_proto.HasExtension(extension))
1731    self.assertIn(extension, extendee_proto.Extensions)
1732    extendee_proto.ClearExtension(extension)
1733    self.assertIsNot(required, extendee_proto.Extensions[extension])
1734    self.assertFalse(extendee_proto.HasExtension(extension))
1735    self.assertNotIn(extension, extendee_proto.Extensions)
1736
1737  def testRegisteredExtensions(self):
1738    pool = unittest_pb2.DESCRIPTOR.pool
1739    self.assertTrue(
1740        pool.FindExtensionByNumber(
1741            unittest_pb2.TestAllExtensions.DESCRIPTOR, 1))
1742    self.assertIs(
1743        pool.FindExtensionByName(
1744            'protobuf_unittest.optional_int32_extension').containing_type,
1745        unittest_pb2.TestAllExtensions.DESCRIPTOR)
1746    # Make sure extensions haven't been registered into types that shouldn't
1747    # have any.
1748    self.assertEqual(0, len(
1749        pool.FindAllExtensions(unittest_pb2.TestAllTypes.DESCRIPTOR)))
1750
1751  # If message A directly contains message B, and
1752  # a.HasField('b') is currently False, then mutating any
1753  # extension in B should change a.HasField('b') to True
1754  # (and so on up the object tree).
1755  def testHasBitsForAncestorsOfExtendedMessage(self):
1756    # Optional scalar extension.
1757    toplevel = more_extensions_pb2.TopLevelMessage()
1758    self.assertFalse(toplevel.HasField('submessage'))
1759    self.assertEqual(0, toplevel.submessage.Extensions[
1760        more_extensions_pb2.optional_int_extension])
1761    self.assertFalse(toplevel.HasField('submessage'))
1762    toplevel.submessage.Extensions[
1763        more_extensions_pb2.optional_int_extension] = 23
1764    self.assertEqual(23, toplevel.submessage.Extensions[
1765        more_extensions_pb2.optional_int_extension])
1766    self.assertTrue(toplevel.HasField('submessage'))
1767
1768    # Repeated scalar extension.
1769    toplevel = more_extensions_pb2.TopLevelMessage()
1770    self.assertFalse(toplevel.HasField('submessage'))
1771    self.assertEqual([], toplevel.submessage.Extensions[
1772        more_extensions_pb2.repeated_int_extension])
1773    self.assertFalse(toplevel.HasField('submessage'))
1774    toplevel.submessage.Extensions[
1775        more_extensions_pb2.repeated_int_extension].append(23)
1776    self.assertEqual([23], toplevel.submessage.Extensions[
1777        more_extensions_pb2.repeated_int_extension])
1778    self.assertTrue(toplevel.HasField('submessage'))
1779
1780    # Optional message extension.
1781    toplevel = more_extensions_pb2.TopLevelMessage()
1782    self.assertFalse(toplevel.HasField('submessage'))
1783    self.assertEqual(0, toplevel.submessage.Extensions[
1784        more_extensions_pb2.optional_message_extension].foreign_message_int)
1785    self.assertFalse(toplevel.HasField('submessage'))
1786    toplevel.submessage.Extensions[
1787        more_extensions_pb2.optional_message_extension].foreign_message_int = 23
1788    self.assertEqual(23, toplevel.submessage.Extensions[
1789        more_extensions_pb2.optional_message_extension].foreign_message_int)
1790    self.assertTrue(toplevel.HasField('submessage'))
1791
1792    # Repeated message extension.
1793    toplevel = more_extensions_pb2.TopLevelMessage()
1794    self.assertFalse(toplevel.HasField('submessage'))
1795    self.assertEqual(0, len(toplevel.submessage.Extensions[
1796        more_extensions_pb2.repeated_message_extension]))
1797    self.assertFalse(toplevel.HasField('submessage'))
1798    foreign = toplevel.submessage.Extensions[
1799        more_extensions_pb2.repeated_message_extension].add()
1800    self.assertEqual(foreign, toplevel.submessage.Extensions[
1801        more_extensions_pb2.repeated_message_extension][0])
1802    self.assertTrue(toplevel.HasField('submessage'))
1803
1804  def testDisconnectionAfterClearingEmptyMessage(self):
1805    toplevel = more_extensions_pb2.TopLevelMessage()
1806    extendee_proto = toplevel.submessage
1807    extension = more_extensions_pb2.optional_message_extension
1808    extension_proto = extendee_proto.Extensions[extension]
1809    extendee_proto.ClearExtension(extension)
1810    extension_proto.foreign_message_int = 23
1811
1812    self.assertIsNot(extension_proto, extendee_proto.Extensions[extension])
1813
1814  def testExtensionFailureModes(self):
1815    extendee_proto = unittest_pb2.TestAllExtensions()
1816
1817    # Try non-extension-handle arguments to HasExtension,
1818    # ClearExtension(), and Extensions[]...
1819    self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
1820    self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
1821    self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
1822    self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
1823
1824    # Try something that *is* an extension handle, just not for
1825    # this message...
1826    for unknown_handle in (more_extensions_pb2.optional_int_extension,
1827                           more_extensions_pb2.optional_message_extension,
1828                           more_extensions_pb2.repeated_int_extension,
1829                           more_extensions_pb2.repeated_message_extension):
1830      self.assertRaises(KeyError, extendee_proto.HasExtension,
1831                        unknown_handle)
1832      self.assertRaises(KeyError, extendee_proto.ClearExtension,
1833                        unknown_handle)
1834      self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
1835                        unknown_handle)
1836      self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
1837                        unknown_handle, 5)
1838
1839    # Try call HasExtension() with a valid handle, but for a
1840    # *repeated* field.  (Just as with non-extension repeated
1841    # fields, Has*() isn't supported for extension repeated fields).
1842    self.assertRaises(KeyError, extendee_proto.HasExtension,
1843                      unittest_pb2.repeated_string_extension)
1844
1845  def testMergeFromOptionalGroup(self):
1846    # Test merge with an optional group.
1847    proto1 = unittest_pb2.TestAllTypes()
1848    proto1.optionalgroup.a = 12
1849    proto2 = unittest_pb2.TestAllTypes()
1850    proto2.MergeFrom(proto1)
1851    self.assertEqual(12, proto2.optionalgroup.a)
1852
1853  def testMergeFromExtensionsSingular(self):
1854    proto1 = unittest_pb2.TestAllExtensions()
1855    proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
1856
1857    proto2 = unittest_pb2.TestAllExtensions()
1858    proto2.MergeFrom(proto1)
1859    self.assertEqual(
1860        1, proto2.Extensions[unittest_pb2.optional_int32_extension])
1861
1862  def testMergeFromExtensionsRepeated(self):
1863    proto1 = unittest_pb2.TestAllExtensions()
1864    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
1865    proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
1866
1867    proto2 = unittest_pb2.TestAllExtensions()
1868    proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
1869    proto2.MergeFrom(proto1)
1870    self.assertEqual(
1871        3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
1872    self.assertEqual(
1873        0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
1874    self.assertEqual(
1875        1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
1876    self.assertEqual(
1877        2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
1878
1879  def testMergeFromExtensionsNestedMessage(self):
1880    proto1 = unittest_pb2.TestAllExtensions()
1881    ext1 = proto1.Extensions[
1882        unittest_pb2.repeated_nested_message_extension]
1883    m = ext1.add()
1884    m.bb = 222
1885    m = ext1.add()
1886    m.bb = 333
1887
1888    proto2 = unittest_pb2.TestAllExtensions()
1889    ext2 = proto2.Extensions[
1890        unittest_pb2.repeated_nested_message_extension]
1891    m = ext2.add()
1892    m.bb = 111
1893
1894    proto2.MergeFrom(proto1)
1895    ext2 = proto2.Extensions[
1896        unittest_pb2.repeated_nested_message_extension]
1897    self.assertEqual(3, len(ext2))
1898    self.assertEqual(111, ext2[0].bb)
1899    self.assertEqual(222, ext2[1].bb)
1900    self.assertEqual(333, ext2[2].bb)
1901
1902  def testCopyFromBadType(self):
1903    # The python implementation doesn't raise an exception in this
1904    # case. In theory it should.
1905    if api_implementation.Type() == 'python':
1906      return
1907    proto1 = unittest_pb2.TestAllTypes()
1908    proto2 = unittest_pb2.TestAllExtensions()
1909    self.assertRaises(TypeError, proto1.CopyFrom, proto2)
1910
1911  def testClear(self):
1912    proto = unittest_pb2.TestAllTypes()
1913    # C++ implementation does not support lazy fields right now so leave it
1914    # out for now.
1915    if api_implementation.Type() == 'python':
1916      test_util.SetAllFields(proto)
1917    else:
1918      test_util.SetAllNonLazyFields(proto)
1919    # Clear the message.
1920    proto.Clear()
1921    self.assertEqual(proto.ByteSize(), 0)
1922    empty_proto = unittest_pb2.TestAllTypes()
1923    self.assertEqual(proto, empty_proto)
1924
1925    # Test if extensions which were set are cleared.
1926    proto = unittest_pb2.TestAllExtensions()
1927    test_util.SetAllExtensions(proto)
1928    # Clear the message.
1929    proto.Clear()
1930    self.assertEqual(proto.ByteSize(), 0)
1931    empty_proto = unittest_pb2.TestAllExtensions()
1932    self.assertEqual(proto, empty_proto)
1933
1934  def testDisconnectingInOneof(self):
1935    m = unittest_pb2.TestOneof2()  # This message has two messages in a oneof.
1936    m.foo_message.moo_int = 5
1937    sub_message = m.foo_message
1938    # Accessing another message's field does not clear the first one
1939    self.assertEqual(m.foo_lazy_message.moo_int, 0)
1940    self.assertEqual(m.foo_message.moo_int, 5)
1941    # But mutating another message in the oneof detaches the first one.
1942    m.foo_lazy_message.moo_int = 6
1943    self.assertEqual(m.foo_message.moo_int, 0)
1944    # The reference we got above was detached and is still valid.
1945    self.assertEqual(sub_message.moo_int, 5)
1946    sub_message.moo_int = 7
1947
1948  def assertInitialized(self, proto):
1949    self.assertTrue(proto.IsInitialized())
1950    # Neither method should raise an exception.
1951    proto.SerializeToString()
1952    proto.SerializePartialToString()
1953
1954  def assertNotInitialized(self, proto, error_size=None):
1955    errors = []
1956    self.assertFalse(proto.IsInitialized())
1957    self.assertFalse(proto.IsInitialized(errors))
1958    self.assertEqual(error_size, len(errors))
1959    self.assertRaises(message.EncodeError, proto.SerializeToString)
1960    # "Partial" serialization doesn't care if message is uninitialized.
1961    proto.SerializePartialToString()
1962
1963  def testIsInitialized(self):
1964    # Trivial cases - all optional fields and extensions.
1965    proto = unittest_pb2.TestAllTypes()
1966    self.assertInitialized(proto)
1967    proto = unittest_pb2.TestAllExtensions()
1968    self.assertInitialized(proto)
1969
1970    # The case of uninitialized required fields.
1971    proto = unittest_pb2.TestRequired()
1972    self.assertNotInitialized(proto, 3)
1973    proto.a = proto.b = proto.c = 2
1974    self.assertInitialized(proto)
1975
1976    # The case of uninitialized submessage.
1977    proto = unittest_pb2.TestRequiredForeign()
1978    self.assertInitialized(proto)
1979    proto.optional_message.a = 1
1980    self.assertNotInitialized(proto, 2)
1981    proto.optional_message.b = 0
1982    proto.optional_message.c = 0
1983    self.assertInitialized(proto)
1984
1985    # Uninitialized repeated submessage.
1986    message1 = proto.repeated_message.add()
1987    self.assertNotInitialized(proto, 3)
1988    message1.a = message1.b = message1.c = 0
1989    self.assertInitialized(proto)
1990
1991    # Uninitialized repeated group in an extension.
1992    proto = unittest_pb2.TestAllExtensions()
1993    extension = unittest_pb2.TestRequired.multi
1994    message1 = proto.Extensions[extension].add()
1995    message2 = proto.Extensions[extension].add()
1996    self.assertNotInitialized(proto, 6)
1997    message1.a = 1
1998    message1.b = 1
1999    message1.c = 1
2000    self.assertNotInitialized(proto, 3)
2001    message2.a = 2
2002    message2.b = 2
2003    message2.c = 2
2004    self.assertInitialized(proto)
2005
2006    # Uninitialized nonrepeated message in an extension.
2007    proto = unittest_pb2.TestAllExtensions()
2008    extension = unittest_pb2.TestRequired.single
2009    proto.Extensions[extension].a = 1
2010    self.assertNotInitialized(proto, 2)
2011    proto.Extensions[extension].b = 2
2012    proto.Extensions[extension].c = 3
2013    self.assertInitialized(proto)
2014
2015    # Try passing an errors list.
2016    errors = []
2017    proto = unittest_pb2.TestRequired()
2018    self.assertFalse(proto.IsInitialized(errors))
2019    self.assertEqual(errors, ['a', 'b', 'c'])
2020    self.assertRaises(TypeError, proto.IsInitialized, 1, 2, 3)
2021
2022  @unittest.skipIf(
2023      api_implementation.Type() == 'python',
2024      'Errors are only available from the most recent C++ implementation.')
2025  def testFileDescriptorErrors(self):
2026    file_name = 'test_file_descriptor_errors.proto'
2027    package_name = 'test_file_descriptor_errors.proto'
2028    file_descriptor_proto = descriptor_pb2.FileDescriptorProto()
2029    file_descriptor_proto.name = file_name
2030    file_descriptor_proto.package = package_name
2031    m1 = file_descriptor_proto.message_type.add()
2032    m1.name = 'msg1'
2033    # Compiles the proto into the C++ descriptor pool
2034    descriptor.FileDescriptor(
2035        file_name,
2036        package_name,
2037        serialized_pb=file_descriptor_proto.SerializeToString())
2038    # Add a FileDescriptorProto that has duplicate symbols
2039    another_file_name = 'another_test_file_descriptor_errors.proto'
2040    file_descriptor_proto.name = another_file_name
2041    m2 = file_descriptor_proto.message_type.add()
2042    m2.name = 'msg2'
2043    with self.assertRaises(TypeError) as cm:
2044      descriptor.FileDescriptor(
2045          another_file_name,
2046          package_name,
2047          serialized_pb=file_descriptor_proto.SerializeToString())
2048      self.assertTrue(hasattr(cm, 'exception'), '%s not raised' %
2049                      getattr(cm.expected, '__name__', cm.expected))
2050      self.assertIn('test_file_descriptor_errors.proto', str(cm.exception))
2051      # Error message will say something about this definition being a
2052      # duplicate, though we don't check the message exactly to avoid a
2053      # dependency on the C++ logging code.
2054      self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
2055
2056  def testStringUTF8Serialization(self):
2057    proto = message_set_extensions_pb2.TestMessageSet()
2058    extension_message = message_set_extensions_pb2.TestMessageSetExtension2
2059    extension = extension_message.message_set_extension
2060
2061    test_utf8 = u'Тест'
2062    test_utf8_bytes = test_utf8.encode('utf-8')
2063
2064    # 'Test' in another language, using UTF-8 charset.
2065    proto.Extensions[extension].str = test_utf8
2066
2067    # Serialize using the MessageSet wire format (this is specified in the
2068    # .proto file).
2069    serialized = proto.SerializeToString()
2070
2071    # Check byte size.
2072    self.assertEqual(proto.ByteSize(), len(serialized))
2073
2074    raw = unittest_mset_pb2.RawMessageSet()
2075    bytes_read = raw.MergeFromString(serialized)
2076    self.assertEqual(len(serialized), bytes_read)
2077
2078    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2079
2080    self.assertEqual(1, len(raw.item))
2081    # Check that the type_id is the same as the tag ID in the .proto file.
2082    self.assertEqual(raw.item[0].type_id, 98418634)
2083
2084    # Check the actual bytes on the wire.
2085    self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes))
2086    bytes_read = message2.MergeFromString(raw.item[0].message)
2087    self.assertEqual(len(raw.item[0].message), bytes_read)
2088
2089    self.assertEqual(type(message2.str), str)
2090    self.assertEqual(message2.str, test_utf8)
2091
2092    # The pure Python API throws an exception on MergeFromString(),
2093    # if any of the string fields of the message can't be UTF-8 decoded.
2094    # The C++ implementation of the API has no way to check that on
2095    # MergeFromString and thus has no way to throw the exception.
2096    #
2097    # The pure Python API always returns objects of type 'unicode' (UTF-8
2098    # encoded), or 'bytes' (in 7 bit ASCII).
2099    badbytes = raw.item[0].message.replace(
2100        test_utf8_bytes, len(test_utf8_bytes) * b'\xff')
2101
2102    unicode_decode_failed = False
2103    try:
2104      message2.MergeFromString(badbytes)
2105    except UnicodeDecodeError:
2106      unicode_decode_failed = True
2107    string_field = message2.str
2108    self.assertTrue(unicode_decode_failed or type(string_field) is bytes)
2109
2110  def testSetInParent(self):
2111    proto = unittest_pb2.TestAllTypes()
2112    self.assertFalse(proto.HasField('optionalgroup'))
2113    proto.optionalgroup.SetInParent()
2114    self.assertTrue(proto.HasField('optionalgroup'))
2115
2116  def testPackageInitializationImport(self):
2117    """Test that we can import nested messages from their __init__.py.
2118
2119    Such setup is not trivial since at the time of processing of __init__.py one
2120    can't refer to its submodules by name in code, so expressions like
2121    google.protobuf.internal.import_test_package.inner_pb2
2122    don't work. They do work in imports, so we have assign an alias at import
2123    and then use that alias in generated code.
2124    """
2125    # We import here since it's the import that used to fail, and we want
2126    # the failure to have the right context.
2127    # pylint: disable=g-import-not-at-top
2128    from google.protobuf.internal import import_test_package
2129    # pylint: enable=g-import-not-at-top
2130    msg = import_test_package.myproto.Outer()
2131    # Just check the default value.
2132    self.assertEqual(57, msg.inner.value)
2133
2134#  Since we had so many tests for protocol buffer equality, we broke these out
2135#  into separate TestCase classes.
2136
2137
2138@testing_refleaks.TestCase
2139class TestAllTypesEqualityTest(unittest.TestCase):
2140
2141  def setUp(self):
2142    self.first_proto = unittest_pb2.TestAllTypes()
2143    self.second_proto = unittest_pb2.TestAllTypes()
2144
2145  def testNotHashable(self):
2146    self.assertRaises(TypeError, hash, self.first_proto)
2147
2148  def testSelfEquality(self):
2149    self.assertEqual(self.first_proto, self.first_proto)
2150
2151  def testEmptyProtosEqual(self):
2152    self.assertEqual(self.first_proto, self.second_proto)
2153
2154
2155@testing_refleaks.TestCase
2156class FullProtosEqualityTest(unittest.TestCase):
2157
2158  """Equality tests using completely-full protos as a starting point."""
2159
2160  def setUp(self):
2161    self.first_proto = unittest_pb2.TestAllTypes()
2162    self.second_proto = unittest_pb2.TestAllTypes()
2163    test_util.SetAllFields(self.first_proto)
2164    test_util.SetAllFields(self.second_proto)
2165
2166  def testNotHashable(self):
2167    self.assertRaises(TypeError, hash, self.first_proto)
2168
2169  def testNoneNotEqual(self):
2170    self.assertNotEqual(self.first_proto, None)
2171    self.assertNotEqual(None, self.second_proto)
2172
2173  def testNotEqualToOtherMessage(self):
2174    third_proto = unittest_pb2.TestRequired()
2175    self.assertNotEqual(self.first_proto, third_proto)
2176    self.assertNotEqual(third_proto, self.second_proto)
2177
2178  def testAllFieldsFilledEquality(self):
2179    self.assertEqual(self.first_proto, self.second_proto)
2180
2181  def testNonRepeatedScalar(self):
2182    # Nonrepeated scalar field change should cause inequality.
2183    self.first_proto.optional_int32 += 1
2184    self.assertNotEqual(self.first_proto, self.second_proto)
2185    # ...as should clearing a field.
2186    self.first_proto.ClearField('optional_int32')
2187    self.assertNotEqual(self.first_proto, self.second_proto)
2188
2189  def testNonRepeatedComposite(self):
2190    # Change a nonrepeated composite field.
2191    self.first_proto.optional_nested_message.bb += 1
2192    self.assertNotEqual(self.first_proto, self.second_proto)
2193    self.first_proto.optional_nested_message.bb -= 1
2194    self.assertEqual(self.first_proto, self.second_proto)
2195    # Clear a field in the nested message.
2196    self.first_proto.optional_nested_message.ClearField('bb')
2197    self.assertNotEqual(self.first_proto, self.second_proto)
2198    self.first_proto.optional_nested_message.bb = (
2199        self.second_proto.optional_nested_message.bb)
2200    self.assertEqual(self.first_proto, self.second_proto)
2201    # Remove the nested message entirely.
2202    self.first_proto.ClearField('optional_nested_message')
2203    self.assertNotEqual(self.first_proto, self.second_proto)
2204
2205  def testRepeatedScalar(self):
2206    # Change a repeated scalar field.
2207    self.first_proto.repeated_int32.append(5)
2208    self.assertNotEqual(self.first_proto, self.second_proto)
2209    self.first_proto.ClearField('repeated_int32')
2210    self.assertNotEqual(self.first_proto, self.second_proto)
2211
2212  def testRepeatedComposite(self):
2213    # Change value within a repeated composite field.
2214    self.first_proto.repeated_nested_message[0].bb += 1
2215    self.assertNotEqual(self.first_proto, self.second_proto)
2216    self.first_proto.repeated_nested_message[0].bb -= 1
2217    self.assertEqual(self.first_proto, self.second_proto)
2218    # Add a value to a repeated composite field.
2219    self.first_proto.repeated_nested_message.add()
2220    self.assertNotEqual(self.first_proto, self.second_proto)
2221    self.second_proto.repeated_nested_message.add()
2222    self.assertEqual(self.first_proto, self.second_proto)
2223
2224  def testNonRepeatedScalarHasBits(self):
2225    # Ensure that we test "has" bits as well as value for
2226    # nonrepeated scalar field.
2227    self.first_proto.ClearField('optional_int32')
2228    self.second_proto.optional_int32 = 0
2229    self.assertNotEqual(self.first_proto, self.second_proto)
2230
2231  def testNonRepeatedCompositeHasBits(self):
2232    # Ensure that we test "has" bits as well as value for
2233    # nonrepeated composite field.
2234    self.first_proto.ClearField('optional_nested_message')
2235    self.second_proto.optional_nested_message.ClearField('bb')
2236    self.assertNotEqual(self.first_proto, self.second_proto)
2237    self.first_proto.optional_nested_message.bb = 0
2238    self.first_proto.optional_nested_message.ClearField('bb')
2239    self.assertEqual(self.first_proto, self.second_proto)
2240
2241
2242@testing_refleaks.TestCase
2243class ExtensionEqualityTest(unittest.TestCase):
2244
2245  def testExtensionEquality(self):
2246    first_proto = unittest_pb2.TestAllExtensions()
2247    second_proto = unittest_pb2.TestAllExtensions()
2248    self.assertEqual(first_proto, second_proto)
2249    test_util.SetAllExtensions(first_proto)
2250    self.assertNotEqual(first_proto, second_proto)
2251    test_util.SetAllExtensions(second_proto)
2252    self.assertEqual(first_proto, second_proto)
2253
2254    # Ensure that we check value equality.
2255    first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
2256    self.assertNotEqual(first_proto, second_proto)
2257    first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
2258    self.assertEqual(first_proto, second_proto)
2259
2260    # Ensure that we also look at "has" bits.
2261    first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
2262    second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
2263    self.assertNotEqual(first_proto, second_proto)
2264    first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
2265    self.assertEqual(first_proto, second_proto)
2266
2267    # Ensure that differences in cached values
2268    # don't matter if "has" bits are both false.
2269    first_proto = unittest_pb2.TestAllExtensions()
2270    second_proto = unittest_pb2.TestAllExtensions()
2271    self.assertEqual(
2272        0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
2273    self.assertEqual(first_proto, second_proto)
2274
2275
2276@testing_refleaks.TestCase
2277class MutualRecursionEqualityTest(unittest.TestCase):
2278
2279  def testEqualityWithMutualRecursion(self):
2280    first_proto = unittest_pb2.TestMutualRecursionA()
2281    second_proto = unittest_pb2.TestMutualRecursionA()
2282    self.assertEqual(first_proto, second_proto)
2283    first_proto.bb.a.bb.optional_int32 = 23
2284    self.assertNotEqual(first_proto, second_proto)
2285    second_proto.bb.a.bb.optional_int32 = 23
2286    self.assertEqual(first_proto, second_proto)
2287
2288
2289@testing_refleaks.TestCase
2290class ByteSizeTest(unittest.TestCase):
2291
2292  def setUp(self):
2293    self.proto = unittest_pb2.TestAllTypes()
2294    self.extended_proto = more_extensions_pb2.ExtendedMessage()
2295    self.packed_proto = unittest_pb2.TestPackedTypes()
2296    self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
2297
2298  def Size(self):
2299    return self.proto.ByteSize()
2300
2301  def testEmptyMessage(self):
2302    self.assertEqual(0, self.proto.ByteSize())
2303
2304  def testSizedOnKwargs(self):
2305    # Use a separate message to ensure testing right after creation.
2306    proto = unittest_pb2.TestAllTypes()
2307    self.assertEqual(0, proto.ByteSize())
2308    proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1)
2309    # One byte for the tag, one to encode varint 1.
2310    self.assertEqual(2, proto_kwargs.ByteSize())
2311
2312  def testVarints(self):
2313    def Test(i, expected_varint_size):
2314      self.proto.Clear()
2315      self.proto.optional_int64 = i
2316      # Add one to the varint size for the tag info
2317      # for tag 1.
2318      self.assertEqual(expected_varint_size + 1, self.Size())
2319    Test(0, 1)
2320    Test(1, 1)
2321    for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
2322      Test((1 << i) - 1, num_bytes)
2323    Test(-1, 10)
2324    Test(-2, 10)
2325    Test(-(1 << 63), 10)
2326
2327  def testStrings(self):
2328    self.proto.optional_string = ''
2329    # Need one byte for tag info (tag #14), and one byte for length.
2330    self.assertEqual(2, self.Size())
2331
2332    self.proto.optional_string = 'abc'
2333    # Need one byte for tag info (tag #14), and one byte for length.
2334    self.assertEqual(2 + len(self.proto.optional_string), self.Size())
2335
2336    self.proto.optional_string = 'x' * 128
2337    # Need one byte for tag info (tag #14), and TWO bytes for length.
2338    self.assertEqual(3 + len(self.proto.optional_string), self.Size())
2339
2340  def testOtherNumerics(self):
2341    self.proto.optional_fixed32 = 1234
2342    # One byte for tag and 4 bytes for fixed32.
2343    self.assertEqual(5, self.Size())
2344    self.proto = unittest_pb2.TestAllTypes()
2345
2346    self.proto.optional_fixed64 = 1234
2347    # One byte for tag and 8 bytes for fixed64.
2348    self.assertEqual(9, self.Size())
2349    self.proto = unittest_pb2.TestAllTypes()
2350
2351    self.proto.optional_float = 1.234
2352    # One byte for tag and 4 bytes for float.
2353    self.assertEqual(5, self.Size())
2354    self.proto = unittest_pb2.TestAllTypes()
2355
2356    self.proto.optional_double = 1.234
2357    # One byte for tag and 8 bytes for float.
2358    self.assertEqual(9, self.Size())
2359    self.proto = unittest_pb2.TestAllTypes()
2360
2361    self.proto.optional_sint32 = 64
2362    # One byte for tag and 2 bytes for zig-zag-encoded 64.
2363    self.assertEqual(3, self.Size())
2364    self.proto = unittest_pb2.TestAllTypes()
2365
2366  def testComposites(self):
2367    # 3 bytes.
2368    self.proto.optional_nested_message.bb = (1 << 14)
2369    # Plus one byte for bb tag.
2370    # Plus 1 byte for optional_nested_message serialized size.
2371    # Plus two bytes for optional_nested_message tag.
2372    self.assertEqual(3 + 1 + 1 + 2, self.Size())
2373
2374  def testGroups(self):
2375    # 4 bytes.
2376    self.proto.optionalgroup.a = (1 << 21)
2377    # Plus two bytes for |a| tag.
2378    # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
2379    self.assertEqual(4 + 2 + 2*2, self.Size())
2380
2381  def testRepeatedScalars(self):
2382    self.proto.repeated_int32.append(10)  # 1 byte.
2383    self.proto.repeated_int32.append(128)  # 2 bytes.
2384    # Also need 2 bytes for each entry for tag.
2385    self.assertEqual(1 + 2 + 2*2, self.Size())
2386
2387  def testRepeatedScalarsExtend(self):
2388    self.proto.repeated_int32.extend([10, 128])  # 3 bytes.
2389    # Also need 2 bytes for each entry for tag.
2390    self.assertEqual(1 + 2 + 2*2, self.Size())
2391
2392  def testRepeatedScalarsRemove(self):
2393    self.proto.repeated_int32.append(10)  # 1 byte.
2394    self.proto.repeated_int32.append(128)  # 2 bytes.
2395    # Also need 2 bytes for each entry for tag.
2396    self.assertEqual(1 + 2 + 2*2, self.Size())
2397    self.proto.repeated_int32.remove(128)
2398    self.assertEqual(1 + 2, self.Size())
2399
2400  def testRepeatedComposites(self):
2401    # Empty message.  2 bytes tag plus 1 byte length.
2402    foreign_message_0 = self.proto.repeated_nested_message.add()
2403    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2404    foreign_message_1 = self.proto.repeated_nested_message.add()
2405    foreign_message_1.bb = 7
2406    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2407
2408  def testRepeatedCompositesDelete(self):
2409    # Empty message.  2 bytes tag plus 1 byte length.
2410    foreign_message_0 = self.proto.repeated_nested_message.add()
2411    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2412    foreign_message_1 = self.proto.repeated_nested_message.add()
2413    foreign_message_1.bb = 9
2414    self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
2415    repeated_nested_message = copy.deepcopy(
2416        self.proto.repeated_nested_message)
2417
2418    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2419    del self.proto.repeated_nested_message[0]
2420    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2421
2422    # Now add a new message.
2423    foreign_message_2 = self.proto.repeated_nested_message.add()
2424    foreign_message_2.bb = 12
2425
2426    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2427    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2428    self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
2429
2430    # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
2431    del self.proto.repeated_nested_message[1]
2432    self.assertEqual(2 + 1 + 1 + 1, self.Size())
2433
2434    del self.proto.repeated_nested_message[0]
2435    self.assertEqual(0, self.Size())
2436
2437    self.assertEqual(2, len(repeated_nested_message))
2438    del repeated_nested_message[0:1]
2439    # TODO(jieluo): Fix cpp extension bug when delete repeated message.
2440    if api_implementation.Type() == 'python':
2441      self.assertEqual(1, len(repeated_nested_message))
2442    del repeated_nested_message[-1]
2443    # TODO(jieluo): Fix cpp extension bug when delete repeated message.
2444    if api_implementation.Type() == 'python':
2445      self.assertEqual(0, len(repeated_nested_message))
2446
2447  def testRepeatedGroups(self):
2448    # 2-byte START_GROUP plus 2-byte END_GROUP.
2449    group_0 = self.proto.repeatedgroup.add()
2450    # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
2451    # plus 2-byte END_GROUP.
2452    group_1 = self.proto.repeatedgroup.add()
2453    group_1.a =  7
2454    self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
2455
2456  def testExtensions(self):
2457    proto = unittest_pb2.TestAllExtensions()
2458    self.assertEqual(0, proto.ByteSize())
2459    extension = unittest_pb2.optional_int32_extension  # Field #1, 1 byte.
2460    proto.Extensions[extension] = 23
2461    # 1 byte for tag, 1 byte for value.
2462    self.assertEqual(2, proto.ByteSize())
2463    field = unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name[
2464        'optional_int32']
2465    with self.assertRaises(KeyError):
2466      proto.Extensions[field] = 23
2467
2468  def testCacheInvalidationForNonrepeatedScalar(self):
2469    # Test non-extension.
2470    self.proto.optional_int32 = 1
2471    self.assertEqual(2, self.proto.ByteSize())
2472    self.proto.optional_int32 = 128
2473    self.assertEqual(3, self.proto.ByteSize())
2474    self.proto.ClearField('optional_int32')
2475    self.assertEqual(0, self.proto.ByteSize())
2476
2477    # Test within extension.
2478    extension = more_extensions_pb2.optional_int_extension
2479    self.extended_proto.Extensions[extension] = 1
2480    self.assertEqual(2, self.extended_proto.ByteSize())
2481    self.extended_proto.Extensions[extension] = 128
2482    self.assertEqual(3, self.extended_proto.ByteSize())
2483    self.extended_proto.ClearExtension(extension)
2484    self.assertEqual(0, self.extended_proto.ByteSize())
2485
2486  def testCacheInvalidationForRepeatedScalar(self):
2487    # Test non-extension.
2488    self.proto.repeated_int32.append(1)
2489    self.assertEqual(3, self.proto.ByteSize())
2490    self.proto.repeated_int32.append(1)
2491    self.assertEqual(6, self.proto.ByteSize())
2492    self.proto.repeated_int32[1] = 128
2493    self.assertEqual(7, self.proto.ByteSize())
2494    self.proto.ClearField('repeated_int32')
2495    self.assertEqual(0, self.proto.ByteSize())
2496
2497    # Test within extension.
2498    extension = more_extensions_pb2.repeated_int_extension
2499    repeated = self.extended_proto.Extensions[extension]
2500    repeated.append(1)
2501    self.assertEqual(2, self.extended_proto.ByteSize())
2502    repeated.append(1)
2503    self.assertEqual(4, self.extended_proto.ByteSize())
2504    repeated[1] = 128
2505    self.assertEqual(5, self.extended_proto.ByteSize())
2506    self.extended_proto.ClearExtension(extension)
2507    self.assertEqual(0, self.extended_proto.ByteSize())
2508
2509  def testCacheInvalidationForNonrepeatedMessage(self):
2510    # Test non-extension.
2511    self.proto.optional_foreign_message.c = 1
2512    self.assertEqual(5, self.proto.ByteSize())
2513    self.proto.optional_foreign_message.c = 128
2514    self.assertEqual(6, self.proto.ByteSize())
2515    self.proto.optional_foreign_message.ClearField('c')
2516    self.assertEqual(3, self.proto.ByteSize())
2517    self.proto.ClearField('optional_foreign_message')
2518    self.assertEqual(0, self.proto.ByteSize())
2519
2520    if api_implementation.Type() == 'python':
2521      # This is only possible in pure-Python implementation of the API.
2522      child = self.proto.optional_foreign_message
2523      self.proto.ClearField('optional_foreign_message')
2524      child.c = 128
2525      self.assertEqual(0, self.proto.ByteSize())
2526
2527    # Test within extension.
2528    extension = more_extensions_pb2.optional_message_extension
2529    child = self.extended_proto.Extensions[extension]
2530    self.assertEqual(0, self.extended_proto.ByteSize())
2531    child.foreign_message_int = 1
2532    self.assertEqual(4, self.extended_proto.ByteSize())
2533    child.foreign_message_int = 128
2534    self.assertEqual(5, self.extended_proto.ByteSize())
2535    self.extended_proto.ClearExtension(extension)
2536    self.assertEqual(0, self.extended_proto.ByteSize())
2537
2538  def testCacheInvalidationForRepeatedMessage(self):
2539    # Test non-extension.
2540    child0 = self.proto.repeated_foreign_message.add()
2541    self.assertEqual(3, self.proto.ByteSize())
2542    self.proto.repeated_foreign_message.add()
2543    self.assertEqual(6, self.proto.ByteSize())
2544    child0.c = 1
2545    self.assertEqual(8, self.proto.ByteSize())
2546    self.proto.ClearField('repeated_foreign_message')
2547    self.assertEqual(0, self.proto.ByteSize())
2548
2549    # Test within extension.
2550    extension = more_extensions_pb2.repeated_message_extension
2551    child_list = self.extended_proto.Extensions[extension]
2552    child0 = child_list.add()
2553    self.assertEqual(2, self.extended_proto.ByteSize())
2554    child_list.add()
2555    self.assertEqual(4, self.extended_proto.ByteSize())
2556    child0.foreign_message_int = 1
2557    self.assertEqual(6, self.extended_proto.ByteSize())
2558    child0.ClearField('foreign_message_int')
2559    self.assertEqual(4, self.extended_proto.ByteSize())
2560    self.extended_proto.ClearExtension(extension)
2561    self.assertEqual(0, self.extended_proto.ByteSize())
2562
2563  def testPackedRepeatedScalars(self):
2564    self.assertEqual(0, self.packed_proto.ByteSize())
2565
2566    self.packed_proto.packed_int32.append(10)   # 1 byte.
2567    self.packed_proto.packed_int32.append(128)  # 2 bytes.
2568    # The tag is 2 bytes (the field number is 90), and the varint
2569    # storing the length is 1 byte.
2570    int_size = 1 + 2 + 3
2571    self.assertEqual(int_size, self.packed_proto.ByteSize())
2572
2573    self.packed_proto.packed_double.append(4.2)   # 8 bytes
2574    self.packed_proto.packed_double.append(3.25)  # 8 bytes
2575    # 2 more tag bytes, 1 more length byte.
2576    double_size = 8 + 8 + 3
2577    self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
2578
2579    self.packed_proto.ClearField('packed_int32')
2580    self.assertEqual(double_size, self.packed_proto.ByteSize())
2581
2582  def testPackedExtensions(self):
2583    self.assertEqual(0, self.packed_extended_proto.ByteSize())
2584    extension = self.packed_extended_proto.Extensions[
2585        unittest_pb2.packed_fixed32_extension]
2586    extension.extend([1, 2, 3, 4])   # 16 bytes
2587    # Tag is 3 bytes.
2588    self.assertEqual(19, self.packed_extended_proto.ByteSize())
2589
2590
2591# Issues to be sure to cover include:
2592#   * Handling of unrecognized tags ("uninterpreted_bytes").
2593#   * Handling of MessageSets.
2594#   * Consistent ordering of tags in the wire format,
2595#     including ordering between extensions and non-extension
2596#     fields.
2597#   * Consistent serialization of negative numbers, especially
2598#     negative int32s.
2599#   * Handling of empty submessages (with and without "has"
2600#     bits set).
2601
2602@testing_refleaks.TestCase
2603class SerializationTest(unittest.TestCase):
2604
2605  def testSerializeEmtpyMessage(self):
2606    first_proto = unittest_pb2.TestAllTypes()
2607    second_proto = unittest_pb2.TestAllTypes()
2608    serialized = first_proto.SerializeToString()
2609    self.assertEqual(first_proto.ByteSize(), len(serialized))
2610    self.assertEqual(
2611        len(serialized),
2612        second_proto.MergeFromString(serialized))
2613    self.assertEqual(first_proto, second_proto)
2614
2615  def testSerializeAllFields(self):
2616    first_proto = unittest_pb2.TestAllTypes()
2617    second_proto = unittest_pb2.TestAllTypes()
2618    test_util.SetAllFields(first_proto)
2619    serialized = first_proto.SerializeToString()
2620    self.assertEqual(first_proto.ByteSize(), len(serialized))
2621    self.assertEqual(
2622        len(serialized),
2623        second_proto.MergeFromString(serialized))
2624    self.assertEqual(first_proto, second_proto)
2625
2626  def testSerializeAllExtensions(self):
2627    first_proto = unittest_pb2.TestAllExtensions()
2628    second_proto = unittest_pb2.TestAllExtensions()
2629    test_util.SetAllExtensions(first_proto)
2630    serialized = first_proto.SerializeToString()
2631    self.assertEqual(
2632        len(serialized),
2633        second_proto.MergeFromString(serialized))
2634    self.assertEqual(first_proto, second_proto)
2635
2636  def testSerializeWithOptionalGroup(self):
2637    first_proto = unittest_pb2.TestAllTypes()
2638    second_proto = unittest_pb2.TestAllTypes()
2639    first_proto.optionalgroup.a = 242
2640    serialized = first_proto.SerializeToString()
2641    self.assertEqual(
2642        len(serialized),
2643        second_proto.MergeFromString(serialized))
2644    self.assertEqual(first_proto, second_proto)
2645
2646  def testSerializeNegativeValues(self):
2647    first_proto = unittest_pb2.TestAllTypes()
2648
2649    first_proto.optional_int32 = -1
2650    first_proto.optional_int64 = -(2 << 40)
2651    first_proto.optional_sint32 = -3
2652    first_proto.optional_sint64 = -(4 << 40)
2653    first_proto.optional_sfixed32 = -5
2654    first_proto.optional_sfixed64 = -(6 << 40)
2655
2656    second_proto = unittest_pb2.TestAllTypes.FromString(
2657        first_proto.SerializeToString())
2658
2659    self.assertEqual(first_proto, second_proto)
2660
2661  def testParseTruncated(self):
2662    # This test is only applicable for the Python implementation of the API.
2663    if api_implementation.Type() != 'python':
2664      return
2665
2666    first_proto = unittest_pb2.TestAllTypes()
2667    test_util.SetAllFields(first_proto)
2668    serialized = memoryview(first_proto.SerializeToString())
2669
2670    for truncation_point in range(len(serialized) + 1):
2671      try:
2672        second_proto = unittest_pb2.TestAllTypes()
2673        unknown_fields = unittest_pb2.TestEmptyMessage()
2674        pos = second_proto._InternalParse(serialized, 0, truncation_point)
2675        # If we didn't raise an error then we read exactly the amount expected.
2676        self.assertEqual(truncation_point, pos)
2677
2678        # Parsing to unknown fields should not throw if parsing to known fields
2679        # did not.
2680        try:
2681          pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
2682          self.assertEqual(truncation_point, pos2)
2683        except message.DecodeError:
2684          self.fail('Parsing unknown fields failed when parsing known fields '
2685                    'did not.')
2686      except message.DecodeError:
2687        # Parsing unknown fields should also fail.
2688        self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
2689                          serialized, 0, truncation_point)
2690
2691  def testCanonicalSerializationOrder(self):
2692    proto = more_messages_pb2.OutOfOrderFields()
2693    # These are also their tag numbers.  Even though we're setting these in
2694    # reverse-tag order AND they're listed in reverse tag-order in the .proto
2695    # file, they should nonetheless be serialized in tag order.
2696    proto.optional_sint32 = 5
2697    proto.Extensions[more_messages_pb2.optional_uint64] = 4
2698    proto.optional_uint32 = 3
2699    proto.Extensions[more_messages_pb2.optional_int64] = 2
2700    proto.optional_int32 = 1
2701    serialized = proto.SerializeToString()
2702    self.assertEqual(proto.ByteSize(), len(serialized))
2703    d = _MiniDecoder(serialized)
2704    ReadTag = d.ReadFieldNumberAndWireType
2705    self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
2706    self.assertEqual(1, d.ReadInt32())
2707    self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
2708    self.assertEqual(2, d.ReadInt64())
2709    self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
2710    self.assertEqual(3, d.ReadUInt32())
2711    self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
2712    self.assertEqual(4, d.ReadUInt64())
2713    self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
2714    self.assertEqual(5, d.ReadSInt32())
2715
2716  def testCanonicalSerializationOrderSameAsCpp(self):
2717    # Copy of the same test we use for C++.
2718    proto = unittest_pb2.TestFieldOrderings()
2719    test_util.SetAllFieldsAndExtensions(proto)
2720    serialized = proto.SerializeToString()
2721    test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
2722
2723  def testMergeFromStringWhenFieldsAlreadySet(self):
2724    first_proto = unittest_pb2.TestAllTypes()
2725    first_proto.repeated_string.append('foobar')
2726    first_proto.optional_int32 = 23
2727    first_proto.optional_nested_message.bb = 42
2728    serialized = first_proto.SerializeToString()
2729
2730    second_proto = unittest_pb2.TestAllTypes()
2731    second_proto.repeated_string.append('baz')
2732    second_proto.optional_int32 = 100
2733    second_proto.optional_nested_message.bb = 999
2734
2735    bytes_parsed = second_proto.MergeFromString(serialized)
2736    self.assertEqual(len(serialized), bytes_parsed)
2737
2738    # Ensure that we append to repeated fields.
2739    self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
2740    # Ensure that we overwrite nonrepeatd scalars.
2741    self.assertEqual(23, second_proto.optional_int32)
2742    # Ensure that we recursively call MergeFromString() on
2743    # submessages.
2744    self.assertEqual(42, second_proto.optional_nested_message.bb)
2745
2746  def testMessageSetWireFormat(self):
2747    proto = message_set_extensions_pb2.TestMessageSet()
2748    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2749    extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2
2750    extension1 = extension_message1.message_set_extension
2751    extension2 = extension_message2.message_set_extension
2752    extension3 = message_set_extensions_pb2.message_set_extension3
2753    proto.Extensions[extension1].i = 123
2754    proto.Extensions[extension2].str = 'foo'
2755    proto.Extensions[extension3].text = 'bar'
2756
2757    # Serialize using the MessageSet wire format (this is specified in the
2758    # .proto file).
2759    serialized = proto.SerializeToString()
2760
2761    raw = unittest_mset_pb2.RawMessageSet()
2762    self.assertEqual(False,
2763                     raw.DESCRIPTOR.GetOptions().message_set_wire_format)
2764    self.assertEqual(
2765        len(serialized),
2766        raw.MergeFromString(serialized))
2767    self.assertEqual(3, len(raw.item))
2768
2769    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2770    self.assertEqual(
2771        len(raw.item[0].message),
2772        message1.MergeFromString(raw.item[0].message))
2773    self.assertEqual(123, message1.i)
2774
2775    message2 = message_set_extensions_pb2.TestMessageSetExtension2()
2776    self.assertEqual(
2777        len(raw.item[1].message),
2778        message2.MergeFromString(raw.item[1].message))
2779    self.assertEqual('foo', message2.str)
2780
2781    message3 = message_set_extensions_pb2.TestMessageSetExtension3()
2782    self.assertEqual(
2783        len(raw.item[2].message),
2784        message3.MergeFromString(raw.item[2].message))
2785    self.assertEqual('bar', message3.text)
2786
2787    # Deserialize using the MessageSet wire format.
2788    proto2 = message_set_extensions_pb2.TestMessageSet()
2789    self.assertEqual(
2790        len(serialized),
2791        proto2.MergeFromString(serialized))
2792    self.assertEqual(123, proto2.Extensions[extension1].i)
2793    self.assertEqual('foo', proto2.Extensions[extension2].str)
2794    self.assertEqual('bar', proto2.Extensions[extension3].text)
2795
2796    # Check byte size.
2797    self.assertEqual(proto2.ByteSize(), len(serialized))
2798    self.assertEqual(proto.ByteSize(), len(serialized))
2799
2800  def testMessageSetWireFormatUnknownExtension(self):
2801    # Create a message using the message set wire format with an unknown
2802    # message.
2803    raw = unittest_mset_pb2.RawMessageSet()
2804
2805    # Add an item.
2806    item = raw.item.add()
2807    item.type_id = 98418603
2808    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2809    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2810    message1.i = 12345
2811    item.message = message1.SerializeToString()
2812
2813    # Add a second, unknown extension.
2814    item = raw.item.add()
2815    item.type_id = 98418604
2816    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2817    message1 = message_set_extensions_pb2.TestMessageSetExtension1()
2818    message1.i = 12346
2819    item.message = message1.SerializeToString()
2820
2821    # Add another unknown extension.
2822    item = raw.item.add()
2823    item.type_id = 98418605
2824    message1 = message_set_extensions_pb2.TestMessageSetExtension2()
2825    message1.str = 'foo'
2826    item.message = message1.SerializeToString()
2827
2828    serialized = raw.SerializeToString()
2829
2830    # Parse message using the message set wire format.
2831    proto = message_set_extensions_pb2.TestMessageSet()
2832    self.assertEqual(
2833        len(serialized),
2834        proto.MergeFromString(serialized))
2835
2836    # Check that the message parsed well.
2837    extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1
2838    extension1 = extension_message1.message_set_extension
2839    self.assertEqual(12345, proto.Extensions[extension1].i)
2840
2841  def testUnknownFields(self):
2842    proto = unittest_pb2.TestAllTypes()
2843    test_util.SetAllFields(proto)
2844
2845    serialized = proto.SerializeToString()
2846
2847    # The empty message should be parsable with all of the fields
2848    # unknown.
2849    proto2 = unittest_pb2.TestEmptyMessage()
2850
2851    # Parsing this message should succeed.
2852    self.assertEqual(
2853        len(serialized),
2854        proto2.MergeFromString(serialized))
2855
2856    # Now test with a int64 field set.
2857    proto = unittest_pb2.TestAllTypes()
2858    proto.optional_int64 = 0x0fffffffffffffff
2859    serialized = proto.SerializeToString()
2860    # The empty message should be parsable with all of the fields
2861    # unknown.
2862    proto2 = unittest_pb2.TestEmptyMessage()
2863    # Parsing this message should succeed.
2864    self.assertEqual(
2865        len(serialized),
2866        proto2.MergeFromString(serialized))
2867
2868  def _CheckRaises(self, exc_class, callable_obj, exception):
2869    """This method checks if the exception type and message are as expected."""
2870    try:
2871      callable_obj()
2872    except exc_class as ex:
2873      # Check if the exception message is the right one.
2874      self.assertEqual(exception, str(ex))
2875      return
2876    else:
2877      raise self.failureException('%s not raised' % str(exc_class))
2878
2879  def testSerializeUninitialized(self):
2880    proto = unittest_pb2.TestRequired()
2881    self._CheckRaises(
2882        message.EncodeError,
2883        proto.SerializeToString,
2884        'Message protobuf_unittest.TestRequired is missing required fields: '
2885        'a,b,c')
2886    # Shouldn't raise exceptions.
2887    partial = proto.SerializePartialToString()
2888
2889    proto2 = unittest_pb2.TestRequired()
2890    self.assertFalse(proto2.HasField('a'))
2891    # proto2 ParseFromString does not check that required fields are set.
2892    proto2.ParseFromString(partial)
2893    self.assertFalse(proto2.HasField('a'))
2894
2895    proto.a = 1
2896    self._CheckRaises(
2897        message.EncodeError,
2898        proto.SerializeToString,
2899        'Message protobuf_unittest.TestRequired is missing required fields: b,c')
2900    # Shouldn't raise exceptions.
2901    partial = proto.SerializePartialToString()
2902
2903    proto.b = 2
2904    self._CheckRaises(
2905        message.EncodeError,
2906        proto.SerializeToString,
2907        'Message protobuf_unittest.TestRequired is missing required fields: c')
2908    # Shouldn't raise exceptions.
2909    partial = proto.SerializePartialToString()
2910
2911    proto.c = 3
2912    serialized = proto.SerializeToString()
2913    # Shouldn't raise exceptions.
2914    partial = proto.SerializePartialToString()
2915
2916    proto2 = unittest_pb2.TestRequired()
2917    self.assertEqual(
2918        len(serialized),
2919        proto2.MergeFromString(serialized))
2920    self.assertEqual(1, proto2.a)
2921    self.assertEqual(2, proto2.b)
2922    self.assertEqual(3, proto2.c)
2923    self.assertEqual(
2924        len(partial),
2925        proto2.MergeFromString(partial))
2926    self.assertEqual(1, proto2.a)
2927    self.assertEqual(2, proto2.b)
2928    self.assertEqual(3, proto2.c)
2929
2930  def testSerializeUninitializedSubMessage(self):
2931    proto = unittest_pb2.TestRequiredForeign()
2932
2933    # Sub-message doesn't exist yet, so this succeeds.
2934    proto.SerializeToString()
2935
2936    proto.optional_message.a = 1
2937    self._CheckRaises(
2938        message.EncodeError,
2939        proto.SerializeToString,
2940        'Message protobuf_unittest.TestRequiredForeign '
2941        'is missing required fields: '
2942        'optional_message.b,optional_message.c')
2943
2944    proto.optional_message.b = 2
2945    proto.optional_message.c = 3
2946    proto.SerializeToString()
2947
2948    proto.repeated_message.add().a = 1
2949    proto.repeated_message.add().b = 2
2950    self._CheckRaises(
2951        message.EncodeError,
2952        proto.SerializeToString,
2953        'Message protobuf_unittest.TestRequiredForeign is missing required fields: '
2954        'repeated_message[0].b,repeated_message[0].c,'
2955        'repeated_message[1].a,repeated_message[1].c')
2956
2957    proto.repeated_message[0].b = 2
2958    proto.repeated_message[0].c = 3
2959    proto.repeated_message[1].a = 1
2960    proto.repeated_message[1].c = 3
2961    proto.SerializeToString()
2962
2963  def testSerializeAllPackedFields(self):
2964    first_proto = unittest_pb2.TestPackedTypes()
2965    second_proto = unittest_pb2.TestPackedTypes()
2966    test_util.SetAllPackedFields(first_proto)
2967    serialized = first_proto.SerializeToString()
2968    self.assertEqual(first_proto.ByteSize(), len(serialized))
2969    bytes_read = second_proto.MergeFromString(serialized)
2970    self.assertEqual(second_proto.ByteSize(), bytes_read)
2971    self.assertEqual(first_proto, second_proto)
2972
2973  def testSerializeAllPackedExtensions(self):
2974    first_proto = unittest_pb2.TestPackedExtensions()
2975    second_proto = unittest_pb2.TestPackedExtensions()
2976    test_util.SetAllPackedExtensions(first_proto)
2977    serialized = first_proto.SerializeToString()
2978    bytes_read = second_proto.MergeFromString(serialized)
2979    self.assertEqual(second_proto.ByteSize(), bytes_read)
2980    self.assertEqual(first_proto, second_proto)
2981
2982  def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
2983    first_proto = unittest_pb2.TestPackedTypes()
2984    first_proto.packed_int32.extend([1, 2])
2985    first_proto.packed_double.append(3.0)
2986    serialized = first_proto.SerializeToString()
2987
2988    second_proto = unittest_pb2.TestPackedTypes()
2989    second_proto.packed_int32.append(3)
2990    second_proto.packed_double.extend([1.0, 2.0])
2991    second_proto.packed_sint32.append(4)
2992
2993    self.assertEqual(
2994        len(serialized),
2995        second_proto.MergeFromString(serialized))
2996    self.assertEqual([3, 1, 2], second_proto.packed_int32)
2997    self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
2998    self.assertEqual([4], second_proto.packed_sint32)
2999
3000  def testPackedFieldsWireFormat(self):
3001    proto = unittest_pb2.TestPackedTypes()
3002    proto.packed_int32.extend([1, 2, 150, 3])  # 1 + 1 + 2 + 1 bytes
3003    proto.packed_double.extend([1.0, 1000.0])  # 8 + 8 bytes
3004    proto.packed_float.append(2.0)             # 4 bytes, will be before double
3005    serialized = proto.SerializeToString()
3006    self.assertEqual(proto.ByteSize(), len(serialized))
3007    d = _MiniDecoder(serialized)
3008    ReadTag = d.ReadFieldNumberAndWireType
3009    self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
3010    self.assertEqual(1+1+1+2, d.ReadInt32())
3011    self.assertEqual(1, d.ReadInt32())
3012    self.assertEqual(2, d.ReadInt32())
3013    self.assertEqual(150, d.ReadInt32())
3014    self.assertEqual(3, d.ReadInt32())
3015    self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
3016    self.assertEqual(4, d.ReadInt32())
3017    self.assertEqual(2.0, d.ReadFloat())
3018    self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
3019    self.assertEqual(8+8, d.ReadInt32())
3020    self.assertEqual(1.0, d.ReadDouble())
3021    self.assertEqual(1000.0, d.ReadDouble())
3022    self.assertTrue(d.EndOfStream())
3023
3024  def testParsePackedFromUnpacked(self):
3025    unpacked = unittest_pb2.TestUnpackedTypes()
3026    test_util.SetAllUnpackedFields(unpacked)
3027    packed = unittest_pb2.TestPackedTypes()
3028    serialized = unpacked.SerializeToString()
3029    self.assertEqual(
3030        len(serialized),
3031        packed.MergeFromString(serialized))
3032    expected = unittest_pb2.TestPackedTypes()
3033    test_util.SetAllPackedFields(expected)
3034    self.assertEqual(expected, packed)
3035
3036  def testParseUnpackedFromPacked(self):
3037    packed = unittest_pb2.TestPackedTypes()
3038    test_util.SetAllPackedFields(packed)
3039    unpacked = unittest_pb2.TestUnpackedTypes()
3040    serialized = packed.SerializeToString()
3041    self.assertEqual(
3042        len(serialized),
3043        unpacked.MergeFromString(serialized))
3044    expected = unittest_pb2.TestUnpackedTypes()
3045    test_util.SetAllUnpackedFields(expected)
3046    self.assertEqual(expected, unpacked)
3047
3048  def testFieldNumbers(self):
3049    proto = unittest_pb2.TestAllTypes()
3050    self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
3051    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
3052    self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
3053    self.assertEqual(
3054      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
3055    self.assertEqual(
3056      unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
3057    self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
3058    self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
3059    self.assertEqual(
3060      unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
3061    self.assertEqual(
3062      unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
3063
3064  def testExtensionFieldNumbers(self):
3065    self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
3066    self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
3067    self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
3068    self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
3069    self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
3070    self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
3071    self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
3072    self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
3073    self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
3074    self.assertEqual(
3075      unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
3076    self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
3077    self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
3078      21)
3079    self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
3080    self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
3081    self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
3082    self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
3083    self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
3084    self.assertEqual(
3085      unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
3086    self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
3087    self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
3088      51)
3089
3090  def testFieldProperties(self):
3091    cls = unittest_pb2.TestAllTypes
3092    self.assertIs(cls.optional_int32.DESCRIPTOR,
3093                  cls.DESCRIPTOR.fields_by_name['optional_int32'])
3094    self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
3095                     cls.optional_int32.DESCRIPTOR.number)
3096    self.assertIs(cls.optional_nested_message.DESCRIPTOR,
3097                  cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
3098    self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
3099                     cls.optional_nested_message.DESCRIPTOR.number)
3100    self.assertIs(cls.repeated_int32.DESCRIPTOR,
3101                  cls.DESCRIPTOR.fields_by_name['repeated_int32'])
3102    self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
3103                     cls.repeated_int32.DESCRIPTOR.number)
3104
3105  def testFieldDataDescriptor(self):
3106    msg = unittest_pb2.TestAllTypes()
3107    msg.optional_int32 = 42
3108    self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
3109    unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
3110    self.assertEqual(msg.optional_int32, 25)
3111    with self.assertRaises(AttributeError):
3112      del msg.optional_int32
3113    try:
3114      unittest_pb2.ForeignMessage.c.__get__(msg)
3115    except TypeError:
3116      pass  # The cpp implementation cannot mix fields from other messages.
3117      # This test exercises a specific check that avoids a crash.
3118    else:
3119      pass  # The python implementation allows fields from other messages.
3120      # This is useless, but works.
3121
3122  def testInitKwargs(self):
3123    proto = unittest_pb2.TestAllTypes(
3124        optional_int32=1,
3125        optional_string='foo',
3126        optional_bool=True,
3127        optional_bytes=b'bar',
3128        optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
3129        optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
3130        optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
3131        optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
3132        repeated_int32=[1, 2, 3])
3133    self.assertTrue(proto.IsInitialized())
3134    self.assertTrue(proto.HasField('optional_int32'))
3135    self.assertTrue(proto.HasField('optional_string'))
3136    self.assertTrue(proto.HasField('optional_bool'))
3137    self.assertTrue(proto.HasField('optional_bytes'))
3138    self.assertTrue(proto.HasField('optional_nested_message'))
3139    self.assertTrue(proto.HasField('optional_foreign_message'))
3140    self.assertTrue(proto.HasField('optional_nested_enum'))
3141    self.assertTrue(proto.HasField('optional_foreign_enum'))
3142    self.assertEqual(1, proto.optional_int32)
3143    self.assertEqual('foo', proto.optional_string)
3144    self.assertEqual(True, proto.optional_bool)
3145    self.assertEqual(b'bar', proto.optional_bytes)
3146    self.assertEqual(1, proto.optional_nested_message.bb)
3147    self.assertEqual(1, proto.optional_foreign_message.c)
3148    self.assertEqual(unittest_pb2.TestAllTypes.FOO,
3149                     proto.optional_nested_enum)
3150    self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
3151    self.assertEqual([1, 2, 3], proto.repeated_int32)
3152
3153  def testInitArgsUnknownFieldName(self):
3154    def InitalizeEmptyMessageWithExtraKeywordArg():
3155      unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
3156    self._CheckRaises(
3157        ValueError,
3158        InitalizeEmptyMessageWithExtraKeywordArg,
3159        'Protocol message TestEmptyMessage has no "unknown" field.')
3160
3161  def testInitRequiredKwargs(self):
3162    proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
3163    self.assertTrue(proto.IsInitialized())
3164    self.assertTrue(proto.HasField('a'))
3165    self.assertTrue(proto.HasField('b'))
3166    self.assertTrue(proto.HasField('c'))
3167    self.assertFalse(proto.HasField('dummy2'))
3168    self.assertEqual(1, proto.a)
3169    self.assertEqual(1, proto.b)
3170    self.assertEqual(1, proto.c)
3171
3172  def testInitRequiredForeignKwargs(self):
3173    proto = unittest_pb2.TestRequiredForeign(
3174        optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
3175    self.assertTrue(proto.IsInitialized())
3176    self.assertTrue(proto.HasField('optional_message'))
3177    self.assertTrue(proto.optional_message.IsInitialized())
3178    self.assertTrue(proto.optional_message.HasField('a'))
3179    self.assertTrue(proto.optional_message.HasField('b'))
3180    self.assertTrue(proto.optional_message.HasField('c'))
3181    self.assertFalse(proto.optional_message.HasField('dummy2'))
3182    self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
3183                     proto.optional_message)
3184    self.assertEqual(1, proto.optional_message.a)
3185    self.assertEqual(1, proto.optional_message.b)
3186    self.assertEqual(1, proto.optional_message.c)
3187
3188  def testInitRepeatedKwargs(self):
3189    proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
3190    self.assertTrue(proto.IsInitialized())
3191    self.assertEqual(1, proto.repeated_int32[0])
3192    self.assertEqual(2, proto.repeated_int32[1])
3193    self.assertEqual(3, proto.repeated_int32[2])
3194
3195
3196@testing_refleaks.TestCase
3197class OptionsTest(unittest.TestCase):
3198
3199  def testMessageOptions(self):
3200    proto = message_set_extensions_pb2.TestMessageSet()
3201    self.assertEqual(True,
3202                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
3203    proto = unittest_pb2.TestAllTypes()
3204    self.assertEqual(False,
3205                     proto.DESCRIPTOR.GetOptions().message_set_wire_format)
3206
3207  def testPackedOptions(self):
3208    proto = unittest_pb2.TestAllTypes()
3209    proto.optional_int32 = 1
3210    proto.optional_double = 3.0
3211    for field_descriptor, _ in proto.ListFields():
3212      self.assertEqual(False, field_descriptor.GetOptions().packed)
3213
3214    proto = unittest_pb2.TestPackedTypes()
3215    proto.packed_int32.append(1)
3216    proto.packed_double.append(3.0)
3217    for field_descriptor, _ in proto.ListFields():
3218      self.assertEqual(True, field_descriptor.GetOptions().packed)
3219      self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED,
3220                       field_descriptor.label)
3221
3222
3223
3224@testing_refleaks.TestCase
3225class ClassAPITest(unittest.TestCase):
3226
3227  @unittest.skipIf(
3228      api_implementation.Type() != 'python',
3229      'C++ implementation requires a call to MakeDescriptor()')
3230  @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
3231  def testMakeClassWithNestedDescriptor(self):
3232    leaf_desc = descriptor.Descriptor(
3233        'leaf', 'package.parent.child.leaf', '',
3234        containing_type=None, fields=[],
3235        nested_types=[], enum_types=[],
3236        extensions=[],
3237        # pylint: disable=protected-access
3238        create_key=descriptor._internal_create_key)
3239    child_desc = descriptor.Descriptor(
3240        'child', 'package.parent.child', '',
3241        containing_type=None, fields=[],
3242        nested_types=[leaf_desc], enum_types=[],
3243        extensions=[],
3244        # pylint: disable=protected-access
3245        create_key=descriptor._internal_create_key)
3246    sibling_desc = descriptor.Descriptor(
3247        'sibling', 'package.parent.sibling',
3248        '', containing_type=None, fields=[],
3249        nested_types=[], enum_types=[],
3250        extensions=[],
3251        # pylint: disable=protected-access
3252        create_key=descriptor._internal_create_key)
3253    parent_desc = descriptor.Descriptor(
3254        'parent', 'package.parent', '',
3255        containing_type=None, fields=[],
3256        nested_types=[child_desc, sibling_desc],
3257        enum_types=[], extensions=[],
3258        # pylint: disable=protected-access
3259        create_key=descriptor._internal_create_key)
3260    reflection.MakeClass(parent_desc)
3261
3262  def _GetSerializedFileDescriptor(self, name):
3263    """Get a serialized representation of a test FileDescriptorProto.
3264
3265    Args:
3266      name: All calls to this must use a unique message name, to avoid
3267          collisions in the cpp descriptor pool.
3268    Returns:
3269      A string containing the serialized form of a test FileDescriptorProto.
3270    """
3271    file_descriptor_str = (
3272        'message_type {'
3273        '  name: "' + name + '"'
3274        '  field {'
3275        '    name: "flat"'
3276        '    number: 1'
3277        '    label: LABEL_REPEATED'
3278        '    type: TYPE_UINT32'
3279        '  }'
3280        '  field {'
3281        '    name: "bar"'
3282        '    number: 2'
3283        '    label: LABEL_OPTIONAL'
3284        '    type: TYPE_MESSAGE'
3285        '    type_name: "Bar"'
3286        '  }'
3287        '  nested_type {'
3288        '    name: "Bar"'
3289        '    field {'
3290        '      name: "baz"'
3291        '      number: 3'
3292        '      label: LABEL_OPTIONAL'
3293        '      type: TYPE_MESSAGE'
3294        '      type_name: "Baz"'
3295        '    }'
3296        '    nested_type {'
3297        '      name: "Baz"'
3298        '      enum_type {'
3299        '        name: "deep_enum"'
3300        '        value {'
3301        '          name: "VALUE_A"'
3302        '          number: 0'
3303        '        }'
3304        '      }'
3305        '      field {'
3306        '        name: "deep"'
3307        '        number: 4'
3308        '        label: LABEL_OPTIONAL'
3309        '        type: TYPE_UINT32'
3310        '      }'
3311        '    }'
3312        '  }'
3313        '}')
3314    file_descriptor = descriptor_pb2.FileDescriptorProto()
3315    text_format.Merge(file_descriptor_str, file_descriptor)
3316    return file_descriptor.SerializeToString()
3317
3318  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3319  # This test can only run once; the second time, it raises errors about
3320  # conflicting message descriptors.
3321  def testParsingFlatClassWithExplicitClassDeclaration(self):
3322    """Test that the generated class can parse a flat message."""
3323    # TODO(xiaofeng): This test fails with cpp implementation in the call
3324    # of six.with_metaclass(). The other two callsites of with_metaclass
3325    # in this file are both excluded from cpp test, so it might be expected
3326    # to fail. Need someone more familiar with the python code to take a
3327    # look at this.
3328    if api_implementation.Type() != 'python':
3329      return
3330    file_descriptor = descriptor_pb2.FileDescriptorProto()
3331    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A'))
3332    msg_descriptor = descriptor.MakeDescriptor(
3333        file_descriptor.message_type[0])
3334
3335    class MessageClass(
3336        message.Message, metaclass=reflection.GeneratedProtocolMessageType):
3337      DESCRIPTOR = msg_descriptor
3338    msg = MessageClass()
3339    msg_str = (
3340        'flat: 0 '
3341        'flat: 1 '
3342        'flat: 2 ')
3343    text_format.Merge(msg_str, msg)
3344    self.assertEqual(msg.flat, [0, 1, 2])
3345
3346  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3347  def testParsingFlatClass(self):
3348    """Test that the generated class can parse a flat message."""
3349    file_descriptor = descriptor_pb2.FileDescriptorProto()
3350    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B'))
3351    msg_descriptor = descriptor.MakeDescriptor(
3352        file_descriptor.message_type[0])
3353    msg_class = reflection.MakeClass(msg_descriptor)
3354    msg = msg_class()
3355    msg_str = (
3356        'flat: 0 '
3357        'flat: 1 '
3358        'flat: 2 ')
3359    text_format.Merge(msg_str, msg)
3360    self.assertEqual(msg.flat, [0, 1, 2])
3361
3362  @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
3363  def testParsingNestedClass(self):
3364    """Test that the generated class can parse a nested message."""
3365    file_descriptor = descriptor_pb2.FileDescriptorProto()
3366    file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
3367    msg_descriptor = descriptor.MakeDescriptor(
3368        file_descriptor.message_type[0])
3369    msg_class = reflection.MakeClass(msg_descriptor)
3370    msg = msg_class()
3371    msg_str = (
3372        'bar {'
3373        '  baz {'
3374        '    deep: 4'
3375        '  }'
3376        '}')
3377    text_format.Merge(msg_str, msg)
3378    self.assertEqual(msg.bar.baz.deep, 4)
3379
3380if __name__ == '__main__':
3381  unittest.main()
3382