blob: 31869e4575ca900440a8ce452ce2af5ed87a19aa [file] [log] [blame]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Code for decoding protocol buffer primitives.
32
33This code is very similar to encoder.py -- read the docs for that module first.
34
35A "decoder" is a function with the signature:
36 Decode(buffer, pos, end, message, field_dict)
37The arguments are:
38 buffer: The string containing the encoded message.
39 pos: The current position in the string.
40 end: The position in the string where the current message ends. May be
41 less than len(buffer) if we're reading a sub-message.
42 message: The message object into which we're parsing.
43 field_dict: message._fields (avoids a hashtable lookup).
44The decoder reads the field and stores it into field_dict, returning the new
45buffer position. A decoder for a repeated field may proactively decode all of
46the elements of that field, if they appear consecutively.
47
48Note that decoders may throw any of the following:
49 IndexError: Indicates a truncated message.
50 struct.error: Unpacking of a fixed-width field failed.
51 message.DecodeError: Other errors.
52
53Decoders are expected to raise an exception if they are called with pos > end.
54This allows callers to be lax about bounds checking: it's fineto read past
55"end" as long as you are sure that someone else will notice and throw an
56exception later on.
57
58Something up the call stack is expected to catch IndexError and struct.error
59and convert them to message.DecodeError.
60
61Decoders are constructed using decoder constructors with the signature:
62 MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
63The arguments are:
64 field_number: The field number of the field we want to decode.
65 is_repeated: Is the field a repeated field? (bool)
66 is_packed: Is the field a packed field? (bool)
67 key: The key to use when looking up the field within field_dict.
68 (This is actually the FieldDescriptor but nothing in this
69 file should depend on that.)
70 new_default: A function which takes a message object as a parameter and
71 returns a new instance of the default value for this field.
72 (This is called for repeated fields and sub-messages, when an
73 instance does not already exist.)
74
75As with encoders, we define a decoder constructor for every type of field.
76Then, for every field of every message class we construct an actual decoder.
77That decoder goes into a dict indexed by tag, so when we decode a message
78we repeatedly read a tag, look up the corresponding decoder, and invoke it.
79"""
80
81__author__ = 'kenton@google.com (Kenton Varda)'
82
83import struct
84
85import six
86
87if six.PY3:
88 long = int
89
90from google.protobuf.internal import encoder
91from google.protobuf.internal import wire_format
92from google.protobuf import message
93
94
95# This will overflow and thus become IEEE-754 "infinity". We would use
96# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
97_POS_INF = 1e10000
98_NEG_INF = -_POS_INF
99_NAN = _POS_INF * 0
100
101
102# This is not for optimization, but rather to avoid conflicts with local
103# variables named "message".
104_DecodeError = message.DecodeError
105
106
107def _VarintDecoder(mask, result_type):
108 """Return an encoder for a basic varint value (does not include tag).
109
110 Decoded values will be bitwise-anded with the given mask before being
111 returned, e.g. to limit them to 32 bits. The returned decoder does not
112 take the usual "end" parameter -- the caller is expected to do bounds checking
113 after the fact (often the caller can defer such checking until later). The
114 decoder returns a (value, new_pos) pair.
115 """
116
117 def DecodeVarint(buffer, pos):
118 result = 0
119 shift = 0
120 while 1:
121 b = six.indexbytes(buffer, pos)
122 result |= ((b & 0x7f) << shift)
123 pos += 1
124 if not (b & 0x80):
125 result &= mask
126 result = result_type(result)
127 return (result, pos)
128 shift += 7
129 if shift >= 64:
130 raise _DecodeError('Too many bytes when decoding varint.')
131 return DecodeVarint
132
133
134def _SignedVarintDecoder(mask, result_type):
135 """Like _VarintDecoder() but decodes signed values."""
136
137 def DecodeVarint(buffer, pos):
138 result = 0
139 shift = 0
140 while 1:
141 b = six.indexbytes(buffer, pos)
142 result |= ((b & 0x7f) << shift)
143 pos += 1
144 if not (b & 0x80):
145 if result > 0x7fffffffffffffff:
146 result -= (1 << 64)
147 result |= ~mask
148 else:
149 result &= mask
150 result = result_type(result)
151 return (result, pos)
152 shift += 7
153 if shift >= 64:
154 raise _DecodeError('Too many bytes when decoding varint.')
155 return DecodeVarint
156
157# We force 32-bit values to int and 64-bit values to long to make
158# alternate implementations where the distinction is more significant
159# (e.g. the C++ implementation) simpler.
160
161_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
162_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
163
164# Use these versions for values which must be limited to 32 bits.
165_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
166_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
167
168
169def ReadTag(buffer, pos):
170 """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
171
172 We return the raw bytes of the tag rather than decoding them. The raw
173 bytes can then be used to look up the proper decoder. This effectively allows
174 us to trade some work that would be done in pure-python (decoding a varint)
175 for work that is done in C (searching for a byte string in a hash table).
176 In a low-level language it would be much cheaper to decode the varint and
177 use that, but not in Python.
178 """
179
180 start = pos
181 while six.indexbytes(buffer, pos) & 0x80:
182 pos += 1
183 pos += 1
184 return (buffer[start:pos], pos)
185
186
187# --------------------------------------------------------------------
188
189
190def _SimpleDecoder(wire_type, decode_value):
191 """Return a constructor for a decoder for fields of a particular type.
192
193 Args:
194 wire_type: The field's wire type.
195 decode_value: A function which decodes an individual value, e.g.
196 _DecodeVarint()
197 """
198
199 def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
200 if is_packed:
201 local_DecodeVarint = _DecodeVarint
202 def DecodePackedField(buffer, pos, end, message, field_dict):
203 value = field_dict.get(key)
204 if value is None:
205 value = field_dict.setdefault(key, new_default(message))
206 (endpoint, pos) = local_DecodeVarint(buffer, pos)
207 endpoint += pos
208 if endpoint > end:
209 raise _DecodeError('Truncated message.')
210 while pos < endpoint:
211 (element, pos) = decode_value(buffer, pos)
212 value.append(element)
213 if pos > endpoint:
214 del value[-1] # Discard corrupt value.
215 raise _DecodeError('Packed element was truncated.')
216 return pos
217 return DecodePackedField
218 elif is_repeated:
219 tag_bytes = encoder.TagBytes(field_number, wire_type)
220 tag_len = len(tag_bytes)
221 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
222 value = field_dict.get(key)
223 if value is None:
224 value = field_dict.setdefault(key, new_default(message))
225 while 1:
226 (element, new_pos) = decode_value(buffer, pos)
227 value.append(element)
228 # Predict that the next tag is another copy of the same repeated
229 # field.
230 pos = new_pos + tag_len
231 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
232 # Prediction failed. Return.
233 if new_pos > end:
234 raise _DecodeError('Truncated message.')
235 return new_pos
236 return DecodeRepeatedField
237 else:
238 def DecodeField(buffer, pos, end, message, field_dict):
239 (field_dict[key], pos) = decode_value(buffer, pos)
240 if pos > end:
241 del field_dict[key] # Discard corrupt value.
242 raise _DecodeError('Truncated message.')
243 return pos
244 return DecodeField
245
246 return SpecificDecoder
247
248
249def _ModifiedDecoder(wire_type, decode_value, modify_value):
250 """Like SimpleDecoder but additionally invokes modify_value on every value
251 before storing it. Usually modify_value is ZigZagDecode.
252 """
253
254 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
255 # not enough to make a significant difference.
256
257 def InnerDecode(buffer, pos):
258 (result, new_pos) = decode_value(buffer, pos)
259 return (modify_value(result), new_pos)
260 return _SimpleDecoder(wire_type, InnerDecode)
261
262
263def _StructPackDecoder(wire_type, format):
264 """Return a constructor for a decoder for a fixed-width field.
265
266 Args:
267 wire_type: The field's wire type.
268 format: The format string to pass to struct.unpack().
269 """
270
271 value_size = struct.calcsize(format)
272 local_unpack = struct.unpack
273
274 # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
275 # not enough to make a significant difference.
276
277 # Note that we expect someone up-stack to catch struct.error and convert
278 # it to _DecodeError -- this way we don't have to set up exception-
279 # handling blocks every time we parse one value.
280
281 def InnerDecode(buffer, pos):
282 new_pos = pos + value_size
283 result = local_unpack(format, buffer[pos:new_pos])[0]
284 return (result, new_pos)
285 return _SimpleDecoder(wire_type, InnerDecode)
286
287
288def _FloatDecoder():
289 """Returns a decoder for a float field.
290
291 This code works around a bug in struct.unpack for non-finite 32-bit
292 floating-point values.
293 """
294
295 local_unpack = struct.unpack
296
297 def InnerDecode(buffer, pos):
298 # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
299 # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
300 new_pos = pos + 4
301 float_bytes = buffer[pos:new_pos]
302
303 # If this value has all its exponent bits set, then it's non-finite.
304 # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
305 # To avoid that, we parse it specially.
306 if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
307 # If at least one significand bit is set...
308 if float_bytes[0:3] != b'\x00\x00\x80':
309 return (_NAN, new_pos)
310 # If sign bit is set...
311 if float_bytes[3:4] == b'\xFF':
312 return (_NEG_INF, new_pos)
313 return (_POS_INF, new_pos)
314
315 # Note that we expect someone up-stack to catch struct.error and convert
316 # it to _DecodeError -- this way we don't have to set up exception-
317 # handling blocks every time we parse one value.
318 result = local_unpack('<f', float_bytes)[0]
319 return (result, new_pos)
320 return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
321
322
323def _DoubleDecoder():
324 """Returns a decoder for a double field.
325
326 This code works around a bug in struct.unpack for not-a-number.
327 """
328
329 local_unpack = struct.unpack
330
331 def InnerDecode(buffer, pos):
332 # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
333 # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
334 new_pos = pos + 8
335 double_bytes = buffer[pos:new_pos]
336
337 # If this value has all its exponent bits set and at least one significand
338 # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
339 # as inf or -inf. To avoid that, we treat it specially.
340 if ((double_bytes[7:8] in b'\x7F\xFF')
341 and (double_bytes[6:7] >= b'\xF0')
342 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
343 return (_NAN, new_pos)
344
345 # Note that we expect someone up-stack to catch struct.error and convert
346 # it to _DecodeError -- this way we don't have to set up exception-
347 # handling blocks every time we parse one value.
348 result = local_unpack('<d', double_bytes)[0]
349 return (result, new_pos)
350 return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
351
352
353def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
354 enum_type = key.enum_type
355 if is_packed:
356 local_DecodeVarint = _DecodeVarint
357 def DecodePackedField(buffer, pos, end, message, field_dict):
358 value = field_dict.get(key)
359 if value is None:
360 value = field_dict.setdefault(key, new_default(message))
361 (endpoint, pos) = local_DecodeVarint(buffer, pos)
362 endpoint += pos
363 if endpoint > end:
364 raise _DecodeError('Truncated message.')
365 while pos < endpoint:
366 value_start_pos = pos
367 (element, pos) = _DecodeSignedVarint32(buffer, pos)
368 if element in enum_type.values_by_number:
369 value.append(element)
370 else:
371 if not message._unknown_fields:
372 message._unknown_fields = []
373 tag_bytes = encoder.TagBytes(field_number,
374 wire_format.WIRETYPE_VARINT)
375 message._unknown_fields.append(
376 (tag_bytes, buffer[value_start_pos:pos]))
377 if pos > endpoint:
378 if element in enum_type.values_by_number:
379 del value[-1] # Discard corrupt value.
380 else:
381 del message._unknown_fields[-1]
382 raise _DecodeError('Packed element was truncated.')
383 return pos
384 return DecodePackedField
385 elif is_repeated:
386 tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
387 tag_len = len(tag_bytes)
388 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
389 value = field_dict.get(key)
390 if value is None:
391 value = field_dict.setdefault(key, new_default(message))
392 while 1:
393 (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
394 if element in enum_type.values_by_number:
395 value.append(element)
396 else:
397 if not message._unknown_fields:
398 message._unknown_fields = []
399 message._unknown_fields.append(
400 (tag_bytes, buffer[pos:new_pos]))
401 # Predict that the next tag is another copy of the same repeated
402 # field.
403 pos = new_pos + tag_len
404 if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
405 # Prediction failed. Return.
406 if new_pos > end:
407 raise _DecodeError('Truncated message.')
408 return new_pos
409 return DecodeRepeatedField
410 else:
411 def DecodeField(buffer, pos, end, message, field_dict):
412 value_start_pos = pos
413 (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
414 if pos > end:
415 raise _DecodeError('Truncated message.')
416 if enum_value in enum_type.values_by_number:
417 field_dict[key] = enum_value
418 else:
419 if not message._unknown_fields:
420 message._unknown_fields = []
421 tag_bytes = encoder.TagBytes(field_number,
422 wire_format.WIRETYPE_VARINT)
423 message._unknown_fields.append(
424 (tag_bytes, buffer[value_start_pos:pos]))
425 return pos
426 return DecodeField
427
428
429# --------------------------------------------------------------------
430
431
432Int32Decoder = _SimpleDecoder(
433 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
434
435Int64Decoder = _SimpleDecoder(
436 wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
437
438UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
439UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
440
441SInt32Decoder = _ModifiedDecoder(
442 wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
443SInt64Decoder = _ModifiedDecoder(
444 wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
445
446# Note that Python conveniently guarantees that when using the '<' prefix on
447# formats, they will also have the same size across all platforms (as opposed
448# to without the prefix, where their sizes depend on the C compiler's basic
449# type sizes).
450Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
451Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
452SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
453SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
454FloatDecoder = _FloatDecoder()
455DoubleDecoder = _DoubleDecoder()
456
457BoolDecoder = _ModifiedDecoder(
458 wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
459
460
461def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
462 """Returns a decoder for a string field."""
463
464 local_DecodeVarint = _DecodeVarint
465 local_unicode = six.text_type
466
467 def _ConvertToUnicode(byte_str):
468 try:
469 return local_unicode(byte_str, 'utf-8')
470 except UnicodeDecodeError as e:
471 # add more information to the error message and re-raise it.
472 e.reason = '%s in field: %s' % (e, key.full_name)
473 raise
474
475 assert not is_packed
476 if is_repeated:
477 tag_bytes = encoder.TagBytes(field_number,
478 wire_format.WIRETYPE_LENGTH_DELIMITED)
479 tag_len = len(tag_bytes)
480 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
481 value = field_dict.get(key)
482 if value is None:
483 value = field_dict.setdefault(key, new_default(message))
484 while 1:
485 (size, pos) = local_DecodeVarint(buffer, pos)
486 new_pos = pos + size
487 if new_pos > end:
488 raise _DecodeError('Truncated string.')
489 value.append(_ConvertToUnicode(buffer[pos:new_pos]))
490 # Predict that the next tag is another copy of the same repeated field.
491 pos = new_pos + tag_len
492 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
493 # Prediction failed. Return.
494 return new_pos
495 return DecodeRepeatedField
496 else:
497 def DecodeField(buffer, pos, end, message, field_dict):
498 (size, pos) = local_DecodeVarint(buffer, pos)
499 new_pos = pos + size
500 if new_pos > end:
501 raise _DecodeError('Truncated string.')
502 field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
503 return new_pos
504 return DecodeField
505
506
507def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
508 """Returns a decoder for a bytes field."""
509
510 local_DecodeVarint = _DecodeVarint
511
512 assert not is_packed
513 if is_repeated:
514 tag_bytes = encoder.TagBytes(field_number,
515 wire_format.WIRETYPE_LENGTH_DELIMITED)
516 tag_len = len(tag_bytes)
517 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
518 value = field_dict.get(key)
519 if value is None:
520 value = field_dict.setdefault(key, new_default(message))
521 while 1:
522 (size, pos) = local_DecodeVarint(buffer, pos)
523 new_pos = pos + size
524 if new_pos > end:
525 raise _DecodeError('Truncated string.')
526 value.append(buffer[pos:new_pos])
527 # Predict that the next tag is another copy of the same repeated field.
528 pos = new_pos + tag_len
529 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
530 # Prediction failed. Return.
531 return new_pos
532 return DecodeRepeatedField
533 else:
534 def DecodeField(buffer, pos, end, message, field_dict):
535 (size, pos) = local_DecodeVarint(buffer, pos)
536 new_pos = pos + size
537 if new_pos > end:
538 raise _DecodeError('Truncated string.')
539 field_dict[key] = buffer[pos:new_pos]
540 return new_pos
541 return DecodeField
542
543
544def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
545 """Returns a decoder for a group field."""
546
547 end_tag_bytes = encoder.TagBytes(field_number,
548 wire_format.WIRETYPE_END_GROUP)
549 end_tag_len = len(end_tag_bytes)
550
551 assert not is_packed
552 if is_repeated:
553 tag_bytes = encoder.TagBytes(field_number,
554 wire_format.WIRETYPE_START_GROUP)
555 tag_len = len(tag_bytes)
556 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
557 value = field_dict.get(key)
558 if value is None:
559 value = field_dict.setdefault(key, new_default(message))
560 while 1:
561 value = field_dict.get(key)
562 if value is None:
563 value = field_dict.setdefault(key, new_default(message))
564 # Read sub-message.
565 pos = value.add()._InternalParse(buffer, pos, end)
566 # Read end tag.
567 new_pos = pos+end_tag_len
568 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
569 raise _DecodeError('Missing group end tag.')
570 # Predict that the next tag is another copy of the same repeated field.
571 pos = new_pos + tag_len
572 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
573 # Prediction failed. Return.
574 return new_pos
575 return DecodeRepeatedField
576 else:
577 def DecodeField(buffer, pos, end, message, field_dict):
578 value = field_dict.get(key)
579 if value is None:
580 value = field_dict.setdefault(key, new_default(message))
581 # Read sub-message.
582 pos = value._InternalParse(buffer, pos, end)
583 # Read end tag.
584 new_pos = pos+end_tag_len
585 if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
586 raise _DecodeError('Missing group end tag.')
587 return new_pos
588 return DecodeField
589
590
591def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
592 """Returns a decoder for a message field."""
593
594 local_DecodeVarint = _DecodeVarint
595
596 assert not is_packed
597 if is_repeated:
598 tag_bytes = encoder.TagBytes(field_number,
599 wire_format.WIRETYPE_LENGTH_DELIMITED)
600 tag_len = len(tag_bytes)
601 def DecodeRepeatedField(buffer, pos, end, message, field_dict):
602 value = field_dict.get(key)
603 if value is None:
604 value = field_dict.setdefault(key, new_default(message))
605 while 1:
606 # Read length.
607 (size, pos) = local_DecodeVarint(buffer, pos)
608 new_pos = pos + size
609 if new_pos > end:
610 raise _DecodeError('Truncated message.')
611 # Read sub-message.
612 if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
613 # The only reason _InternalParse would return early is if it
614 # encountered an end-group tag.
615 raise _DecodeError('Unexpected end-group tag.')
616 # Predict that the next tag is another copy of the same repeated field.
617 pos = new_pos + tag_len
618 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
619 # Prediction failed. Return.
620 return new_pos
621 return DecodeRepeatedField
622 else:
623 def DecodeField(buffer, pos, end, message, field_dict):
624 value = field_dict.get(key)
625 if value is None:
626 value = field_dict.setdefault(key, new_default(message))
627 # Read length.
628 (size, pos) = local_DecodeVarint(buffer, pos)
629 new_pos = pos + size
630 if new_pos > end:
631 raise _DecodeError('Truncated message.')
632 # Read sub-message.
633 if value._InternalParse(buffer, pos, new_pos) != new_pos:
634 # The only reason _InternalParse would return early is if it encountered
635 # an end-group tag.
636 raise _DecodeError('Unexpected end-group tag.')
637 return new_pos
638 return DecodeField
639
640
641# --------------------------------------------------------------------
642
643MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
644
645def MessageSetItemDecoder(extensions_by_number):
646 """Returns a decoder for a MessageSet item.
647
648 The parameter is the _extensions_by_number map for the message class.
649
650 The message set message looks like this:
651 message MessageSet {
652 repeated group Item = 1 {
653 required int32 type_id = 2;
654 required string message = 3;
655 }
656 }
657 """
658
659 type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
660 message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
661 item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
662
663 local_ReadTag = ReadTag
664 local_DecodeVarint = _DecodeVarint
665 local_SkipField = SkipField
666
667 def DecodeItem(buffer, pos, end, message, field_dict):
668 message_set_item_start = pos
669 type_id = -1
670 message_start = -1
671 message_end = -1
672
673 # Technically, type_id and message can appear in any order, so we need
674 # a little loop here.
675 while 1:
676 (tag_bytes, pos) = local_ReadTag(buffer, pos)
677 if tag_bytes == type_id_tag_bytes:
678 (type_id, pos) = local_DecodeVarint(buffer, pos)
679 elif tag_bytes == message_tag_bytes:
680 (size, message_start) = local_DecodeVarint(buffer, pos)
681 pos = message_end = message_start + size
682 elif tag_bytes == item_end_tag_bytes:
683 break
684 else:
685 pos = SkipField(buffer, pos, end, tag_bytes)
686 if pos == -1:
687 raise _DecodeError('Missing group end tag.')
688
689 if pos > end:
690 raise _DecodeError('Truncated message.')
691
692 if type_id == -1:
693 raise _DecodeError('MessageSet item missing type_id.')
694 if message_start == -1:
695 raise _DecodeError('MessageSet item missing message.')
696
697 extension = extensions_by_number.get(type_id)
698 if extension is not None:
699 value = field_dict.get(extension)
700 if value is None:
701 value = field_dict.setdefault(
702 extension, extension.message_type._concrete_class())
703 if value._InternalParse(buffer, message_start,message_end) != message_end:
704 # The only reason _InternalParse would return early is if it encountered
705 # an end-group tag.
706 raise _DecodeError('Unexpected end-group tag.')
707 else:
708 if not message._unknown_fields:
709 message._unknown_fields = []
710 message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
711 buffer[message_set_item_start:pos]))
712
713 return pos
714
715 return DecodeItem
716
717# --------------------------------------------------------------------
718
719def MapDecoder(field_descriptor, new_default, is_message_map):
720 """Returns a decoder for a map field."""
721
722 key = field_descriptor
723 tag_bytes = encoder.TagBytes(field_descriptor.number,
724 wire_format.WIRETYPE_LENGTH_DELIMITED)
725 tag_len = len(tag_bytes)
726 local_DecodeVarint = _DecodeVarint
727 # Can't read _concrete_class yet; might not be initialized.
728 message_type = field_descriptor.message_type
729
730 def DecodeMap(buffer, pos, end, message, field_dict):
731 submsg = message_type._concrete_class()
732 value = field_dict.get(key)
733 if value is None:
734 value = field_dict.setdefault(key, new_default(message))
735 while 1:
736 # Read length.
737 (size, pos) = local_DecodeVarint(buffer, pos)
738 new_pos = pos + size
739 if new_pos > end:
740 raise _DecodeError('Truncated message.')
741 # Read sub-message.
742 submsg.Clear()
743 if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
744 # The only reason _InternalParse would return early is if it
745 # encountered an end-group tag.
746 raise _DecodeError('Unexpected end-group tag.')
747
748 if is_message_map:
749 value[submsg.key].MergeFrom(submsg.value)
750 else:
751 value[submsg.key] = submsg.value
752
753 # Predict that the next tag is another copy of the same repeated field.
754 pos = new_pos + tag_len
755 if buffer[new_pos:pos] != tag_bytes or new_pos == end:
756 # Prediction failed. Return.
757 return new_pos
758
759 return DecodeMap
760
761# --------------------------------------------------------------------
762# Optimization is not as heavy here because calls to SkipField() are rare,
763# except for handling end-group tags.
764
765def _SkipVarint(buffer, pos, end):
766 """Skip a varint value. Returns the new position."""
767 # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
768 # With this code, ord(b'') raises TypeError. Both are handled in
769 # python_message.py to generate a 'Truncated message' error.
770 while ord(buffer[pos:pos+1]) & 0x80:
771 pos += 1
772 pos += 1
773 if pos > end:
774 raise _DecodeError('Truncated message.')
775 return pos
776
777def _SkipFixed64(buffer, pos, end):
778 """Skip a fixed64 value. Returns the new position."""
779
780 pos += 8
781 if pos > end:
782 raise _DecodeError('Truncated message.')
783 return pos
784
785def _SkipLengthDelimited(buffer, pos, end):
786 """Skip a length-delimited value. Returns the new position."""
787
788 (size, pos) = _DecodeVarint(buffer, pos)
789 pos += size
790 if pos > end:
791 raise _DecodeError('Truncated message.')
792 return pos
793
794def _SkipGroup(buffer, pos, end):
795 """Skip sub-group. Returns the new position."""
796
797 while 1:
798 (tag_bytes, pos) = ReadTag(buffer, pos)
799 new_pos = SkipField(buffer, pos, end, tag_bytes)
800 if new_pos == -1:
801 return pos
802 pos = new_pos
803
804def _EndGroup(buffer, pos, end):
805 """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
806
807 return -1
808
809def _SkipFixed32(buffer, pos, end):
810 """Skip a fixed32 value. Returns the new position."""
811
812 pos += 4
813 if pos > end:
814 raise _DecodeError('Truncated message.')
815 return pos
816
817def _RaiseInvalidWireType(buffer, pos, end):
818 """Skip function for unknown wire types. Raises an exception."""
819
820 raise _DecodeError('Tag had invalid wire type.')
821
822def _FieldSkipper():
823 """Constructs the SkipField function."""
824
825 WIRETYPE_TO_SKIPPER = [
826 _SkipVarint,
827 _SkipFixed64,
828 _SkipLengthDelimited,
829 _SkipGroup,
830 _EndGroup,
831 _SkipFixed32,
832 _RaiseInvalidWireType,
833 _RaiseInvalidWireType,
834 ]
835
836 wiretype_mask = wire_format.TAG_TYPE_MASK
837
838 def SkipField(buffer, pos, end, tag_bytes):
839 """Skips a field with the specified tag.
840
841 |pos| should point to the byte immediately after the tag.
842
843 Returns:
844 The new position (after the tag value), or -1 if the tag is an end-group
845 tag (in which case the calling loop should break).
846 """
847
848 # The wire type is always in the first byte since varints are little-endian.
849 wire_type = ord(tag_bytes[0:1]) & wiretype_mask
850 return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
851
852 return SkipField
853
854SkipField = _FieldSkipper()