blob: 87f60666ab64012b8aeb51f7f1a7c76a916d16c2 [file] [log] [blame]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53from io import BytesIO
54import sys
55import struct
56import weakref
57
58import six
59import six.moves.copyreg as copyreg
60
61# We use "as" to avoid name collisions with variables.
62from google.protobuf.internal import containers
63from google.protobuf.internal import decoder
64from google.protobuf.internal import encoder
65from google.protobuf.internal import enum_type_wrapper
66from google.protobuf.internal import message_listener as message_listener_mod
67from google.protobuf.internal import type_checkers
68from google.protobuf.internal import well_known_types
69from google.protobuf.internal import wire_format
70from google.protobuf import descriptor as descriptor_mod
71from google.protobuf import message as message_mod
72from google.protobuf import symbol_database
73from google.protobuf import text_format
74
75_FieldDescriptor = descriptor_mod.FieldDescriptor
76_AnyFullTypeName = 'google.protobuf.Any'
77
78
79class GeneratedProtocolMessageType(type):
80
81 """Metaclass for protocol message classes created at runtime from Descriptors.
82
83 We add implementations for all methods described in the Message class. We
84 also create properties to allow getting/setting all fields in the protocol
85 message. Finally, we create slots to prevent users from accidentally
86 "setting" nonexistent fields in the protocol message, which then wouldn't get
87 serialized / deserialized properly.
88
89 The protocol compiler currently uses this metaclass to create protocol
90 message classes at runtime. Clients can also manually create their own
91 classes at runtime, as in this example:
92
93 mydescriptor = Descriptor(.....)
94 class MyProtoClass(Message):
95 __metaclass__ = GeneratedProtocolMessageType
96 DESCRIPTOR = mydescriptor
97 myproto_instance = MyProtoClass()
98 myproto.foo_field = 23
99 ...
100
101 The above example will not work for nested types. If you wish to include them,
102 use reflection.MakeClass() instead of manually instantiating the class in
103 order to create the appropriate class structure.
104 """
105
106 # Must be consistent with the protocol-compiler code in
107 # proto2/compiler/internal/generator.*.
108 _DESCRIPTOR_KEY = 'DESCRIPTOR'
109
110 def __new__(cls, name, bases, dictionary):
111 """Custom allocation for runtime-generated class types.
112
113 We override __new__ because this is apparently the only place
114 where we can meaningfully set __slots__ on the class we're creating(?).
115 (The interplay between metaclasses and slots is not very well-documented).
116
117 Args:
118 name: Name of the class (ignored, but required by the
119 metaclass protocol).
120 bases: Base classes of the class we're constructing.
121 (Should be message.Message). We ignore this field, but
122 it's required by the metaclass protocol
123 dictionary: The class dictionary of the class we're
124 constructing. dictionary[_DESCRIPTOR_KEY] must contain
125 a Descriptor object describing this protocol message
126 type.
127
128 Returns:
129 Newly-allocated class.
130 """
131 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
132 if descriptor.full_name in well_known_types.WKTBASES:
133 bases += (well_known_types.WKTBASES[descriptor.full_name],)
134 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
135 _AddSlots(descriptor, dictionary)
136
137 superclass = super(GeneratedProtocolMessageType, cls)
138 new_class = superclass.__new__(cls, name, bases, dictionary)
139 return new_class
140
141 def __init__(cls, name, bases, dictionary):
142 """Here we perform the majority of our work on the class.
143 We add enum getters, an __init__ method, implementations
144 of all Message methods, and properties for all fields
145 in the protocol type.
146
147 Args:
148 name: Name of the class (ignored, but required by the
149 metaclass protocol).
150 bases: Base classes of the class we're constructing.
151 (Should be message.Message). We ignore this field, but
152 it's required by the metaclass protocol
153 dictionary: The class dictionary of the class we're
154 constructing. dictionary[_DESCRIPTOR_KEY] must contain
155 a Descriptor object describing this protocol message
156 type.
157 """
158 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
159 cls._decoders_by_tag = {}
160 cls._extensions_by_name = {}
161 cls._extensions_by_number = {}
162 if (descriptor.has_options and
163 descriptor.GetOptions().message_set_wire_format):
164 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
165 decoder.MessageSetItemDecoder(cls._extensions_by_number), None)
166
167 # Attach stuff to each FieldDescriptor for quick lookup later on.
168 for field in descriptor.fields:
169 _AttachFieldHelpers(cls, field)
170
171 descriptor._concrete_class = cls # pylint: disable=protected-access
172 _AddEnumValues(descriptor, cls)
173 _AddInitMethod(descriptor, cls)
174 _AddPropertiesForFields(descriptor, cls)
175 _AddPropertiesForExtensions(descriptor, cls)
176 _AddStaticMethods(cls)
177 _AddMessageMethods(descriptor, cls)
178 _AddPrivateHelperMethods(descriptor, cls)
179 copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
180
181 superclass = super(GeneratedProtocolMessageType, cls)
182 superclass.__init__(name, bases, dictionary)
183
184
185# Stateless helpers for GeneratedProtocolMessageType below.
186# Outside clients should not access these directly.
187#
188# I opted not to make any of these methods on the metaclass, to make it more
189# clear that I'm not really using any state there and to keep clients from
190# thinking that they have direct access to these construction helpers.
191
192
193def _PropertyName(proto_field_name):
194 """Returns the name of the public property attribute which
195 clients can use to get and (in some cases) set the value
196 of a protocol message field.
197
198 Args:
199 proto_field_name: The protocol message field name, exactly
200 as it appears (or would appear) in a .proto file.
201 """
202 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
203 # nnorwitz makes my day by writing:
204 # """
205 # FYI. See the keyword module in the stdlib. This could be as simple as:
206 #
207 # if keyword.iskeyword(proto_field_name):
208 # return proto_field_name + "_"
209 # return proto_field_name
210 # """
211 # Kenton says: The above is a BAD IDEA. People rely on being able to use
212 # getattr() and setattr() to reflectively manipulate field values. If we
213 # rename the properties, then every such user has to also make sure to apply
214 # the same transformation. Note that currently if you name a field "yield",
215 # you can still access it just fine using getattr/setattr -- it's not even
216 # that cumbersome to do so.
217 # TODO(kenton): Remove this method entirely if/when everyone agrees with my
218 # position.
219 return proto_field_name
220
221
222def _VerifyExtensionHandle(message, extension_handle):
223 """Verify that the given extension handle is valid."""
224
225 if not isinstance(extension_handle, _FieldDescriptor):
226 raise KeyError('HasExtension() expects an extension handle, got: %s' %
227 extension_handle)
228
229 if not extension_handle.is_extension:
230 raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
231
232 if not extension_handle.containing_type:
233 raise KeyError('"%s" is missing a containing_type.'
234 % extension_handle.full_name)
235
236 if extension_handle.containing_type is not message.DESCRIPTOR:
237 raise KeyError('Extension "%s" extends message type "%s", but this '
238 'message is of type "%s".' %
239 (extension_handle.full_name,
240 extension_handle.containing_type.full_name,
241 message.DESCRIPTOR.full_name))
242
243
244def _AddSlots(message_descriptor, dictionary):
245 """Adds a __slots__ entry to dictionary, containing the names of all valid
246 attributes for this message type.
247
248 Args:
249 message_descriptor: A Descriptor instance describing this message type.
250 dictionary: Class dictionary to which we'll add a '__slots__' entry.
251 """
252 dictionary['__slots__'] = ['_cached_byte_size',
253 '_cached_byte_size_dirty',
254 '_fields',
255 '_unknown_fields',
256 '_is_present_in_parent',
257 '_listener',
258 '_listener_for_children',
259 '__weakref__',
260 '_oneofs']
261
262
263def _IsMessageSetExtension(field):
264 return (field.is_extension and
265 field.containing_type.has_options and
266 field.containing_type.GetOptions().message_set_wire_format and
267 field.type == _FieldDescriptor.TYPE_MESSAGE and
268 field.label == _FieldDescriptor.LABEL_OPTIONAL)
269
270
271def _IsMapField(field):
272 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
273 field.message_type.has_options and
274 field.message_type.GetOptions().map_entry)
275
276
277def _IsMessageMapField(field):
278 value_type = field.message_type.fields_by_name["value"]
279 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
280
281
282def _AttachFieldHelpers(cls, field_descriptor):
283 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
284 is_packable = (is_repeated and
285 wire_format.IsTypePackable(field_descriptor.type))
286 if not is_packable:
287 is_packed = False
288 elif field_descriptor.containing_type.syntax == "proto2":
289 is_packed = (field_descriptor.has_options and
290 field_descriptor.GetOptions().packed)
291 else:
292 has_packed_false = (field_descriptor.has_options and
293 field_descriptor.GetOptions().HasField("packed") and
294 field_descriptor.GetOptions().packed == False)
295 is_packed = not has_packed_false
296 is_map_entry = _IsMapField(field_descriptor)
297
298 if is_map_entry:
299 field_encoder = encoder.MapEncoder(field_descriptor)
300 sizer = encoder.MapSizer(field_descriptor)
301 elif _IsMessageSetExtension(field_descriptor):
302 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
303 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
304 else:
305 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
306 field_descriptor.number, is_repeated, is_packed)
307 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
308 field_descriptor.number, is_repeated, is_packed)
309
310 field_descriptor._encoder = field_encoder
311 field_descriptor._sizer = sizer
312 field_descriptor._default_constructor = _DefaultValueConstructorForField(
313 field_descriptor)
314
315 def AddDecoder(wiretype, is_packed):
316 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
317 decode_type = field_descriptor.type
318 if (decode_type == _FieldDescriptor.TYPE_ENUM and
319 type_checkers.SupportsOpenEnums(field_descriptor)):
320 decode_type = _FieldDescriptor.TYPE_INT32
321
322 oneof_descriptor = None
323 if field_descriptor.containing_oneof is not None:
324 oneof_descriptor = field_descriptor
325
326 if is_map_entry:
327 is_message_map = _IsMessageMapField(field_descriptor)
328
329 field_decoder = decoder.MapDecoder(
330 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
331 is_message_map)
332 else:
333 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
334 field_descriptor.number, is_repeated, is_packed,
335 field_descriptor, field_descriptor._default_constructor)
336
337 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
338
339 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
340 False)
341
342 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
343 # To support wire compatibility of adding packed = true, add a decoder for
344 # packed values regardless of the field's options.
345 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
346
347
348def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
349 extension_dict = descriptor.extensions_by_name
350 for extension_name, extension_field in extension_dict.items():
351 assert extension_name not in dictionary
352 dictionary[extension_name] = extension_field
353
354
355def _AddEnumValues(descriptor, cls):
356 """Sets class-level attributes for all enum fields defined in this message.
357
358 Also exporting a class-level object that can name enum values.
359
360 Args:
361 descriptor: Descriptor object for this message type.
362 cls: Class we're constructing for this message type.
363 """
364 for enum_type in descriptor.enum_types:
365 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
366 for enum_value in enum_type.values:
367 setattr(cls, enum_value.name, enum_value.number)
368
369
370def _GetInitializeDefaultForMap(field):
371 if field.label != _FieldDescriptor.LABEL_REPEATED:
372 raise ValueError('map_entry set on non-repeated field %s' % (
373 field.name))
374 fields_by_name = field.message_type.fields_by_name
375 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
376
377 value_field = fields_by_name['value']
378 if _IsMessageMapField(field):
379 def MakeMessageMapDefault(message):
380 return containers.MessageMap(
381 message._listener_for_children, value_field.message_type, key_checker)
382 return MakeMessageMapDefault
383 else:
384 value_checker = type_checkers.GetTypeChecker(value_field)
385 def MakePrimitiveMapDefault(message):
386 return containers.ScalarMap(
387 message._listener_for_children, key_checker, value_checker)
388 return MakePrimitiveMapDefault
389
390def _DefaultValueConstructorForField(field):
391 """Returns a function which returns a default value for a field.
392
393 Args:
394 field: FieldDescriptor object for this field.
395
396 The returned function has one argument:
397 message: Message instance containing this field, or a weakref proxy
398 of same.
399
400 That function in turn returns a default value for this field. The default
401 value may refer back to |message| via a weak reference.
402 """
403
404 if _IsMapField(field):
405 return _GetInitializeDefaultForMap(field)
406
407 if field.label == _FieldDescriptor.LABEL_REPEATED:
408 if field.has_default_value and field.default_value != []:
409 raise ValueError('Repeated field default value not empty list: %s' % (
410 field.default_value))
411 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
412 # We can't look at _concrete_class yet since it might not have
413 # been set. (Depends on order in which we initialize the classes).
414 message_type = field.message_type
415 def MakeRepeatedMessageDefault(message):
416 return containers.RepeatedCompositeFieldContainer(
417 message._listener_for_children, field.message_type)
418 return MakeRepeatedMessageDefault
419 else:
420 type_checker = type_checkers.GetTypeChecker(field)
421 def MakeRepeatedScalarDefault(message):
422 return containers.RepeatedScalarFieldContainer(
423 message._listener_for_children, type_checker)
424 return MakeRepeatedScalarDefault
425
426 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
427 # _concrete_class may not yet be initialized.
428 message_type = field.message_type
429 def MakeSubMessageDefault(message):
430 result = message_type._concrete_class()
431 result._SetListener(
432 _OneofListener(message, field)
433 if field.containing_oneof is not None
434 else message._listener_for_children)
435 return result
436 return MakeSubMessageDefault
437
438 def MakeScalarDefault(message):
439 # TODO(protobuf-team): This may be broken since there may not be
440 # default_value. Combine with has_default_value somehow.
441 return field.default_value
442 return MakeScalarDefault
443
444
445def _ReraiseTypeErrorWithFieldName(message_name, field_name):
446 """Re-raise the currently-handled TypeError with the field name added."""
447 exc = sys.exc_info()[1]
448 if len(exc.args) == 1 and type(exc) is TypeError:
449 # simple TypeError; add field name to exception message
450 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
451
452 # re-raise possibly-amended exception with original traceback:
453 six.reraise(type(exc), exc, sys.exc_info()[2])
454
455
456def _AddInitMethod(message_descriptor, cls):
457 """Adds an __init__ method to cls."""
458
459 def _GetIntegerEnumValue(enum_type, value):
460 """Convert a string or integer enum value to an integer.
461
462 If the value is a string, it is converted to the enum value in
463 enum_type with the same name. If the value is not a string, it's
464 returned as-is. (No conversion or bounds-checking is done.)
465 """
466 if isinstance(value, six.string_types):
467 try:
468 return enum_type.values_by_name[value].number
469 except KeyError:
470 raise ValueError('Enum type %s: unknown label "%s"' % (
471 enum_type.full_name, value))
472 return value
473
474 def init(self, **kwargs):
475 self._cached_byte_size = 0
476 self._cached_byte_size_dirty = len(kwargs) > 0
477 self._fields = {}
478 # Contains a mapping from oneof field descriptors to the descriptor
479 # of the currently set field in that oneof field.
480 self._oneofs = {}
481
482 # _unknown_fields is () when empty for efficiency, and will be turned into
483 # a list if fields are added.
484 self._unknown_fields = ()
485 self._is_present_in_parent = False
486 self._listener = message_listener_mod.NullMessageListener()
487 self._listener_for_children = _Listener(self)
488 for field_name, field_value in kwargs.items():
489 field = _GetFieldByName(message_descriptor, field_name)
490 if field is None:
491 raise TypeError("%s() got an unexpected keyword argument '%s'" %
492 (message_descriptor.name, field_name))
493 if field.label == _FieldDescriptor.LABEL_REPEATED:
494 copy = field._default_constructor(self)
495 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
496 if _IsMapField(field):
497 if _IsMessageMapField(field):
498 for key in field_value:
499 copy[key].MergeFrom(field_value[key])
500 else:
501 copy.update(field_value)
502 else:
503 for val in field_value:
504 if isinstance(val, dict):
505 copy.add(**val)
506 else:
507 copy.add().MergeFrom(val)
508 else: # Scalar
509 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
510 field_value = [_GetIntegerEnumValue(field.enum_type, val)
511 for val in field_value]
512 copy.extend(field_value)
513 self._fields[field] = copy
514 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
515 copy = field._default_constructor(self)
516 new_val = field_value
517 if isinstance(field_value, dict):
518 new_val = field.message_type._concrete_class(**field_value)
519 try:
520 copy.MergeFrom(new_val)
521 except TypeError:
522 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
523 self._fields[field] = copy
524 else:
525 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
526 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
527 try:
528 setattr(self, field_name, field_value)
529 except TypeError:
530 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
531
532 init.__module__ = None
533 init.__doc__ = None
534 cls.__init__ = init
535
536
537def _GetFieldByName(message_descriptor, field_name):
538 """Returns a field descriptor by field name.
539
540 Args:
541 message_descriptor: A Descriptor describing all fields in message.
542 field_name: The name of the field to retrieve.
543 Returns:
544 The field descriptor associated with the field name.
545 """
546 try:
547 return message_descriptor.fields_by_name[field_name]
548 except KeyError:
549 raise ValueError('Protocol message %s has no "%s" field.' %
550 (message_descriptor.name, field_name))
551
552
553def _AddPropertiesForFields(descriptor, cls):
554 """Adds properties for all fields in this protocol message type."""
555 for field in descriptor.fields:
556 _AddPropertiesForField(field, cls)
557
558 if descriptor.is_extendable:
559 # _ExtensionDict is just an adaptor with no state so we allocate a new one
560 # every time it is accessed.
561 cls.Extensions = property(lambda self: _ExtensionDict(self))
562
563
564def _AddPropertiesForField(field, cls):
565 """Adds a public property for a protocol message field.
566 Clients can use this property to get and (in the case
567 of non-repeated scalar fields) directly set the value
568 of a protocol message field.
569
570 Args:
571 field: A FieldDescriptor for this field.
572 cls: The class we're constructing.
573 """
574 # Catch it if we add other types that we should
575 # handle specially here.
576 assert _FieldDescriptor.MAX_CPPTYPE == 10
577
578 constant_name = field.name.upper() + "_FIELD_NUMBER"
579 setattr(cls, constant_name, field.number)
580
581 if field.label == _FieldDescriptor.LABEL_REPEATED:
582 _AddPropertiesForRepeatedField(field, cls)
583 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
584 _AddPropertiesForNonRepeatedCompositeField(field, cls)
585 else:
586 _AddPropertiesForNonRepeatedScalarField(field, cls)
587
588
589def _AddPropertiesForRepeatedField(field, cls):
590 """Adds a public property for a "repeated" protocol message field. Clients
591 can use this property to get the value of the field, which will be either a
592 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
593 below).
594
595 Note that when clients add values to these containers, we perform
596 type-checking in the case of repeated scalar fields, and we also set any
597 necessary "has" bits as a side-effect.
598
599 Args:
600 field: A FieldDescriptor for this field.
601 cls: The class we're constructing.
602 """
603 proto_field_name = field.name
604 property_name = _PropertyName(proto_field_name)
605
606 def getter(self):
607 field_value = self._fields.get(field)
608 if field_value is None:
609 # Construct a new object to represent this field.
610 field_value = field._default_constructor(self)
611
612 # Atomically check if another thread has preempted us and, if not, swap
613 # in the new object we just created. If someone has preempted us, we
614 # take that object and discard ours.
615 # WARNING: We are relying on setdefault() being atomic. This is true
616 # in CPython but we haven't investigated others. This warning appears
617 # in several other locations in this file.
618 field_value = self._fields.setdefault(field, field_value)
619 return field_value
620 getter.__module__ = None
621 getter.__doc__ = 'Getter for %s.' % proto_field_name
622
623 # We define a setter just so we can throw an exception with a more
624 # helpful error message.
625 def setter(self, new_value):
626 raise AttributeError('Assignment not allowed to repeated field '
627 '"%s" in protocol message object.' % proto_field_name)
628
629 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
630 setattr(cls, property_name, property(getter, setter, doc=doc))
631
632
633def _AddPropertiesForNonRepeatedScalarField(field, cls):
634 """Adds a public property for a nonrepeated, scalar protocol message field.
635 Clients can use this property to get and directly set the value of the field.
636 Note that when the client sets the value of a field by using this property,
637 all necessary "has" bits are set as a side-effect, and we also perform
638 type-checking.
639
640 Args:
641 field: A FieldDescriptor for this field.
642 cls: The class we're constructing.
643 """
644 proto_field_name = field.name
645 property_name = _PropertyName(proto_field_name)
646 type_checker = type_checkers.GetTypeChecker(field)
647 default_value = field.default_value
648 valid_values = set()
649 is_proto3 = field.containing_type.syntax == "proto3"
650
651 def getter(self):
652 # TODO(protobuf-team): This may be broken since there may not be
653 # default_value. Combine with has_default_value somehow.
654 return self._fields.get(field, default_value)
655 getter.__module__ = None
656 getter.__doc__ = 'Getter for %s.' % proto_field_name
657
658 clear_when_set_to_default = is_proto3 and not field.containing_oneof
659
660 def field_setter(self, new_value):
661 # pylint: disable=protected-access
662 # Testing the value for truthiness captures all of the proto3 defaults
663 # (0, 0.0, enum 0, and False).
664 new_value = type_checker.CheckValue(new_value)
665 if clear_when_set_to_default and not new_value:
666 self._fields.pop(field, None)
667 else:
668 self._fields[field] = new_value
669 # Check _cached_byte_size_dirty inline to improve performance, since scalar
670 # setters are called frequently.
671 if not self._cached_byte_size_dirty:
672 self._Modified()
673
674 if field.containing_oneof:
675 def setter(self, new_value):
676 field_setter(self, new_value)
677 self._UpdateOneofState(field)
678 else:
679 setter = field_setter
680
681 setter.__module__ = None
682 setter.__doc__ = 'Setter for %s.' % proto_field_name
683
684 # Add a property to encapsulate the getter/setter.
685 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
686 setattr(cls, property_name, property(getter, setter, doc=doc))
687
688
689def _AddPropertiesForNonRepeatedCompositeField(field, cls):
690 """Adds a public property for a nonrepeated, composite protocol message field.
691 A composite field is a "group" or "message" field.
692
693 Clients can use this property to get the value of the field, but cannot
694 assign to the property directly.
695
696 Args:
697 field: A FieldDescriptor for this field.
698 cls: The class we're constructing.
699 """
700 # TODO(robinson): Remove duplication with similar method
701 # for non-repeated scalars.
702 proto_field_name = field.name
703 property_name = _PropertyName(proto_field_name)
704
705 def getter(self):
706 field_value = self._fields.get(field)
707 if field_value is None:
708 # Construct a new object to represent this field.
709 field_value = field._default_constructor(self)
710
711 # Atomically check if another thread has preempted us and, if not, swap
712 # in the new object we just created. If someone has preempted us, we
713 # take that object and discard ours.
714 # WARNING: We are relying on setdefault() being atomic. This is true
715 # in CPython but we haven't investigated others. This warning appears
716 # in several other locations in this file.
717 field_value = self._fields.setdefault(field, field_value)
718 return field_value
719 getter.__module__ = None
720 getter.__doc__ = 'Getter for %s.' % proto_field_name
721
722 # We define a setter just so we can throw an exception with a more
723 # helpful error message.
724 def setter(self, new_value):
725 raise AttributeError('Assignment not allowed to composite field '
726 '"%s" in protocol message object.' % proto_field_name)
727
728 # Add a property to encapsulate the getter.
729 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
730 setattr(cls, property_name, property(getter, setter, doc=doc))
731
732
733def _AddPropertiesForExtensions(descriptor, cls):
734 """Adds properties for all fields in this protocol message type."""
735 extension_dict = descriptor.extensions_by_name
736 for extension_name, extension_field in extension_dict.items():
737 constant_name = extension_name.upper() + "_FIELD_NUMBER"
738 setattr(cls, constant_name, extension_field.number)
739
740
741def _AddStaticMethods(cls):
742 # TODO(robinson): This probably needs to be thread-safe(?)
743 def RegisterExtension(extension_handle):
744 extension_handle.containing_type = cls.DESCRIPTOR
745 _AttachFieldHelpers(cls, extension_handle)
746
747 # Try to insert our extension, failing if an extension with the same number
748 # already exists.
749 actual_handle = cls._extensions_by_number.setdefault(
750 extension_handle.number, extension_handle)
751 if actual_handle is not extension_handle:
752 raise AssertionError(
753 'Extensions "%s" and "%s" both try to extend message type "%s" with '
754 'field number %d.' %
755 (extension_handle.full_name, actual_handle.full_name,
756 cls.DESCRIPTOR.full_name, extension_handle.number))
757
758 cls._extensions_by_name[extension_handle.full_name] = extension_handle
759
760 handle = extension_handle # avoid line wrapping
761 if _IsMessageSetExtension(handle):
762 # MessageSet extension. Also register under type name.
763 cls._extensions_by_name[
764 extension_handle.message_type.full_name] = extension_handle
765
766 cls.RegisterExtension = staticmethod(RegisterExtension)
767
768 def FromString(s):
769 message = cls()
770 message.MergeFromString(s)
771 return message
772 cls.FromString = staticmethod(FromString)
773
774
775def _IsPresent(item):
776 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
777 value should be included in the list returned by ListFields()."""
778
779 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
780 return bool(item[1])
781 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
782 return item[1]._is_present_in_parent
783 else:
784 return True
785
786
787def _AddListFieldsMethod(message_descriptor, cls):
788 """Helper for _AddMessageMethods()."""
789
790 def ListFields(self):
791 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
792 all_fields.sort(key = lambda item: item[0].number)
793 return all_fields
794
795 cls.ListFields = ListFields
796
797_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
798_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
799
800def _AddHasFieldMethod(message_descriptor, cls):
801 """Helper for _AddMessageMethods()."""
802
803 is_proto3 = (message_descriptor.syntax == "proto3")
804 error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
805
806 hassable_fields = {}
807 for field in message_descriptor.fields:
808 if field.label == _FieldDescriptor.LABEL_REPEATED:
809 continue
810 # For proto3, only submessages and fields inside a oneof have presence.
811 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
812 not field.containing_oneof):
813 continue
814 hassable_fields[field.name] = field
815
816 if not is_proto3:
817 # Fields inside oneofs are never repeated (enforced by the compiler).
818 for oneof in message_descriptor.oneofs:
819 hassable_fields[oneof.name] = oneof
820
821 def HasField(self, field_name):
822 try:
823 field = hassable_fields[field_name]
824 except KeyError:
825 raise ValueError(error_msg % field_name)
826
827 if isinstance(field, descriptor_mod.OneofDescriptor):
828 try:
829 return HasField(self, self._oneofs[field].name)
830 except KeyError:
831 return False
832 else:
833 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
834 value = self._fields.get(field)
835 return value is not None and value._is_present_in_parent
836 else:
837 return field in self._fields
838
839 cls.HasField = HasField
840
841
842def _AddClearFieldMethod(message_descriptor, cls):
843 """Helper for _AddMessageMethods()."""
844 def ClearField(self, field_name):
845 try:
846 field = message_descriptor.fields_by_name[field_name]
847 except KeyError:
848 try:
849 field = message_descriptor.oneofs_by_name[field_name]
850 if field in self._oneofs:
851 field = self._oneofs[field]
852 else:
853 return
854 except KeyError:
855 raise ValueError('Protocol message %s() has no "%s" field.' %
856 (message_descriptor.name, field_name))
857
858 if field in self._fields:
859 # To match the C++ implementation, we need to invalidate iterators
860 # for map fields when ClearField() happens.
861 if hasattr(self._fields[field], 'InvalidateIterators'):
862 self._fields[field].InvalidateIterators()
863
864 # Note: If the field is a sub-message, its listener will still point
865 # at us. That's fine, because the worst than can happen is that it
866 # will call _Modified() and invalidate our byte size. Big deal.
867 del self._fields[field]
868
869 if self._oneofs.get(field.containing_oneof, None) is field:
870 del self._oneofs[field.containing_oneof]
871
872 # Always call _Modified() -- even if nothing was changed, this is
873 # a mutating method, and thus calling it should cause the field to become
874 # present in the parent message.
875 self._Modified()
876
877 cls.ClearField = ClearField
878
879
880def _AddClearExtensionMethod(cls):
881 """Helper for _AddMessageMethods()."""
882 def ClearExtension(self, extension_handle):
883 _VerifyExtensionHandle(self, extension_handle)
884
885 # Similar to ClearField(), above.
886 if extension_handle in self._fields:
887 del self._fields[extension_handle]
888 self._Modified()
889 cls.ClearExtension = ClearExtension
890
891
892def _AddClearMethod(message_descriptor, cls):
893 """Helper for _AddMessageMethods()."""
894 def Clear(self):
895 # Clear fields.
896 self._fields = {}
897 self._unknown_fields = ()
898 self._oneofs = {}
899 self._Modified()
900 cls.Clear = Clear
901
902
903def _AddHasExtensionMethod(cls):
904 """Helper for _AddMessageMethods()."""
905 def HasExtension(self, extension_handle):
906 _VerifyExtensionHandle(self, extension_handle)
907 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
908 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
909
910 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
911 value = self._fields.get(extension_handle)
912 return value is not None and value._is_present_in_parent
913 else:
914 return extension_handle in self._fields
915 cls.HasExtension = HasExtension
916
917def _InternalUnpackAny(msg):
918 """Unpacks Any message and returns the unpacked message.
919
920 This internal method is differnt from public Any Unpack method which takes
921 the target message as argument. _InternalUnpackAny method does not have
922 target message type and need to find the message type in descriptor pool.
923
924 Args:
925 msg: An Any message to be unpacked.
926
927 Returns:
928 The unpacked message.
929 """
930 type_url = msg.type_url
931 db = symbol_database.Default()
932
933 if not type_url:
934 return None
935
936 # TODO(haberman): For now we just strip the hostname. Better logic will be
937 # required.
938 type_name = type_url.split("/")[-1]
939 descriptor = db.pool.FindMessageTypeByName(type_name)
940
941 if descriptor is None:
942 return None
943
944 message_class = db.GetPrototype(descriptor)
945 message = message_class()
946
947 message.ParseFromString(msg.value)
948 return message
949
950def _AddEqualsMethod(message_descriptor, cls):
951 """Helper for _AddMessageMethods()."""
952 def __eq__(self, other):
953 if (not isinstance(other, message_mod.Message) or
954 other.DESCRIPTOR != self.DESCRIPTOR):
955 return False
956
957 if self is other:
958 return True
959
960 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
961 any_a = _InternalUnpackAny(self)
962 any_b = _InternalUnpackAny(other)
963 if any_a and any_b:
964 return any_a == any_b
965
966 if not self.ListFields() == other.ListFields():
967 return False
968
969 # Sort unknown fields because their order shouldn't affect equality test.
970 unknown_fields = list(self._unknown_fields)
971 unknown_fields.sort()
972 other_unknown_fields = list(other._unknown_fields)
973 other_unknown_fields.sort()
974
975 return unknown_fields == other_unknown_fields
976
977 cls.__eq__ = __eq__
978
979
980def _AddStrMethod(message_descriptor, cls):
981 """Helper for _AddMessageMethods()."""
982 def __str__(self):
983 return text_format.MessageToString(self)
984 cls.__str__ = __str__
985
986
987def _AddReprMethod(message_descriptor, cls):
988 """Helper for _AddMessageMethods()."""
989 def __repr__(self):
990 return text_format.MessageToString(self)
991 cls.__repr__ = __repr__
992
993
994def _AddUnicodeMethod(unused_message_descriptor, cls):
995 """Helper for _AddMessageMethods()."""
996
997 def __unicode__(self):
998 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
999 cls.__unicode__ = __unicode__
1000
1001
1002def _AddSetListenerMethod(cls):
1003 """Helper for _AddMessageMethods()."""
1004 def SetListener(self, listener):
1005 if listener is None:
1006 self._listener = message_listener_mod.NullMessageListener()
1007 else:
1008 self._listener = listener
1009 cls._SetListener = SetListener
1010
1011
1012def _BytesForNonRepeatedElement(value, field_number, field_type):
1013 """Returns the number of bytes needed to serialize a non-repeated element.
1014 The returned byte count includes space for tag information and any
1015 other additional space associated with serializing value.
1016
1017 Args:
1018 value: Value we're serializing.
1019 field_number: Field number of this value. (Since the field number
1020 is stored as part of a varint-encoded tag, this has an impact
1021 on the total bytes required to serialize the value).
1022 field_type: The type of the field. One of the TYPE_* constants
1023 within FieldDescriptor.
1024 """
1025 try:
1026 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1027 return fn(field_number, value)
1028 except KeyError:
1029 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1030
1031
1032def _AddByteSizeMethod(message_descriptor, cls):
1033 """Helper for _AddMessageMethods()."""
1034
1035 def ByteSize(self):
1036 if not self._cached_byte_size_dirty:
1037 return self._cached_byte_size
1038
1039 size = 0
1040 for field_descriptor, field_value in self.ListFields():
1041 size += field_descriptor._sizer(field_value)
1042
1043 for tag_bytes, value_bytes in self._unknown_fields:
1044 size += len(tag_bytes) + len(value_bytes)
1045
1046 self._cached_byte_size = size
1047 self._cached_byte_size_dirty = False
1048 self._listener_for_children.dirty = False
1049 return size
1050
1051 cls.ByteSize = ByteSize
1052
1053
1054def _AddSerializeToStringMethod(message_descriptor, cls):
1055 """Helper for _AddMessageMethods()."""
1056
1057 def SerializeToString(self):
1058 # Check if the message has all of its required fields set.
1059 errors = []
1060 if not self.IsInitialized():
1061 raise message_mod.EncodeError(
1062 'Message %s is missing required fields: %s' % (
1063 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
1064 return self.SerializePartialToString()
1065 cls.SerializeToString = SerializeToString
1066
1067
1068def _AddSerializePartialToStringMethod(message_descriptor, cls):
1069 """Helper for _AddMessageMethods()."""
1070
1071 def SerializePartialToString(self):
1072 out = BytesIO()
1073 self._InternalSerialize(out.write)
1074 return out.getvalue()
1075 cls.SerializePartialToString = SerializePartialToString
1076
1077 def InternalSerialize(self, write_bytes):
1078 for field_descriptor, field_value in self.ListFields():
1079 field_descriptor._encoder(write_bytes, field_value)
1080 for tag_bytes, value_bytes in self._unknown_fields:
1081 write_bytes(tag_bytes)
1082 write_bytes(value_bytes)
1083 cls._InternalSerialize = InternalSerialize
1084
1085
1086def _AddMergeFromStringMethod(message_descriptor, cls):
1087 """Helper for _AddMessageMethods()."""
1088 def MergeFromString(self, serialized):
1089 length = len(serialized)
1090 try:
1091 if self._InternalParse(serialized, 0, length) != length:
1092 # The only reason _InternalParse would return early is if it
1093 # encountered an end-group tag.
1094 raise message_mod.DecodeError('Unexpected end-group tag.')
1095 except (IndexError, TypeError):
1096 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1097 raise message_mod.DecodeError('Truncated message.')
1098 except struct.error as e:
1099 raise message_mod.DecodeError(e)
1100 return length # Return this for legacy reasons.
1101 cls.MergeFromString = MergeFromString
1102
1103 local_ReadTag = decoder.ReadTag
1104 local_SkipField = decoder.SkipField
1105 decoders_by_tag = cls._decoders_by_tag
1106 is_proto3 = message_descriptor.syntax == "proto3"
1107
1108 def InternalParse(self, buffer, pos, end):
1109 self._Modified()
1110 field_dict = self._fields
1111 unknown_field_list = self._unknown_fields
1112 while pos != end:
1113 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1114 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1115 if field_decoder is None:
1116 value_start_pos = new_pos
1117 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
1118 if new_pos == -1:
1119 return pos
1120 if not is_proto3:
1121 if not unknown_field_list:
1122 unknown_field_list = self._unknown_fields = []
1123 unknown_field_list.append(
1124 (tag_bytes, buffer[value_start_pos:new_pos]))
1125 pos = new_pos
1126 else:
1127 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1128 if field_desc:
1129 self._UpdateOneofState(field_desc)
1130 return pos
1131 cls._InternalParse = InternalParse
1132
1133
1134def _AddIsInitializedMethod(message_descriptor, cls):
1135 """Adds the IsInitialized and FindInitializationError methods to the
1136 protocol message class."""
1137
1138 required_fields = [field for field in message_descriptor.fields
1139 if field.label == _FieldDescriptor.LABEL_REQUIRED]
1140
1141 def IsInitialized(self, errors=None):
1142 """Checks if all required fields of a message are set.
1143
1144 Args:
1145 errors: A list which, if provided, will be populated with the field
1146 paths of all missing required fields.
1147
1148 Returns:
1149 True iff the specified message has all required fields set.
1150 """
1151
1152 # Performance is critical so we avoid HasField() and ListFields().
1153
1154 for field in required_fields:
1155 if (field not in self._fields or
1156 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1157 not self._fields[field]._is_present_in_parent)):
1158 if errors is not None:
1159 errors.extend(self.FindInitializationErrors())
1160 return False
1161
1162 for field, value in list(self._fields.items()): # dict can change size!
1163 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1164 if field.label == _FieldDescriptor.LABEL_REPEATED:
1165 if (field.message_type.has_options and
1166 field.message_type.GetOptions().map_entry):
1167 continue
1168 for element in value:
1169 if not element.IsInitialized():
1170 if errors is not None:
1171 errors.extend(self.FindInitializationErrors())
1172 return False
1173 elif value._is_present_in_parent and not value.IsInitialized():
1174 if errors is not None:
1175 errors.extend(self.FindInitializationErrors())
1176 return False
1177
1178 return True
1179
1180 cls.IsInitialized = IsInitialized
1181
1182 def FindInitializationErrors(self):
1183 """Finds required fields which are not initialized.
1184
1185 Returns:
1186 A list of strings. Each string is a path to an uninitialized field from
1187 the top-level message, e.g. "foo.bar[5].baz".
1188 """
1189
1190 errors = [] # simplify things
1191
1192 for field in required_fields:
1193 if not self.HasField(field.name):
1194 errors.append(field.name)
1195
1196 for field, value in self.ListFields():
1197 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1198 if field.is_extension:
1199 name = "(%s)" % field.full_name
1200 else:
1201 name = field.name
1202
1203 if _IsMapField(field):
1204 if _IsMessageMapField(field):
1205 for key in value:
1206 element = value[key]
1207 prefix = "%s[%s]." % (name, key)
1208 sub_errors = element.FindInitializationErrors()
1209 errors += [prefix + error for error in sub_errors]
1210 else:
1211 # ScalarMaps can't have any initialization errors.
1212 pass
1213 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1214 for i in range(len(value)):
1215 element = value[i]
1216 prefix = "%s[%d]." % (name, i)
1217 sub_errors = element.FindInitializationErrors()
1218 errors += [prefix + error for error in sub_errors]
1219 else:
1220 prefix = name + "."
1221 sub_errors = value.FindInitializationErrors()
1222 errors += [prefix + error for error in sub_errors]
1223
1224 return errors
1225
1226 cls.FindInitializationErrors = FindInitializationErrors
1227
1228
1229def _AddMergeFromMethod(cls):
1230 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1231 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1232
1233 def MergeFrom(self, msg):
1234 if not isinstance(msg, cls):
1235 raise TypeError(
1236 "Parameter to MergeFrom() must be instance of same class: "
1237 "expected %s got %s." % (cls.__name__, type(msg).__name__))
1238
1239 assert msg is not self
1240 self._Modified()
1241
1242 fields = self._fields
1243
1244 for field, value in msg._fields.items():
1245 if field.label == LABEL_REPEATED:
1246 field_value = fields.get(field)
1247 if field_value is None:
1248 # Construct a new object to represent this field.
1249 field_value = field._default_constructor(self)
1250 fields[field] = field_value
1251 field_value.MergeFrom(value)
1252 elif field.cpp_type == CPPTYPE_MESSAGE:
1253 if value._is_present_in_parent:
1254 field_value = fields.get(field)
1255 if field_value is None:
1256 # Construct a new object to represent this field.
1257 field_value = field._default_constructor(self)
1258 fields[field] = field_value
1259 field_value.MergeFrom(value)
1260 else:
1261 self._fields[field] = value
1262 if field.containing_oneof:
1263 self._UpdateOneofState(field)
1264
1265 if msg._unknown_fields:
1266 if not self._unknown_fields:
1267 self._unknown_fields = []
1268 self._unknown_fields.extend(msg._unknown_fields)
1269
1270 cls.MergeFrom = MergeFrom
1271
1272
1273def _AddWhichOneofMethod(message_descriptor, cls):
1274 def WhichOneof(self, oneof_name):
1275 """Returns the name of the currently set field inside a oneof, or None."""
1276 try:
1277 field = message_descriptor.oneofs_by_name[oneof_name]
1278 except KeyError:
1279 raise ValueError(
1280 'Protocol message has no oneof "%s" field.' % oneof_name)
1281
1282 nested_field = self._oneofs.get(field, None)
1283 if nested_field is not None and self.HasField(nested_field.name):
1284 return nested_field.name
1285 else:
1286 return None
1287
1288 cls.WhichOneof = WhichOneof
1289
1290
1291def _AddMessageMethods(message_descriptor, cls):
1292 """Adds implementations of all Message methods to cls."""
1293 _AddListFieldsMethod(message_descriptor, cls)
1294 _AddHasFieldMethod(message_descriptor, cls)
1295 _AddClearFieldMethod(message_descriptor, cls)
1296 if message_descriptor.is_extendable:
1297 _AddClearExtensionMethod(cls)
1298 _AddHasExtensionMethod(cls)
1299 _AddClearMethod(message_descriptor, cls)
1300 _AddEqualsMethod(message_descriptor, cls)
1301 _AddStrMethod(message_descriptor, cls)
1302 _AddReprMethod(message_descriptor, cls)
1303 _AddUnicodeMethod(message_descriptor, cls)
1304 _AddSetListenerMethod(cls)
1305 _AddByteSizeMethod(message_descriptor, cls)
1306 _AddSerializeToStringMethod(message_descriptor, cls)
1307 _AddSerializePartialToStringMethod(message_descriptor, cls)
1308 _AddMergeFromStringMethod(message_descriptor, cls)
1309 _AddIsInitializedMethod(message_descriptor, cls)
1310 _AddMergeFromMethod(cls)
1311 _AddWhichOneofMethod(message_descriptor, cls)
1312
1313
1314def _AddPrivateHelperMethods(message_descriptor, cls):
1315 """Adds implementation of private helper methods to cls."""
1316
1317 def Modified(self):
1318 """Sets the _cached_byte_size_dirty bit to true,
1319 and propagates this to our listener iff this was a state change.
1320 """
1321
1322 # Note: Some callers check _cached_byte_size_dirty before calling
1323 # _Modified() as an extra optimization. So, if this method is ever
1324 # changed such that it does stuff even when _cached_byte_size_dirty is
1325 # already true, the callers need to be updated.
1326 if not self._cached_byte_size_dirty:
1327 self._cached_byte_size_dirty = True
1328 self._listener_for_children.dirty = True
1329 self._is_present_in_parent = True
1330 self._listener.Modified()
1331
1332 def _UpdateOneofState(self, field):
1333 """Sets field as the active field in its containing oneof.
1334
1335 Will also delete currently active field in the oneof, if it is different
1336 from the argument. Does not mark the message as modified.
1337 """
1338 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1339 if other_field is not field:
1340 del self._fields[other_field]
1341 self._oneofs[field.containing_oneof] = field
1342
1343 cls._Modified = Modified
1344 cls.SetInParent = Modified
1345 cls._UpdateOneofState = _UpdateOneofState
1346
1347
1348class _Listener(object):
1349
1350 """MessageListener implementation that a parent message registers with its
1351 child message.
1352
1353 In order to support semantics like:
1354
1355 foo.bar.baz.qux = 23
1356 assert foo.HasField('bar')
1357
1358 ...child objects must have back references to their parents.
1359 This helper class is at the heart of this support.
1360 """
1361
1362 def __init__(self, parent_message):
1363 """Args:
1364 parent_message: The message whose _Modified() method we should call when
1365 we receive Modified() messages.
1366 """
1367 # This listener establishes a back reference from a child (contained) object
1368 # to its parent (containing) object. We make this a weak reference to avoid
1369 # creating cyclic garbage when the client finishes with the 'parent' object
1370 # in the tree.
1371 if isinstance(parent_message, weakref.ProxyType):
1372 self._parent_message_weakref = parent_message
1373 else:
1374 self._parent_message_weakref = weakref.proxy(parent_message)
1375
1376 # As an optimization, we also indicate directly on the listener whether
1377 # or not the parent message is dirty. This way we can avoid traversing
1378 # up the tree in the common case.
1379 self.dirty = False
1380
1381 def Modified(self):
1382 if self.dirty:
1383 return
1384 try:
1385 # Propagate the signal to our parents iff this is the first field set.
1386 self._parent_message_weakref._Modified()
1387 except ReferenceError:
1388 # We can get here if a client has kept a reference to a child object,
1389 # and is now setting a field on it, but the child's parent has been
1390 # garbage-collected. This is not an error.
1391 pass
1392
1393
1394class _OneofListener(_Listener):
1395 """Special listener implementation for setting composite oneof fields."""
1396
1397 def __init__(self, parent_message, field):
1398 """Args:
1399 parent_message: The message whose _Modified() method we should call when
1400 we receive Modified() messages.
1401 field: The descriptor of the field being set in the parent message.
1402 """
1403 super(_OneofListener, self).__init__(parent_message)
1404 self._field = field
1405
1406 def Modified(self):
1407 """Also updates the state of the containing oneof in the parent message."""
1408 try:
1409 self._parent_message_weakref._UpdateOneofState(self._field)
1410 super(_OneofListener, self).Modified()
1411 except ReferenceError:
1412 pass
1413
1414
1415# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1416# TODO(robinson): Unify error handling of "unknown extension" crap.
1417# TODO(robinson): Support iteritems()-style iteration over all
1418# extensions with the "has" bits turned on?
1419class _ExtensionDict(object):
1420
1421 """Dict-like container for supporting an indexable "Extensions"
1422 field on proto instances.
1423
1424 Note that in all cases we expect extension handles to be
1425 FieldDescriptors.
1426 """
1427
1428 def __init__(self, extended_message):
1429 """extended_message: Message instance for which we are the Extensions dict.
1430 """
1431
1432 self._extended_message = extended_message
1433
1434 def __getitem__(self, extension_handle):
1435 """Returns the current value of the given extension handle."""
1436
1437 _VerifyExtensionHandle(self._extended_message, extension_handle)
1438
1439 result = self._extended_message._fields.get(extension_handle)
1440 if result is not None:
1441 return result
1442
1443 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1444 result = extension_handle._default_constructor(self._extended_message)
1445 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1446 result = extension_handle.message_type._concrete_class()
1447 try:
1448 result._SetListener(self._extended_message._listener_for_children)
1449 except ReferenceError:
1450 pass
1451 else:
1452 # Singular scalar -- just return the default without inserting into the
1453 # dict.
1454 return extension_handle.default_value
1455
1456 # Atomically check if another thread has preempted us and, if not, swap
1457 # in the new object we just created. If someone has preempted us, we
1458 # take that object and discard ours.
1459 # WARNING: We are relying on setdefault() being atomic. This is true
1460 # in CPython but we haven't investigated others. This warning appears
1461 # in several other locations in this file.
1462 result = self._extended_message._fields.setdefault(
1463 extension_handle, result)
1464
1465 return result
1466
1467 def __eq__(self, other):
1468 if not isinstance(other, self.__class__):
1469 return False
1470
1471 my_fields = self._extended_message.ListFields()
1472 other_fields = other._extended_message.ListFields()
1473
1474 # Get rid of non-extension fields.
1475 my_fields = [ field for field in my_fields if field.is_extension ]
1476 other_fields = [ field for field in other_fields if field.is_extension ]
1477
1478 return my_fields == other_fields
1479
1480 def __ne__(self, other):
1481 return not self == other
1482
1483 def __hash__(self):
1484 raise TypeError('unhashable object')
1485
1486 # Note that this is only meaningful for non-repeated, scalar extension
1487 # fields. Note also that we may have to call _Modified() when we do
1488 # successfully set a field this way, to set any necssary "has" bits in the
1489 # ancestors of the extended message.
1490 def __setitem__(self, extension_handle, value):
1491 """If extension_handle specifies a non-repeated, scalar extension
1492 field, sets the value of that field.
1493 """
1494
1495 _VerifyExtensionHandle(self._extended_message, extension_handle)
1496
1497 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1498 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1499 raise TypeError(
1500 'Cannot assign to extension "%s" because it is a repeated or '
1501 'composite type.' % extension_handle.full_name)
1502
1503 # It's slightly wasteful to lookup the type checker each time,
1504 # but we expect this to be a vanishingly uncommon case anyway.
1505 type_checker = type_checkers.GetTypeChecker(extension_handle)
1506 # pylint: disable=protected-access
1507 self._extended_message._fields[extension_handle] = (
1508 type_checker.CheckValue(value))
1509 self._extended_message._Modified()
1510
1511 def _FindExtensionByName(self, name):
1512 """Tries to find a known extension with the specified name.
1513
1514 Args:
1515 name: Extension full name.
1516
1517 Returns:
1518 Extension field descriptor.
1519 """
1520 return self._extended_message._extensions_by_name.get(name, None)