xref: /aosp_15_r20/external/cronet/third_party/protobuf/python/google/protobuf/internal/encoder.py (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc.  All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9#     * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11#     * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15#     * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Code for encoding protocol message primitives.
32
33Contains the logic for encoding every logical protocol field type
34into one of the 5 physical wire types.
35
36This code is designed to push the Python interpreter's performance to the
37limits.
38
39The basic idea is that at startup time, for every field (i.e. every
40FieldDescriptor) we construct two functions:  a "sizer" and an "encoder".  The
41sizer takes a value of this field's type and computes its byte size.  The
42encoder takes a writer function and a value.  It encodes the value into byte
43strings and invokes the writer function to write those strings.  Typically the
44writer function is the write() method of a BytesIO.
45
46We try to do as much work as possible when constructing the writer and the
47sizer rather than when calling them.  In particular:
48* We copy any needed global functions to local variables, so that we do not need
49  to do costly global table lookups at runtime.
50* Similarly, we try to do any attribute lookups at startup time if possible.
51* Every field's tag is encoded to bytes at startup, since it can't change at
52  runtime.
53* Whatever component of the field size we can compute at startup, we do.
54* We *avoid* sharing code if doing so would make the code slower and not sharing
55  does not burden us too much.  For example, encoders for repeated fields do
56  not just call the encoders for singular fields in a loop because this would
57  add an extra function call overhead for every loop iteration; instead, we
58  manually inline the single-value encoder into the loop.
59* If a Python function lacks a return statement, Python actually generates
60  instructions to pop the result of the last statement off the stack, push
61  None onto the stack, and then return that.  If we really don't care what
62  value is returned, then we can save two instructions by returning the
63  result of the last statement.  It looks funny but it helps.
64* We assume that type and bounds checking has happened at a higher level.
65"""
66
67__author__ = '[email protected] (Kenton Varda)'
68
69import struct
70
71from google.protobuf.internal import wire_format
72
73
74# This will overflow and thus become IEEE-754 "infinity".  We would use
75# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
76_POS_INF = 1e10000
77_NEG_INF = -_POS_INF
78
79
80def _VarintSize(value):
81  """Compute the size of a varint value."""
82  if value <= 0x7f: return 1
83  if value <= 0x3fff: return 2
84  if value <= 0x1fffff: return 3
85  if value <= 0xfffffff: return 4
86  if value <= 0x7ffffffff: return 5
87  if value <= 0x3ffffffffff: return 6
88  if value <= 0x1ffffffffffff: return 7
89  if value <= 0xffffffffffffff: return 8
90  if value <= 0x7fffffffffffffff: return 9
91  return 10
92
93
94def _SignedVarintSize(value):
95  """Compute the size of a signed varint value."""
96  if value < 0: return 10
97  if value <= 0x7f: return 1
98  if value <= 0x3fff: return 2
99  if value <= 0x1fffff: return 3
100  if value <= 0xfffffff: return 4
101  if value <= 0x7ffffffff: return 5
102  if value <= 0x3ffffffffff: return 6
103  if value <= 0x1ffffffffffff: return 7
104  if value <= 0xffffffffffffff: return 8
105  if value <= 0x7fffffffffffffff: return 9
106  return 10
107
108
109def _TagSize(field_number):
110  """Returns the number of bytes required to serialize a tag with this field
111  number."""
112  # Just pass in type 0, since the type won't affect the tag+type size.
113  return _VarintSize(wire_format.PackTag(field_number, 0))
114
115
116# --------------------------------------------------------------------
117# In this section we define some generic sizers.  Each of these functions
118# takes parameters specific to a particular field type, e.g. int32 or fixed64.
119# It returns another function which in turn takes parameters specific to a
120# particular field, e.g. the field number and whether it is repeated or packed.
121# Look at the next section to see how these are used.
122
123
124def _SimpleSizer(compute_value_size):
125  """A sizer which uses the function compute_value_size to compute the size of
126  each value.  Typically compute_value_size is _VarintSize."""
127
128  def SpecificSizer(field_number, is_repeated, is_packed):
129    tag_size = _TagSize(field_number)
130    if is_packed:
131      local_VarintSize = _VarintSize
132      def PackedFieldSize(value):
133        result = 0
134        for element in value:
135          result += compute_value_size(element)
136        return result + local_VarintSize(result) + tag_size
137      return PackedFieldSize
138    elif is_repeated:
139      def RepeatedFieldSize(value):
140        result = tag_size * len(value)
141        for element in value:
142          result += compute_value_size(element)
143        return result
144      return RepeatedFieldSize
145    else:
146      def FieldSize(value):
147        return tag_size + compute_value_size(value)
148      return FieldSize
149
150  return SpecificSizer
151
152
153def _ModifiedSizer(compute_value_size, modify_value):
154  """Like SimpleSizer, but modify_value is invoked on each value before it is
155  passed to compute_value_size.  modify_value is typically ZigZagEncode."""
156
157  def SpecificSizer(field_number, is_repeated, is_packed):
158    tag_size = _TagSize(field_number)
159    if is_packed:
160      local_VarintSize = _VarintSize
161      def PackedFieldSize(value):
162        result = 0
163        for element in value:
164          result += compute_value_size(modify_value(element))
165        return result + local_VarintSize(result) + tag_size
166      return PackedFieldSize
167    elif is_repeated:
168      def RepeatedFieldSize(value):
169        result = tag_size * len(value)
170        for element in value:
171          result += compute_value_size(modify_value(element))
172        return result
173      return RepeatedFieldSize
174    else:
175      def FieldSize(value):
176        return tag_size + compute_value_size(modify_value(value))
177      return FieldSize
178
179  return SpecificSizer
180
181
182def _FixedSizer(value_size):
183  """Like _SimpleSizer except for a fixed-size field.  The input is the size
184  of one value."""
185
186  def SpecificSizer(field_number, is_repeated, is_packed):
187    tag_size = _TagSize(field_number)
188    if is_packed:
189      local_VarintSize = _VarintSize
190      def PackedFieldSize(value):
191        result = len(value) * value_size
192        return result + local_VarintSize(result) + tag_size
193      return PackedFieldSize
194    elif is_repeated:
195      element_size = value_size + tag_size
196      def RepeatedFieldSize(value):
197        return len(value) * element_size
198      return RepeatedFieldSize
199    else:
200      field_size = value_size + tag_size
201      def FieldSize(value):
202        return field_size
203      return FieldSize
204
205  return SpecificSizer
206
207
208# ====================================================================
209# Here we declare a sizer constructor for each field type.  Each "sizer
210# constructor" is a function that takes (field_number, is_repeated, is_packed)
211# as parameters and returns a sizer, which in turn takes a field value as
212# a parameter and returns its encoded size.
213
214
215Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
216
217UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
218
219SInt32Sizer = SInt64Sizer = _ModifiedSizer(
220    _SignedVarintSize, wire_format.ZigZagEncode)
221
222Fixed32Sizer = SFixed32Sizer = FloatSizer  = _FixedSizer(4)
223Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
224
225BoolSizer = _FixedSizer(1)
226
227
228def StringSizer(field_number, is_repeated, is_packed):
229  """Returns a sizer for a string field."""
230
231  tag_size = _TagSize(field_number)
232  local_VarintSize = _VarintSize
233  local_len = len
234  assert not is_packed
235  if is_repeated:
236    def RepeatedFieldSize(value):
237      result = tag_size * len(value)
238      for element in value:
239        l = local_len(element.encode('utf-8'))
240        result += local_VarintSize(l) + l
241      return result
242    return RepeatedFieldSize
243  else:
244    def FieldSize(value):
245      l = local_len(value.encode('utf-8'))
246      return tag_size + local_VarintSize(l) + l
247    return FieldSize
248
249
250def BytesSizer(field_number, is_repeated, is_packed):
251  """Returns a sizer for a bytes field."""
252
253  tag_size = _TagSize(field_number)
254  local_VarintSize = _VarintSize
255  local_len = len
256  assert not is_packed
257  if is_repeated:
258    def RepeatedFieldSize(value):
259      result = tag_size * len(value)
260      for element in value:
261        l = local_len(element)
262        result += local_VarintSize(l) + l
263      return result
264    return RepeatedFieldSize
265  else:
266    def FieldSize(value):
267      l = local_len(value)
268      return tag_size + local_VarintSize(l) + l
269    return FieldSize
270
271
272def GroupSizer(field_number, is_repeated, is_packed):
273  """Returns a sizer for a group field."""
274
275  tag_size = _TagSize(field_number) * 2
276  assert not is_packed
277  if is_repeated:
278    def RepeatedFieldSize(value):
279      result = tag_size * len(value)
280      for element in value:
281        result += element.ByteSize()
282      return result
283    return RepeatedFieldSize
284  else:
285    def FieldSize(value):
286      return tag_size + value.ByteSize()
287    return FieldSize
288
289
290def MessageSizer(field_number, is_repeated, is_packed):
291  """Returns a sizer for a message field."""
292
293  tag_size = _TagSize(field_number)
294  local_VarintSize = _VarintSize
295  assert not is_packed
296  if is_repeated:
297    def RepeatedFieldSize(value):
298      result = tag_size * len(value)
299      for element in value:
300        l = element.ByteSize()
301        result += local_VarintSize(l) + l
302      return result
303    return RepeatedFieldSize
304  else:
305    def FieldSize(value):
306      l = value.ByteSize()
307      return tag_size + local_VarintSize(l) + l
308    return FieldSize
309
310
311# --------------------------------------------------------------------
312# MessageSet is special: it needs custom logic to compute its size properly.
313
314
315def MessageSetItemSizer(field_number):
316  """Returns a sizer for extensions of MessageSet.
317
318  The message set message looks like this:
319    message MessageSet {
320      repeated group Item = 1 {
321        required int32 type_id = 2;
322        required string message = 3;
323      }
324    }
325  """
326  static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
327                 _TagSize(3))
328  local_VarintSize = _VarintSize
329
330  def FieldSize(value):
331    l = value.ByteSize()
332    return static_size + local_VarintSize(l) + l
333
334  return FieldSize
335
336
337# --------------------------------------------------------------------
338# Map is special: it needs custom logic to compute its size properly.
339
340
341def MapSizer(field_descriptor, is_message_map):
342  """Returns a sizer for a map field."""
343
344  # Can't look at field_descriptor.message_type._concrete_class because it may
345  # not have been initialized yet.
346  message_type = field_descriptor.message_type
347  message_sizer = MessageSizer(field_descriptor.number, False, False)
348
349  def FieldSize(map_value):
350    total = 0
351    for key in map_value:
352      value = map_value[key]
353      # It's wasteful to create the messages and throw them away one second
354      # later since we'll do the same for the actual encode.  But there's not an
355      # obvious way to avoid this within the current design without tons of code
356      # duplication. For message map, value.ByteSize() should be called to
357      # update the status.
358      entry_msg = message_type._concrete_class(key=key, value=value)
359      total += message_sizer(entry_msg)
360      if is_message_map:
361        value.ByteSize()
362    return total
363
364  return FieldSize
365
366# ====================================================================
367# Encoders!
368
369
370def _VarintEncoder():
371  """Return an encoder for a basic varint value (does not include tag)."""
372
373  local_int2byte = struct.Struct('>B').pack
374
375  def EncodeVarint(write, value, unused_deterministic=None):
376    bits = value & 0x7f
377    value >>= 7
378    while value:
379      write(local_int2byte(0x80|bits))
380      bits = value & 0x7f
381      value >>= 7
382    return write(local_int2byte(bits))
383
384  return EncodeVarint
385
386
387def _SignedVarintEncoder():
388  """Return an encoder for a basic signed varint value (does not include
389  tag)."""
390
391  local_int2byte = struct.Struct('>B').pack
392
393  def EncodeSignedVarint(write, value, unused_deterministic=None):
394    if value < 0:
395      value += (1 << 64)
396    bits = value & 0x7f
397    value >>= 7
398    while value:
399      write(local_int2byte(0x80|bits))
400      bits = value & 0x7f
401      value >>= 7
402    return write(local_int2byte(bits))
403
404  return EncodeSignedVarint
405
406
407_EncodeVarint = _VarintEncoder()
408_EncodeSignedVarint = _SignedVarintEncoder()
409
410
411def _VarintBytes(value):
412  """Encode the given integer as a varint and return the bytes.  This is only
413  called at startup time so it doesn't need to be fast."""
414
415  pieces = []
416  _EncodeVarint(pieces.append, value, True)
417  return b"".join(pieces)
418
419
420def TagBytes(field_number, wire_type):
421  """Encode the given tag and return the bytes.  Only called at startup."""
422
423  return bytes(_VarintBytes(wire_format.PackTag(field_number, wire_type)))
424
425# --------------------------------------------------------------------
426# As with sizers (see above), we have a number of common encoder
427# implementations.
428
429
430def _SimpleEncoder(wire_type, encode_value, compute_value_size):
431  """Return a constructor for an encoder for fields of a particular type.
432
433  Args:
434      wire_type:  The field's wire type, for encoding tags.
435      encode_value:  A function which encodes an individual value, e.g.
436        _EncodeVarint().
437      compute_value_size:  A function which computes the size of an individual
438        value, e.g. _VarintSize().
439  """
440
441  def SpecificEncoder(field_number, is_repeated, is_packed):
442    if is_packed:
443      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
444      local_EncodeVarint = _EncodeVarint
445      def EncodePackedField(write, value, deterministic):
446        write(tag_bytes)
447        size = 0
448        for element in value:
449          size += compute_value_size(element)
450        local_EncodeVarint(write, size, deterministic)
451        for element in value:
452          encode_value(write, element, deterministic)
453      return EncodePackedField
454    elif is_repeated:
455      tag_bytes = TagBytes(field_number, wire_type)
456      def EncodeRepeatedField(write, value, deterministic):
457        for element in value:
458          write(tag_bytes)
459          encode_value(write, element, deterministic)
460      return EncodeRepeatedField
461    else:
462      tag_bytes = TagBytes(field_number, wire_type)
463      def EncodeField(write, value, deterministic):
464        write(tag_bytes)
465        return encode_value(write, value, deterministic)
466      return EncodeField
467
468  return SpecificEncoder
469
470
471def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
472  """Like SimpleEncoder but additionally invokes modify_value on every value
473  before passing it to encode_value.  Usually modify_value is ZigZagEncode."""
474
475  def SpecificEncoder(field_number, is_repeated, is_packed):
476    if is_packed:
477      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
478      local_EncodeVarint = _EncodeVarint
479      def EncodePackedField(write, value, deterministic):
480        write(tag_bytes)
481        size = 0
482        for element in value:
483          size += compute_value_size(modify_value(element))
484        local_EncodeVarint(write, size, deterministic)
485        for element in value:
486          encode_value(write, modify_value(element), deterministic)
487      return EncodePackedField
488    elif is_repeated:
489      tag_bytes = TagBytes(field_number, wire_type)
490      def EncodeRepeatedField(write, value, deterministic):
491        for element in value:
492          write(tag_bytes)
493          encode_value(write, modify_value(element), deterministic)
494      return EncodeRepeatedField
495    else:
496      tag_bytes = TagBytes(field_number, wire_type)
497      def EncodeField(write, value, deterministic):
498        write(tag_bytes)
499        return encode_value(write, modify_value(value), deterministic)
500      return EncodeField
501
502  return SpecificEncoder
503
504
505def _StructPackEncoder(wire_type, format):
506  """Return a constructor for an encoder for a fixed-width field.
507
508  Args:
509      wire_type:  The field's wire type, for encoding tags.
510      format:  The format string to pass to struct.pack().
511  """
512
513  value_size = struct.calcsize(format)
514
515  def SpecificEncoder(field_number, is_repeated, is_packed):
516    local_struct_pack = struct.pack
517    if is_packed:
518      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
519      local_EncodeVarint = _EncodeVarint
520      def EncodePackedField(write, value, deterministic):
521        write(tag_bytes)
522        local_EncodeVarint(write, len(value) * value_size, deterministic)
523        for element in value:
524          write(local_struct_pack(format, element))
525      return EncodePackedField
526    elif is_repeated:
527      tag_bytes = TagBytes(field_number, wire_type)
528      def EncodeRepeatedField(write, value, unused_deterministic=None):
529        for element in value:
530          write(tag_bytes)
531          write(local_struct_pack(format, element))
532      return EncodeRepeatedField
533    else:
534      tag_bytes = TagBytes(field_number, wire_type)
535      def EncodeField(write, value, unused_deterministic=None):
536        write(tag_bytes)
537        return write(local_struct_pack(format, value))
538      return EncodeField
539
540  return SpecificEncoder
541
542
543def _FloatingPointEncoder(wire_type, format):
544  """Return a constructor for an encoder for float fields.
545
546  This is like StructPackEncoder, but catches errors that may be due to
547  passing non-finite floating-point values to struct.pack, and makes a
548  second attempt to encode those values.
549
550  Args:
551      wire_type:  The field's wire type, for encoding tags.
552      format:  The format string to pass to struct.pack().
553  """
554
555  value_size = struct.calcsize(format)
556  if value_size == 4:
557    def EncodeNonFiniteOrRaise(write, value):
558      # Remember that the serialized form uses little-endian byte order.
559      if value == _POS_INF:
560        write(b'\x00\x00\x80\x7F')
561      elif value == _NEG_INF:
562        write(b'\x00\x00\x80\xFF')
563      elif value != value:           # NaN
564        write(b'\x00\x00\xC0\x7F')
565      else:
566        raise
567  elif value_size == 8:
568    def EncodeNonFiniteOrRaise(write, value):
569      if value == _POS_INF:
570        write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F')
571      elif value == _NEG_INF:
572        write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF')
573      elif value != value:                         # NaN
574        write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F')
575      else:
576        raise
577  else:
578    raise ValueError('Can\'t encode floating-point values that are '
579                     '%d bytes long (only 4 or 8)' % value_size)
580
581  def SpecificEncoder(field_number, is_repeated, is_packed):
582    local_struct_pack = struct.pack
583    if is_packed:
584      tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
585      local_EncodeVarint = _EncodeVarint
586      def EncodePackedField(write, value, deterministic):
587        write(tag_bytes)
588        local_EncodeVarint(write, len(value) * value_size, deterministic)
589        for element in value:
590          # This try/except block is going to be faster than any code that
591          # we could write to check whether element is finite.
592          try:
593            write(local_struct_pack(format, element))
594          except SystemError:
595            EncodeNonFiniteOrRaise(write, element)
596      return EncodePackedField
597    elif is_repeated:
598      tag_bytes = TagBytes(field_number, wire_type)
599      def EncodeRepeatedField(write, value, unused_deterministic=None):
600        for element in value:
601          write(tag_bytes)
602          try:
603            write(local_struct_pack(format, element))
604          except SystemError:
605            EncodeNonFiniteOrRaise(write, element)
606      return EncodeRepeatedField
607    else:
608      tag_bytes = TagBytes(field_number, wire_type)
609      def EncodeField(write, value, unused_deterministic=None):
610        write(tag_bytes)
611        try:
612          write(local_struct_pack(format, value))
613        except SystemError:
614          EncodeNonFiniteOrRaise(write, value)
615      return EncodeField
616
617  return SpecificEncoder
618
619
620# ====================================================================
621# Here we declare an encoder constructor for each field type.  These work
622# very similarly to sizer constructors, described earlier.
623
624
625Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
626    wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
627
628UInt32Encoder = UInt64Encoder = _SimpleEncoder(
629    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
630
631SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
632    wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
633    wire_format.ZigZagEncode)
634
635# Note that Python conveniently guarantees that when using the '<' prefix on
636# formats, they will also have the same size across all platforms (as opposed
637# to without the prefix, where their sizes depend on the C compiler's basic
638# type sizes).
639Fixed32Encoder  = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
640Fixed64Encoder  = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
641SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
642SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
643FloatEncoder    = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f')
644DoubleEncoder   = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d')
645
646
647def BoolEncoder(field_number, is_repeated, is_packed):
648  """Returns an encoder for a boolean field."""
649
650  false_byte = b'\x00'
651  true_byte = b'\x01'
652  if is_packed:
653    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
654    local_EncodeVarint = _EncodeVarint
655    def EncodePackedField(write, value, deterministic):
656      write(tag_bytes)
657      local_EncodeVarint(write, len(value), deterministic)
658      for element in value:
659        if element:
660          write(true_byte)
661        else:
662          write(false_byte)
663    return EncodePackedField
664  elif is_repeated:
665    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
666    def EncodeRepeatedField(write, value, unused_deterministic=None):
667      for element in value:
668        write(tag_bytes)
669        if element:
670          write(true_byte)
671        else:
672          write(false_byte)
673    return EncodeRepeatedField
674  else:
675    tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
676    def EncodeField(write, value, unused_deterministic=None):
677      write(tag_bytes)
678      if value:
679        return write(true_byte)
680      return write(false_byte)
681    return EncodeField
682
683
684def StringEncoder(field_number, is_repeated, is_packed):
685  """Returns an encoder for a string field."""
686
687  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
688  local_EncodeVarint = _EncodeVarint
689  local_len = len
690  assert not is_packed
691  if is_repeated:
692    def EncodeRepeatedField(write, value, deterministic):
693      for element in value:
694        encoded = element.encode('utf-8')
695        write(tag)
696        local_EncodeVarint(write, local_len(encoded), deterministic)
697        write(encoded)
698    return EncodeRepeatedField
699  else:
700    def EncodeField(write, value, deterministic):
701      encoded = value.encode('utf-8')
702      write(tag)
703      local_EncodeVarint(write, local_len(encoded), deterministic)
704      return write(encoded)
705    return EncodeField
706
707
708def BytesEncoder(field_number, is_repeated, is_packed):
709  """Returns an encoder for a bytes field."""
710
711  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
712  local_EncodeVarint = _EncodeVarint
713  local_len = len
714  assert not is_packed
715  if is_repeated:
716    def EncodeRepeatedField(write, value, deterministic):
717      for element in value:
718        write(tag)
719        local_EncodeVarint(write, local_len(element), deterministic)
720        write(element)
721    return EncodeRepeatedField
722  else:
723    def EncodeField(write, value, deterministic):
724      write(tag)
725      local_EncodeVarint(write, local_len(value), deterministic)
726      return write(value)
727    return EncodeField
728
729
730def GroupEncoder(field_number, is_repeated, is_packed):
731  """Returns an encoder for a group field."""
732
733  start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
734  end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
735  assert not is_packed
736  if is_repeated:
737    def EncodeRepeatedField(write, value, deterministic):
738      for element in value:
739        write(start_tag)
740        element._InternalSerialize(write, deterministic)
741        write(end_tag)
742    return EncodeRepeatedField
743  else:
744    def EncodeField(write, value, deterministic):
745      write(start_tag)
746      value._InternalSerialize(write, deterministic)
747      return write(end_tag)
748    return EncodeField
749
750
751def MessageEncoder(field_number, is_repeated, is_packed):
752  """Returns an encoder for a message field."""
753
754  tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
755  local_EncodeVarint = _EncodeVarint
756  assert not is_packed
757  if is_repeated:
758    def EncodeRepeatedField(write, value, deterministic):
759      for element in value:
760        write(tag)
761        local_EncodeVarint(write, element.ByteSize(), deterministic)
762        element._InternalSerialize(write, deterministic)
763    return EncodeRepeatedField
764  else:
765    def EncodeField(write, value, deterministic):
766      write(tag)
767      local_EncodeVarint(write, value.ByteSize(), deterministic)
768      return value._InternalSerialize(write, deterministic)
769    return EncodeField
770
771
772# --------------------------------------------------------------------
773# As before, MessageSet is special.
774
775
776def MessageSetItemEncoder(field_number):
777  """Encoder for extensions of MessageSet.
778
779  The message set message looks like this:
780    message MessageSet {
781      repeated group Item = 1 {
782        required int32 type_id = 2;
783        required string message = 3;
784      }
785    }
786  """
787  start_bytes = b"".join([
788      TagBytes(1, wire_format.WIRETYPE_START_GROUP),
789      TagBytes(2, wire_format.WIRETYPE_VARINT),
790      _VarintBytes(field_number),
791      TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
792  end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
793  local_EncodeVarint = _EncodeVarint
794
795  def EncodeField(write, value, deterministic):
796    write(start_bytes)
797    local_EncodeVarint(write, value.ByteSize(), deterministic)
798    value._InternalSerialize(write, deterministic)
799    return write(end_bytes)
800
801  return EncodeField
802
803
804# --------------------------------------------------------------------
805# As before, Map is special.
806
807
808def MapEncoder(field_descriptor):
809  """Encoder for extensions of MessageSet.
810
811  Maps always have a wire format like this:
812    message MapEntry {
813      key_type key = 1;
814      value_type value = 2;
815    }
816    repeated MapEntry map = N;
817  """
818  # Can't look at field_descriptor.message_type._concrete_class because it may
819  # not have been initialized yet.
820  message_type = field_descriptor.message_type
821  encode_message = MessageEncoder(field_descriptor.number, False, False)
822
823  def EncodeField(write, value, deterministic):
824    value_keys = sorted(value.keys()) if deterministic else value
825    for key in value_keys:
826      entry_msg = message_type._concrete_class(key=key, value=value[key])
827      encode_message(write, entry_msg, deterministic)
828
829  return EncodeField
830