blob: 975e3b4d527fd41cd9f74c77a3804aa36b69194f [file] [log] [blame]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31# This code is meant to work on Python 2.4 and above only.
32#
33# TODO(robinson): Helpers for verbose, common checks like seeing if a
34# descriptor's cpp_type is CPPTYPE_MESSAGE.
35
36"""Contains a metaclass and helper functions used to create
37protocol message classes from Descriptor objects at runtime.
38
39Recall that a metaclass is the "type" of a class.
40(A class is to a metaclass what an instance is to a class.)
41
42In this case, we use the GeneratedProtocolMessageType metaclass
43to inject all the useful functionality into the classes
44output by the protocol compiler at compile-time.
45
46The upshot of all this is that the real implementation
47details for ALL pure-Python protocol buffers are *here in
48this file*.
49"""
50
51__author__ = 'robinson@google.com (Will Robinson)'
52
53from io import BytesIO
Brian Silverman9c614bc2016-02-15 20:20:02 -050054import struct
Austin Schuh40c16522018-10-28 20:27:54 -070055import sys
Brian Silverman9c614bc2016-02-15 20:20:02 -050056import weakref
57
58import six
Brian Silverman9c614bc2016-02-15 20:20:02 -050059
60# We use "as" to avoid name collisions with variables.
Austin Schuh40c16522018-10-28 20:27:54 -070061from google.protobuf.internal import api_implementation
Brian Silverman9c614bc2016-02-15 20:20:02 -050062from google.protobuf.internal import containers
63from google.protobuf.internal import decoder
64from google.protobuf.internal import encoder
65from google.protobuf.internal import enum_type_wrapper
66from google.protobuf.internal import message_listener as message_listener_mod
67from google.protobuf.internal import type_checkers
68from google.protobuf.internal import well_known_types
69from google.protobuf.internal import wire_format
70from google.protobuf import descriptor as descriptor_mod
71from google.protobuf import message as message_mod
Brian Silverman9c614bc2016-02-15 20:20:02 -050072from google.protobuf import text_format
73
74_FieldDescriptor = descriptor_mod.FieldDescriptor
75_AnyFullTypeName = 'google.protobuf.Any'
76
77
78class GeneratedProtocolMessageType(type):
79
80 """Metaclass for protocol message classes created at runtime from Descriptors.
81
82 We add implementations for all methods described in the Message class. We
83 also create properties to allow getting/setting all fields in the protocol
84 message. Finally, we create slots to prevent users from accidentally
85 "setting" nonexistent fields in the protocol message, which then wouldn't get
86 serialized / deserialized properly.
87
88 The protocol compiler currently uses this metaclass to create protocol
89 message classes at runtime. Clients can also manually create their own
90 classes at runtime, as in this example:
91
92 mydescriptor = Descriptor(.....)
Austin Schuh40c16522018-10-28 20:27:54 -070093 factory = symbol_database.Default()
94 factory.pool.AddDescriptor(mydescriptor)
95 MyProtoClass = factory.GetPrototype(mydescriptor)
Brian Silverman9c614bc2016-02-15 20:20:02 -050096 myproto_instance = MyProtoClass()
97 myproto.foo_field = 23
98 ...
Brian Silverman9c614bc2016-02-15 20:20:02 -050099 """
100
101 # Must be consistent with the protocol-compiler code in
102 # proto2/compiler/internal/generator.*.
103 _DESCRIPTOR_KEY = 'DESCRIPTOR'
104
105 def __new__(cls, name, bases, dictionary):
106 """Custom allocation for runtime-generated class types.
107
108 We override __new__ because this is apparently the only place
109 where we can meaningfully set __slots__ on the class we're creating(?).
110 (The interplay between metaclasses and slots is not very well-documented).
111
112 Args:
113 name: Name of the class (ignored, but required by the
114 metaclass protocol).
115 bases: Base classes of the class we're constructing.
116 (Should be message.Message). We ignore this field, but
117 it's required by the metaclass protocol
118 dictionary: The class dictionary of the class we're
119 constructing. dictionary[_DESCRIPTOR_KEY] must contain
120 a Descriptor object describing this protocol message
121 type.
122
123 Returns:
124 Newly-allocated class.
125 """
126 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
127 if descriptor.full_name in well_known_types.WKTBASES:
128 bases += (well_known_types.WKTBASES[descriptor.full_name],)
129 _AddClassAttributesForNestedExtensions(descriptor, dictionary)
130 _AddSlots(descriptor, dictionary)
131
132 superclass = super(GeneratedProtocolMessageType, cls)
133 new_class = superclass.__new__(cls, name, bases, dictionary)
134 return new_class
135
136 def __init__(cls, name, bases, dictionary):
137 """Here we perform the majority of our work on the class.
138 We add enum getters, an __init__ method, implementations
139 of all Message methods, and properties for all fields
140 in the protocol type.
141
142 Args:
143 name: Name of the class (ignored, but required by the
144 metaclass protocol).
145 bases: Base classes of the class we're constructing.
146 (Should be message.Message). We ignore this field, but
147 it's required by the metaclass protocol
148 dictionary: The class dictionary of the class we're
149 constructing. dictionary[_DESCRIPTOR_KEY] must contain
150 a Descriptor object describing this protocol message
151 type.
152 """
153 descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
154 cls._decoders_by_tag = {}
Brian Silverman9c614bc2016-02-15 20:20:02 -0500155 if (descriptor.has_options and
156 descriptor.GetOptions().message_set_wire_format):
157 cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
Austin Schuh40c16522018-10-28 20:27:54 -0700158 decoder.MessageSetItemDecoder(descriptor), None)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500159
160 # Attach stuff to each FieldDescriptor for quick lookup later on.
161 for field in descriptor.fields:
162 _AttachFieldHelpers(cls, field)
163
164 descriptor._concrete_class = cls # pylint: disable=protected-access
165 _AddEnumValues(descriptor, cls)
166 _AddInitMethod(descriptor, cls)
167 _AddPropertiesForFields(descriptor, cls)
168 _AddPropertiesForExtensions(descriptor, cls)
169 _AddStaticMethods(cls)
170 _AddMessageMethods(descriptor, cls)
171 _AddPrivateHelperMethods(descriptor, cls)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500172
173 superclass = super(GeneratedProtocolMessageType, cls)
174 superclass.__init__(name, bases, dictionary)
175
176
177# Stateless helpers for GeneratedProtocolMessageType below.
178# Outside clients should not access these directly.
179#
180# I opted not to make any of these methods on the metaclass, to make it more
181# clear that I'm not really using any state there and to keep clients from
182# thinking that they have direct access to these construction helpers.
183
184
185def _PropertyName(proto_field_name):
186 """Returns the name of the public property attribute which
187 clients can use to get and (in some cases) set the value
188 of a protocol message field.
189
190 Args:
191 proto_field_name: The protocol message field name, exactly
192 as it appears (or would appear) in a .proto file.
193 """
194 # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
195 # nnorwitz makes my day by writing:
196 # """
197 # FYI. See the keyword module in the stdlib. This could be as simple as:
198 #
199 # if keyword.iskeyword(proto_field_name):
200 # return proto_field_name + "_"
201 # return proto_field_name
202 # """
203 # Kenton says: The above is a BAD IDEA. People rely on being able to use
204 # getattr() and setattr() to reflectively manipulate field values. If we
205 # rename the properties, then every such user has to also make sure to apply
206 # the same transformation. Note that currently if you name a field "yield",
207 # you can still access it just fine using getattr/setattr -- it's not even
208 # that cumbersome to do so.
209 # TODO(kenton): Remove this method entirely if/when everyone agrees with my
210 # position.
211 return proto_field_name
212
213
214def _VerifyExtensionHandle(message, extension_handle):
215 """Verify that the given extension handle is valid."""
216
217 if not isinstance(extension_handle, _FieldDescriptor):
218 raise KeyError('HasExtension() expects an extension handle, got: %s' %
219 extension_handle)
220
221 if not extension_handle.is_extension:
222 raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
223
224 if not extension_handle.containing_type:
225 raise KeyError('"%s" is missing a containing_type.'
226 % extension_handle.full_name)
227
228 if extension_handle.containing_type is not message.DESCRIPTOR:
229 raise KeyError('Extension "%s" extends message type "%s", but this '
230 'message is of type "%s".' %
231 (extension_handle.full_name,
232 extension_handle.containing_type.full_name,
233 message.DESCRIPTOR.full_name))
234
235
236def _AddSlots(message_descriptor, dictionary):
237 """Adds a __slots__ entry to dictionary, containing the names of all valid
238 attributes for this message type.
239
240 Args:
241 message_descriptor: A Descriptor instance describing this message type.
242 dictionary: Class dictionary to which we'll add a '__slots__' entry.
243 """
244 dictionary['__slots__'] = ['_cached_byte_size',
245 '_cached_byte_size_dirty',
246 '_fields',
247 '_unknown_fields',
248 '_is_present_in_parent',
249 '_listener',
250 '_listener_for_children',
251 '__weakref__',
252 '_oneofs']
253
254
255def _IsMessageSetExtension(field):
256 return (field.is_extension and
257 field.containing_type.has_options and
258 field.containing_type.GetOptions().message_set_wire_format and
259 field.type == _FieldDescriptor.TYPE_MESSAGE and
260 field.label == _FieldDescriptor.LABEL_OPTIONAL)
261
262
263def _IsMapField(field):
264 return (field.type == _FieldDescriptor.TYPE_MESSAGE and
265 field.message_type.has_options and
266 field.message_type.GetOptions().map_entry)
267
268
269def _IsMessageMapField(field):
270 value_type = field.message_type.fields_by_name["value"]
271 return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
272
273
274def _AttachFieldHelpers(cls, field_descriptor):
275 is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
276 is_packable = (is_repeated and
277 wire_format.IsTypePackable(field_descriptor.type))
278 if not is_packable:
279 is_packed = False
280 elif field_descriptor.containing_type.syntax == "proto2":
281 is_packed = (field_descriptor.has_options and
282 field_descriptor.GetOptions().packed)
283 else:
284 has_packed_false = (field_descriptor.has_options and
285 field_descriptor.GetOptions().HasField("packed") and
286 field_descriptor.GetOptions().packed == False)
287 is_packed = not has_packed_false
288 is_map_entry = _IsMapField(field_descriptor)
289
290 if is_map_entry:
291 field_encoder = encoder.MapEncoder(field_descriptor)
Austin Schuh40c16522018-10-28 20:27:54 -0700292 sizer = encoder.MapSizer(field_descriptor,
293 _IsMessageMapField(field_descriptor))
Brian Silverman9c614bc2016-02-15 20:20:02 -0500294 elif _IsMessageSetExtension(field_descriptor):
295 field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
296 sizer = encoder.MessageSetItemSizer(field_descriptor.number)
297 else:
298 field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
299 field_descriptor.number, is_repeated, is_packed)
300 sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
301 field_descriptor.number, is_repeated, is_packed)
302
303 field_descriptor._encoder = field_encoder
304 field_descriptor._sizer = sizer
305 field_descriptor._default_constructor = _DefaultValueConstructorForField(
306 field_descriptor)
307
308 def AddDecoder(wiretype, is_packed):
309 tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
310 decode_type = field_descriptor.type
311 if (decode_type == _FieldDescriptor.TYPE_ENUM and
312 type_checkers.SupportsOpenEnums(field_descriptor)):
313 decode_type = _FieldDescriptor.TYPE_INT32
314
315 oneof_descriptor = None
316 if field_descriptor.containing_oneof is not None:
317 oneof_descriptor = field_descriptor
318
319 if is_map_entry:
320 is_message_map = _IsMessageMapField(field_descriptor)
321
322 field_decoder = decoder.MapDecoder(
323 field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
324 is_message_map)
325 else:
326 field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
327 field_descriptor.number, is_repeated, is_packed,
328 field_descriptor, field_descriptor._default_constructor)
329
330 cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
331
332 AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
333 False)
334
335 if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
336 # To support wire compatibility of adding packed = true, add a decoder for
337 # packed values regardless of the field's options.
338 AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
339
340
341def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
342 extension_dict = descriptor.extensions_by_name
343 for extension_name, extension_field in extension_dict.items():
344 assert extension_name not in dictionary
345 dictionary[extension_name] = extension_field
346
347
348def _AddEnumValues(descriptor, cls):
349 """Sets class-level attributes for all enum fields defined in this message.
350
351 Also exporting a class-level object that can name enum values.
352
353 Args:
354 descriptor: Descriptor object for this message type.
355 cls: Class we're constructing for this message type.
356 """
357 for enum_type in descriptor.enum_types:
358 setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
359 for enum_value in enum_type.values:
360 setattr(cls, enum_value.name, enum_value.number)
361
362
363def _GetInitializeDefaultForMap(field):
364 if field.label != _FieldDescriptor.LABEL_REPEATED:
365 raise ValueError('map_entry set on non-repeated field %s' % (
366 field.name))
367 fields_by_name = field.message_type.fields_by_name
368 key_checker = type_checkers.GetTypeChecker(fields_by_name['key'])
369
370 value_field = fields_by_name['value']
371 if _IsMessageMapField(field):
372 def MakeMessageMapDefault(message):
373 return containers.MessageMap(
Austin Schuh40c16522018-10-28 20:27:54 -0700374 message._listener_for_children, value_field.message_type, key_checker,
375 field.message_type)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500376 return MakeMessageMapDefault
377 else:
378 value_checker = type_checkers.GetTypeChecker(value_field)
379 def MakePrimitiveMapDefault(message):
380 return containers.ScalarMap(
Austin Schuh40c16522018-10-28 20:27:54 -0700381 message._listener_for_children, key_checker, value_checker,
382 field.message_type)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500383 return MakePrimitiveMapDefault
384
385def _DefaultValueConstructorForField(field):
386 """Returns a function which returns a default value for a field.
387
388 Args:
389 field: FieldDescriptor object for this field.
390
391 The returned function has one argument:
392 message: Message instance containing this field, or a weakref proxy
393 of same.
394
395 That function in turn returns a default value for this field. The default
396 value may refer back to |message| via a weak reference.
397 """
398
399 if _IsMapField(field):
400 return _GetInitializeDefaultForMap(field)
401
402 if field.label == _FieldDescriptor.LABEL_REPEATED:
403 if field.has_default_value and field.default_value != []:
404 raise ValueError('Repeated field default value not empty list: %s' % (
405 field.default_value))
406 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
407 # We can't look at _concrete_class yet since it might not have
408 # been set. (Depends on order in which we initialize the classes).
409 message_type = field.message_type
410 def MakeRepeatedMessageDefault(message):
411 return containers.RepeatedCompositeFieldContainer(
412 message._listener_for_children, field.message_type)
413 return MakeRepeatedMessageDefault
414 else:
415 type_checker = type_checkers.GetTypeChecker(field)
416 def MakeRepeatedScalarDefault(message):
417 return containers.RepeatedScalarFieldContainer(
418 message._listener_for_children, type_checker)
419 return MakeRepeatedScalarDefault
420
421 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
422 # _concrete_class may not yet be initialized.
423 message_type = field.message_type
424 def MakeSubMessageDefault(message):
425 result = message_type._concrete_class()
426 result._SetListener(
427 _OneofListener(message, field)
428 if field.containing_oneof is not None
429 else message._listener_for_children)
430 return result
431 return MakeSubMessageDefault
432
433 def MakeScalarDefault(message):
434 # TODO(protobuf-team): This may be broken since there may not be
435 # default_value. Combine with has_default_value somehow.
436 return field.default_value
437 return MakeScalarDefault
438
439
440def _ReraiseTypeErrorWithFieldName(message_name, field_name):
441 """Re-raise the currently-handled TypeError with the field name added."""
442 exc = sys.exc_info()[1]
443 if len(exc.args) == 1 and type(exc) is TypeError:
444 # simple TypeError; add field name to exception message
445 exc = TypeError('%s for field %s.%s' % (str(exc), message_name, field_name))
446
447 # re-raise possibly-amended exception with original traceback:
448 six.reraise(type(exc), exc, sys.exc_info()[2])
449
450
451def _AddInitMethod(message_descriptor, cls):
452 """Adds an __init__ method to cls."""
453
454 def _GetIntegerEnumValue(enum_type, value):
455 """Convert a string or integer enum value to an integer.
456
457 If the value is a string, it is converted to the enum value in
458 enum_type with the same name. If the value is not a string, it's
459 returned as-is. (No conversion or bounds-checking is done.)
460 """
461 if isinstance(value, six.string_types):
462 try:
463 return enum_type.values_by_name[value].number
464 except KeyError:
465 raise ValueError('Enum type %s: unknown label "%s"' % (
466 enum_type.full_name, value))
467 return value
468
469 def init(self, **kwargs):
470 self._cached_byte_size = 0
471 self._cached_byte_size_dirty = len(kwargs) > 0
472 self._fields = {}
473 # Contains a mapping from oneof field descriptors to the descriptor
474 # of the currently set field in that oneof field.
475 self._oneofs = {}
476
477 # _unknown_fields is () when empty for efficiency, and will be turned into
478 # a list if fields are added.
479 self._unknown_fields = ()
480 self._is_present_in_parent = False
481 self._listener = message_listener_mod.NullMessageListener()
482 self._listener_for_children = _Listener(self)
483 for field_name, field_value in kwargs.items():
484 field = _GetFieldByName(message_descriptor, field_name)
485 if field is None:
486 raise TypeError("%s() got an unexpected keyword argument '%s'" %
487 (message_descriptor.name, field_name))
Austin Schuh40c16522018-10-28 20:27:54 -0700488 if field_value is None:
489 # field=None is the same as no field at all.
490 continue
Brian Silverman9c614bc2016-02-15 20:20:02 -0500491 if field.label == _FieldDescriptor.LABEL_REPEATED:
492 copy = field._default_constructor(self)
493 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
494 if _IsMapField(field):
495 if _IsMessageMapField(field):
496 for key in field_value:
497 copy[key].MergeFrom(field_value[key])
498 else:
499 copy.update(field_value)
500 else:
501 for val in field_value:
502 if isinstance(val, dict):
503 copy.add(**val)
504 else:
505 copy.add().MergeFrom(val)
506 else: # Scalar
507 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
508 field_value = [_GetIntegerEnumValue(field.enum_type, val)
509 for val in field_value]
510 copy.extend(field_value)
511 self._fields[field] = copy
512 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
513 copy = field._default_constructor(self)
514 new_val = field_value
515 if isinstance(field_value, dict):
516 new_val = field.message_type._concrete_class(**field_value)
517 try:
518 copy.MergeFrom(new_val)
519 except TypeError:
520 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
521 self._fields[field] = copy
522 else:
523 if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
524 field_value = _GetIntegerEnumValue(field.enum_type, field_value)
525 try:
526 setattr(self, field_name, field_value)
527 except TypeError:
528 _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
529
530 init.__module__ = None
531 init.__doc__ = None
532 cls.__init__ = init
533
534
535def _GetFieldByName(message_descriptor, field_name):
536 """Returns a field descriptor by field name.
537
538 Args:
539 message_descriptor: A Descriptor describing all fields in message.
540 field_name: The name of the field to retrieve.
541 Returns:
542 The field descriptor associated with the field name.
543 """
544 try:
545 return message_descriptor.fields_by_name[field_name]
546 except KeyError:
547 raise ValueError('Protocol message %s has no "%s" field.' %
548 (message_descriptor.name, field_name))
549
550
551def _AddPropertiesForFields(descriptor, cls):
552 """Adds properties for all fields in this protocol message type."""
553 for field in descriptor.fields:
554 _AddPropertiesForField(field, cls)
555
556 if descriptor.is_extendable:
557 # _ExtensionDict is just an adaptor with no state so we allocate a new one
558 # every time it is accessed.
559 cls.Extensions = property(lambda self: _ExtensionDict(self))
560
561
562def _AddPropertiesForField(field, cls):
563 """Adds a public property for a protocol message field.
564 Clients can use this property to get and (in the case
565 of non-repeated scalar fields) directly set the value
566 of a protocol message field.
567
568 Args:
569 field: A FieldDescriptor for this field.
570 cls: The class we're constructing.
571 """
572 # Catch it if we add other types that we should
573 # handle specially here.
574 assert _FieldDescriptor.MAX_CPPTYPE == 10
575
576 constant_name = field.name.upper() + "_FIELD_NUMBER"
577 setattr(cls, constant_name, field.number)
578
579 if field.label == _FieldDescriptor.LABEL_REPEATED:
580 _AddPropertiesForRepeatedField(field, cls)
581 elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
582 _AddPropertiesForNonRepeatedCompositeField(field, cls)
583 else:
584 _AddPropertiesForNonRepeatedScalarField(field, cls)
585
586
587def _AddPropertiesForRepeatedField(field, cls):
588 """Adds a public property for a "repeated" protocol message field. Clients
589 can use this property to get the value of the field, which will be either a
590 _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
591 below).
592
593 Note that when clients add values to these containers, we perform
594 type-checking in the case of repeated scalar fields, and we also set any
595 necessary "has" bits as a side-effect.
596
597 Args:
598 field: A FieldDescriptor for this field.
599 cls: The class we're constructing.
600 """
601 proto_field_name = field.name
602 property_name = _PropertyName(proto_field_name)
603
604 def getter(self):
605 field_value = self._fields.get(field)
606 if field_value is None:
607 # Construct a new object to represent this field.
608 field_value = field._default_constructor(self)
609
610 # Atomically check if another thread has preempted us and, if not, swap
611 # in the new object we just created. If someone has preempted us, we
612 # take that object and discard ours.
613 # WARNING: We are relying on setdefault() being atomic. This is true
614 # in CPython but we haven't investigated others. This warning appears
615 # in several other locations in this file.
616 field_value = self._fields.setdefault(field, field_value)
617 return field_value
618 getter.__module__ = None
619 getter.__doc__ = 'Getter for %s.' % proto_field_name
620
621 # We define a setter just so we can throw an exception with a more
622 # helpful error message.
623 def setter(self, new_value):
624 raise AttributeError('Assignment not allowed to repeated field '
625 '"%s" in protocol message object.' % proto_field_name)
626
627 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
628 setattr(cls, property_name, property(getter, setter, doc=doc))
629
630
631def _AddPropertiesForNonRepeatedScalarField(field, cls):
632 """Adds a public property for a nonrepeated, scalar protocol message field.
633 Clients can use this property to get and directly set the value of the field.
634 Note that when the client sets the value of a field by using this property,
635 all necessary "has" bits are set as a side-effect, and we also perform
636 type-checking.
637
638 Args:
639 field: A FieldDescriptor for this field.
640 cls: The class we're constructing.
641 """
642 proto_field_name = field.name
643 property_name = _PropertyName(proto_field_name)
644 type_checker = type_checkers.GetTypeChecker(field)
645 default_value = field.default_value
646 valid_values = set()
647 is_proto3 = field.containing_type.syntax == "proto3"
648
649 def getter(self):
650 # TODO(protobuf-team): This may be broken since there may not be
651 # default_value. Combine with has_default_value somehow.
652 return self._fields.get(field, default_value)
653 getter.__module__ = None
654 getter.__doc__ = 'Getter for %s.' % proto_field_name
655
656 clear_when_set_to_default = is_proto3 and not field.containing_oneof
657
658 def field_setter(self, new_value):
659 # pylint: disable=protected-access
660 # Testing the value for truthiness captures all of the proto3 defaults
661 # (0, 0.0, enum 0, and False).
662 new_value = type_checker.CheckValue(new_value)
663 if clear_when_set_to_default and not new_value:
664 self._fields.pop(field, None)
665 else:
666 self._fields[field] = new_value
667 # Check _cached_byte_size_dirty inline to improve performance, since scalar
668 # setters are called frequently.
669 if not self._cached_byte_size_dirty:
670 self._Modified()
671
672 if field.containing_oneof:
673 def setter(self, new_value):
674 field_setter(self, new_value)
675 self._UpdateOneofState(field)
676 else:
677 setter = field_setter
678
679 setter.__module__ = None
680 setter.__doc__ = 'Setter for %s.' % proto_field_name
681
682 # Add a property to encapsulate the getter/setter.
683 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
684 setattr(cls, property_name, property(getter, setter, doc=doc))
685
686
687def _AddPropertiesForNonRepeatedCompositeField(field, cls):
688 """Adds a public property for a nonrepeated, composite protocol message field.
689 A composite field is a "group" or "message" field.
690
691 Clients can use this property to get the value of the field, but cannot
692 assign to the property directly.
693
694 Args:
695 field: A FieldDescriptor for this field.
696 cls: The class we're constructing.
697 """
698 # TODO(robinson): Remove duplication with similar method
699 # for non-repeated scalars.
700 proto_field_name = field.name
701 property_name = _PropertyName(proto_field_name)
702
703 def getter(self):
704 field_value = self._fields.get(field)
705 if field_value is None:
706 # Construct a new object to represent this field.
707 field_value = field._default_constructor(self)
708
709 # Atomically check if another thread has preempted us and, if not, swap
710 # in the new object we just created. If someone has preempted us, we
711 # take that object and discard ours.
712 # WARNING: We are relying on setdefault() being atomic. This is true
713 # in CPython but we haven't investigated others. This warning appears
714 # in several other locations in this file.
715 field_value = self._fields.setdefault(field, field_value)
716 return field_value
717 getter.__module__ = None
718 getter.__doc__ = 'Getter for %s.' % proto_field_name
719
720 # We define a setter just so we can throw an exception with a more
721 # helpful error message.
722 def setter(self, new_value):
723 raise AttributeError('Assignment not allowed to composite field '
724 '"%s" in protocol message object.' % proto_field_name)
725
726 # Add a property to encapsulate the getter.
727 doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
728 setattr(cls, property_name, property(getter, setter, doc=doc))
729
730
731def _AddPropertiesForExtensions(descriptor, cls):
732 """Adds properties for all fields in this protocol message type."""
733 extension_dict = descriptor.extensions_by_name
734 for extension_name, extension_field in extension_dict.items():
735 constant_name = extension_name.upper() + "_FIELD_NUMBER"
736 setattr(cls, constant_name, extension_field.number)
737
Austin Schuh40c16522018-10-28 20:27:54 -0700738 # TODO(amauryfa): Migrate all users of these attributes to functions like
739 # pool.FindExtensionByNumber(descriptor).
740 if descriptor.file is not None:
741 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
742 pool = descriptor.file.pool
743 cls._extensions_by_number = pool._extensions_by_number[descriptor]
744 cls._extensions_by_name = pool._extensions_by_name[descriptor]
Brian Silverman9c614bc2016-02-15 20:20:02 -0500745
746def _AddStaticMethods(cls):
747 # TODO(robinson): This probably needs to be thread-safe(?)
748 def RegisterExtension(extension_handle):
749 extension_handle.containing_type = cls.DESCRIPTOR
Austin Schuh40c16522018-10-28 20:27:54 -0700750 # TODO(amauryfa): Use cls.MESSAGE_FACTORY.pool when available.
751 cls.DESCRIPTOR.file.pool.AddExtensionDescriptor(extension_handle)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500752 _AttachFieldHelpers(cls, extension_handle)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500753 cls.RegisterExtension = staticmethod(RegisterExtension)
754
755 def FromString(s):
756 message = cls()
757 message.MergeFromString(s)
758 return message
759 cls.FromString = staticmethod(FromString)
760
761
762def _IsPresent(item):
763 """Given a (FieldDescriptor, value) tuple from _fields, return true if the
764 value should be included in the list returned by ListFields()."""
765
766 if item[0].label == _FieldDescriptor.LABEL_REPEATED:
767 return bool(item[1])
768 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
769 return item[1]._is_present_in_parent
770 else:
771 return True
772
773
774def _AddListFieldsMethod(message_descriptor, cls):
775 """Helper for _AddMessageMethods()."""
776
777 def ListFields(self):
778 all_fields = [item for item in self._fields.items() if _IsPresent(item)]
779 all_fields.sort(key = lambda item: item[0].number)
780 return all_fields
781
782 cls.ListFields = ListFields
783
784_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"'
785_Proto2HasError = 'Protocol message has no non-repeated field "%s"'
786
787def _AddHasFieldMethod(message_descriptor, cls):
788 """Helper for _AddMessageMethods()."""
789
790 is_proto3 = (message_descriptor.syntax == "proto3")
791 error_msg = _Proto3HasError if is_proto3 else _Proto2HasError
792
793 hassable_fields = {}
794 for field in message_descriptor.fields:
795 if field.label == _FieldDescriptor.LABEL_REPEATED:
796 continue
797 # For proto3, only submessages and fields inside a oneof have presence.
798 if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and
799 not field.containing_oneof):
800 continue
801 hassable_fields[field.name] = field
802
803 if not is_proto3:
804 # Fields inside oneofs are never repeated (enforced by the compiler).
805 for oneof in message_descriptor.oneofs:
806 hassable_fields[oneof.name] = oneof
807
808 def HasField(self, field_name):
809 try:
810 field = hassable_fields[field_name]
811 except KeyError:
812 raise ValueError(error_msg % field_name)
813
814 if isinstance(field, descriptor_mod.OneofDescriptor):
815 try:
816 return HasField(self, self._oneofs[field].name)
817 except KeyError:
818 return False
819 else:
820 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
821 value = self._fields.get(field)
822 return value is not None and value._is_present_in_parent
823 else:
824 return field in self._fields
825
826 cls.HasField = HasField
827
828
829def _AddClearFieldMethod(message_descriptor, cls):
830 """Helper for _AddMessageMethods()."""
831 def ClearField(self, field_name):
832 try:
833 field = message_descriptor.fields_by_name[field_name]
834 except KeyError:
835 try:
836 field = message_descriptor.oneofs_by_name[field_name]
837 if field in self._oneofs:
838 field = self._oneofs[field]
839 else:
840 return
841 except KeyError:
842 raise ValueError('Protocol message %s() has no "%s" field.' %
843 (message_descriptor.name, field_name))
844
845 if field in self._fields:
846 # To match the C++ implementation, we need to invalidate iterators
847 # for map fields when ClearField() happens.
848 if hasattr(self._fields[field], 'InvalidateIterators'):
849 self._fields[field].InvalidateIterators()
850
851 # Note: If the field is a sub-message, its listener will still point
852 # at us. That's fine, because the worst than can happen is that it
853 # will call _Modified() and invalidate our byte size. Big deal.
854 del self._fields[field]
855
856 if self._oneofs.get(field.containing_oneof, None) is field:
857 del self._oneofs[field.containing_oneof]
858
859 # Always call _Modified() -- even if nothing was changed, this is
860 # a mutating method, and thus calling it should cause the field to become
861 # present in the parent message.
862 self._Modified()
863
864 cls.ClearField = ClearField
865
866
867def _AddClearExtensionMethod(cls):
868 """Helper for _AddMessageMethods()."""
869 def ClearExtension(self, extension_handle):
870 _VerifyExtensionHandle(self, extension_handle)
871
872 # Similar to ClearField(), above.
873 if extension_handle in self._fields:
874 del self._fields[extension_handle]
875 self._Modified()
876 cls.ClearExtension = ClearExtension
877
878
Brian Silverman9c614bc2016-02-15 20:20:02 -0500879def _AddHasExtensionMethod(cls):
880 """Helper for _AddMessageMethods()."""
881 def HasExtension(self, extension_handle):
882 _VerifyExtensionHandle(self, extension_handle)
883 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
884 raise KeyError('"%s" is repeated.' % extension_handle.full_name)
885
886 if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
887 value = self._fields.get(extension_handle)
888 return value is not None and value._is_present_in_parent
889 else:
890 return extension_handle in self._fields
891 cls.HasExtension = HasExtension
892
893def _InternalUnpackAny(msg):
894 """Unpacks Any message and returns the unpacked message.
895
Austin Schuh40c16522018-10-28 20:27:54 -0700896 This internal method is different from public Any Unpack method which takes
Brian Silverman9c614bc2016-02-15 20:20:02 -0500897 the target message as argument. _InternalUnpackAny method does not have
898 target message type and need to find the message type in descriptor pool.
899
900 Args:
901 msg: An Any message to be unpacked.
902
903 Returns:
904 The unpacked message.
905 """
Austin Schuh40c16522018-10-28 20:27:54 -0700906 # TODO(amauryfa): Don't use the factory of generated messages.
907 # To make Any work with custom factories, use the message factory of the
908 # parent message.
909 # pylint: disable=g-import-not-at-top
910 from google.protobuf import symbol_database
911 factory = symbol_database.Default()
912
Brian Silverman9c614bc2016-02-15 20:20:02 -0500913 type_url = msg.type_url
Brian Silverman9c614bc2016-02-15 20:20:02 -0500914
915 if not type_url:
916 return None
917
918 # TODO(haberman): For now we just strip the hostname. Better logic will be
919 # required.
Austin Schuh40c16522018-10-28 20:27:54 -0700920 type_name = type_url.split('/')[-1]
921 descriptor = factory.pool.FindMessageTypeByName(type_name)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500922
923 if descriptor is None:
924 return None
925
Austin Schuh40c16522018-10-28 20:27:54 -0700926 message_class = factory.GetPrototype(descriptor)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500927 message = message_class()
928
929 message.ParseFromString(msg.value)
930 return message
931
Austin Schuh40c16522018-10-28 20:27:54 -0700932
Brian Silverman9c614bc2016-02-15 20:20:02 -0500933def _AddEqualsMethod(message_descriptor, cls):
934 """Helper for _AddMessageMethods()."""
935 def __eq__(self, other):
936 if (not isinstance(other, message_mod.Message) or
937 other.DESCRIPTOR != self.DESCRIPTOR):
938 return False
939
940 if self is other:
941 return True
942
943 if self.DESCRIPTOR.full_name == _AnyFullTypeName:
944 any_a = _InternalUnpackAny(self)
945 any_b = _InternalUnpackAny(other)
946 if any_a and any_b:
947 return any_a == any_b
948
949 if not self.ListFields() == other.ListFields():
950 return False
951
952 # Sort unknown fields because their order shouldn't affect equality test.
953 unknown_fields = list(self._unknown_fields)
954 unknown_fields.sort()
955 other_unknown_fields = list(other._unknown_fields)
956 other_unknown_fields.sort()
957
958 return unknown_fields == other_unknown_fields
959
960 cls.__eq__ = __eq__
961
962
963def _AddStrMethod(message_descriptor, cls):
964 """Helper for _AddMessageMethods()."""
965 def __str__(self):
966 return text_format.MessageToString(self)
967 cls.__str__ = __str__
968
969
970def _AddReprMethod(message_descriptor, cls):
971 """Helper for _AddMessageMethods()."""
972 def __repr__(self):
973 return text_format.MessageToString(self)
974 cls.__repr__ = __repr__
975
976
977def _AddUnicodeMethod(unused_message_descriptor, cls):
978 """Helper for _AddMessageMethods()."""
979
980 def __unicode__(self):
981 return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
982 cls.__unicode__ = __unicode__
983
984
Brian Silverman9c614bc2016-02-15 20:20:02 -0500985def _BytesForNonRepeatedElement(value, field_number, field_type):
986 """Returns the number of bytes needed to serialize a non-repeated element.
987 The returned byte count includes space for tag information and any
988 other additional space associated with serializing value.
989
990 Args:
991 value: Value we're serializing.
992 field_number: Field number of this value. (Since the field number
993 is stored as part of a varint-encoded tag, this has an impact
994 on the total bytes required to serialize the value).
995 field_type: The type of the field. One of the TYPE_* constants
996 within FieldDescriptor.
997 """
998 try:
999 fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
1000 return fn(field_number, value)
1001 except KeyError:
1002 raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
1003
1004
1005def _AddByteSizeMethod(message_descriptor, cls):
1006 """Helper for _AddMessageMethods()."""
1007
1008 def ByteSize(self):
1009 if not self._cached_byte_size_dirty:
1010 return self._cached_byte_size
1011
1012 size = 0
Austin Schuh40c16522018-10-28 20:27:54 -07001013 descriptor = self.DESCRIPTOR
1014 if descriptor.GetOptions().map_entry:
1015 # Fields of map entry should always be serialized.
1016 size = descriptor.fields_by_name['key']._sizer(self.key)
1017 size += descriptor.fields_by_name['value']._sizer(self.value)
1018 else:
1019 for field_descriptor, field_value in self.ListFields():
1020 size += field_descriptor._sizer(field_value)
1021 for tag_bytes, value_bytes in self._unknown_fields:
1022 size += len(tag_bytes) + len(value_bytes)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001023
1024 self._cached_byte_size = size
1025 self._cached_byte_size_dirty = False
1026 self._listener_for_children.dirty = False
1027 return size
1028
1029 cls.ByteSize = ByteSize
1030
1031
1032def _AddSerializeToStringMethod(message_descriptor, cls):
1033 """Helper for _AddMessageMethods()."""
1034
Austin Schuh40c16522018-10-28 20:27:54 -07001035 def SerializeToString(self, **kwargs):
Brian Silverman9c614bc2016-02-15 20:20:02 -05001036 # Check if the message has all of its required fields set.
1037 errors = []
1038 if not self.IsInitialized():
1039 raise message_mod.EncodeError(
1040 'Message %s is missing required fields: %s' % (
1041 self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
Austin Schuh40c16522018-10-28 20:27:54 -07001042 return self.SerializePartialToString(**kwargs)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001043 cls.SerializeToString = SerializeToString
1044
1045
1046def _AddSerializePartialToStringMethod(message_descriptor, cls):
1047 """Helper for _AddMessageMethods()."""
1048
Austin Schuh40c16522018-10-28 20:27:54 -07001049 def SerializePartialToString(self, **kwargs):
Brian Silverman9c614bc2016-02-15 20:20:02 -05001050 out = BytesIO()
Austin Schuh40c16522018-10-28 20:27:54 -07001051 self._InternalSerialize(out.write, **kwargs)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001052 return out.getvalue()
1053 cls.SerializePartialToString = SerializePartialToString
1054
Austin Schuh40c16522018-10-28 20:27:54 -07001055 def InternalSerialize(self, write_bytes, deterministic=None):
1056 if deterministic is None:
1057 deterministic = (
1058 api_implementation.IsPythonDefaultSerializationDeterministic())
1059 else:
1060 deterministic = bool(deterministic)
1061
1062 descriptor = self.DESCRIPTOR
1063 if descriptor.GetOptions().map_entry:
1064 # Fields of map entry should always be serialized.
1065 descriptor.fields_by_name['key']._encoder(
1066 write_bytes, self.key, deterministic)
1067 descriptor.fields_by_name['value']._encoder(
1068 write_bytes, self.value, deterministic)
1069 else:
1070 for field_descriptor, field_value in self.ListFields():
1071 field_descriptor._encoder(write_bytes, field_value, deterministic)
1072 for tag_bytes, value_bytes in self._unknown_fields:
1073 write_bytes(tag_bytes)
1074 write_bytes(value_bytes)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001075 cls._InternalSerialize = InternalSerialize
1076
1077
1078def _AddMergeFromStringMethod(message_descriptor, cls):
1079 """Helper for _AddMessageMethods()."""
1080 def MergeFromString(self, serialized):
1081 length = len(serialized)
1082 try:
1083 if self._InternalParse(serialized, 0, length) != length:
1084 # The only reason _InternalParse would return early is if it
1085 # encountered an end-group tag.
1086 raise message_mod.DecodeError('Unexpected end-group tag.')
1087 except (IndexError, TypeError):
1088 # Now ord(buf[p:p+1]) == ord('') gets TypeError.
1089 raise message_mod.DecodeError('Truncated message.')
1090 except struct.error as e:
1091 raise message_mod.DecodeError(e)
1092 return length # Return this for legacy reasons.
1093 cls.MergeFromString = MergeFromString
1094
1095 local_ReadTag = decoder.ReadTag
1096 local_SkipField = decoder.SkipField
1097 decoders_by_tag = cls._decoders_by_tag
1098 is_proto3 = message_descriptor.syntax == "proto3"
1099
1100 def InternalParse(self, buffer, pos, end):
1101 self._Modified()
1102 field_dict = self._fields
1103 unknown_field_list = self._unknown_fields
1104 while pos != end:
1105 (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
1106 field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
1107 if field_decoder is None:
1108 value_start_pos = new_pos
1109 new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
1110 if new_pos == -1:
1111 return pos
Austin Schuh40c16522018-10-28 20:27:54 -07001112 if (not is_proto3 or
1113 api_implementation.GetPythonProto3PreserveUnknownsDefault()):
Brian Silverman9c614bc2016-02-15 20:20:02 -05001114 if not unknown_field_list:
1115 unknown_field_list = self._unknown_fields = []
1116 unknown_field_list.append(
1117 (tag_bytes, buffer[value_start_pos:new_pos]))
1118 pos = new_pos
1119 else:
1120 pos = field_decoder(buffer, new_pos, end, self, field_dict)
1121 if field_desc:
1122 self._UpdateOneofState(field_desc)
1123 return pos
1124 cls._InternalParse = InternalParse
1125
1126
1127def _AddIsInitializedMethod(message_descriptor, cls):
1128 """Adds the IsInitialized and FindInitializationError methods to the
1129 protocol message class."""
1130
1131 required_fields = [field for field in message_descriptor.fields
1132 if field.label == _FieldDescriptor.LABEL_REQUIRED]
1133
1134 def IsInitialized(self, errors=None):
1135 """Checks if all required fields of a message are set.
1136
1137 Args:
1138 errors: A list which, if provided, will be populated with the field
1139 paths of all missing required fields.
1140
1141 Returns:
1142 True iff the specified message has all required fields set.
1143 """
1144
1145 # Performance is critical so we avoid HasField() and ListFields().
1146
1147 for field in required_fields:
1148 if (field not in self._fields or
1149 (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
1150 not self._fields[field]._is_present_in_parent)):
1151 if errors is not None:
1152 errors.extend(self.FindInitializationErrors())
1153 return False
1154
1155 for field, value in list(self._fields.items()): # dict can change size!
1156 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1157 if field.label == _FieldDescriptor.LABEL_REPEATED:
1158 if (field.message_type.has_options and
1159 field.message_type.GetOptions().map_entry):
1160 continue
1161 for element in value:
1162 if not element.IsInitialized():
1163 if errors is not None:
1164 errors.extend(self.FindInitializationErrors())
1165 return False
1166 elif value._is_present_in_parent and not value.IsInitialized():
1167 if errors is not None:
1168 errors.extend(self.FindInitializationErrors())
1169 return False
1170
1171 return True
1172
1173 cls.IsInitialized = IsInitialized
1174
1175 def FindInitializationErrors(self):
1176 """Finds required fields which are not initialized.
1177
1178 Returns:
1179 A list of strings. Each string is a path to an uninitialized field from
1180 the top-level message, e.g. "foo.bar[5].baz".
1181 """
1182
1183 errors = [] # simplify things
1184
1185 for field in required_fields:
1186 if not self.HasField(field.name):
1187 errors.append(field.name)
1188
1189 for field, value in self.ListFields():
1190 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1191 if field.is_extension:
1192 name = "(%s)" % field.full_name
1193 else:
1194 name = field.name
1195
1196 if _IsMapField(field):
1197 if _IsMessageMapField(field):
1198 for key in value:
1199 element = value[key]
1200 prefix = "%s[%s]." % (name, key)
1201 sub_errors = element.FindInitializationErrors()
1202 errors += [prefix + error for error in sub_errors]
1203 else:
1204 # ScalarMaps can't have any initialization errors.
1205 pass
1206 elif field.label == _FieldDescriptor.LABEL_REPEATED:
1207 for i in range(len(value)):
1208 element = value[i]
1209 prefix = "%s[%d]." % (name, i)
1210 sub_errors = element.FindInitializationErrors()
1211 errors += [prefix + error for error in sub_errors]
1212 else:
1213 prefix = name + "."
1214 sub_errors = value.FindInitializationErrors()
1215 errors += [prefix + error for error in sub_errors]
1216
1217 return errors
1218
1219 cls.FindInitializationErrors = FindInitializationErrors
1220
1221
1222def _AddMergeFromMethod(cls):
1223 LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
1224 CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
1225
1226 def MergeFrom(self, msg):
1227 if not isinstance(msg, cls):
1228 raise TypeError(
1229 "Parameter to MergeFrom() must be instance of same class: "
Austin Schuh40c16522018-10-28 20:27:54 -07001230 'expected %s got %s.' % (cls.__name__, msg.__class__.__name__))
Brian Silverman9c614bc2016-02-15 20:20:02 -05001231
1232 assert msg is not self
1233 self._Modified()
1234
1235 fields = self._fields
1236
1237 for field, value in msg._fields.items():
1238 if field.label == LABEL_REPEATED:
1239 field_value = fields.get(field)
1240 if field_value is None:
1241 # Construct a new object to represent this field.
1242 field_value = field._default_constructor(self)
1243 fields[field] = field_value
1244 field_value.MergeFrom(value)
1245 elif field.cpp_type == CPPTYPE_MESSAGE:
1246 if value._is_present_in_parent:
1247 field_value = fields.get(field)
1248 if field_value is None:
1249 # Construct a new object to represent this field.
1250 field_value = field._default_constructor(self)
1251 fields[field] = field_value
1252 field_value.MergeFrom(value)
1253 else:
1254 self._fields[field] = value
1255 if field.containing_oneof:
1256 self._UpdateOneofState(field)
1257
1258 if msg._unknown_fields:
1259 if not self._unknown_fields:
1260 self._unknown_fields = []
1261 self._unknown_fields.extend(msg._unknown_fields)
1262
1263 cls.MergeFrom = MergeFrom
1264
1265
1266def _AddWhichOneofMethod(message_descriptor, cls):
1267 def WhichOneof(self, oneof_name):
1268 """Returns the name of the currently set field inside a oneof, or None."""
1269 try:
1270 field = message_descriptor.oneofs_by_name[oneof_name]
1271 except KeyError:
1272 raise ValueError(
1273 'Protocol message has no oneof "%s" field.' % oneof_name)
1274
1275 nested_field = self._oneofs.get(field, None)
1276 if nested_field is not None and self.HasField(nested_field.name):
1277 return nested_field.name
1278 else:
1279 return None
1280
1281 cls.WhichOneof = WhichOneof
1282
1283
Austin Schuh40c16522018-10-28 20:27:54 -07001284def _AddReduceMethod(cls):
1285 def __reduce__(self): # pylint: disable=invalid-name
1286 return (type(self), (), self.__getstate__())
1287 cls.__reduce__ = __reduce__
1288
1289
1290def _Clear(self):
1291 # Clear fields.
1292 self._fields = {}
1293 self._unknown_fields = ()
1294 self._oneofs = {}
1295 self._Modified()
1296
1297
1298def _DiscardUnknownFields(self):
1299 self._unknown_fields = []
1300 for field, value in self.ListFields():
1301 if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1302 if field.label == _FieldDescriptor.LABEL_REPEATED:
1303 for sub_message in value:
1304 sub_message.DiscardUnknownFields()
1305 else:
1306 value.DiscardUnknownFields()
1307
1308
1309def _SetListener(self, listener):
1310 if listener is None:
1311 self._listener = message_listener_mod.NullMessageListener()
1312 else:
1313 self._listener = listener
1314
1315
Brian Silverman9c614bc2016-02-15 20:20:02 -05001316def _AddMessageMethods(message_descriptor, cls):
1317 """Adds implementations of all Message methods to cls."""
1318 _AddListFieldsMethod(message_descriptor, cls)
1319 _AddHasFieldMethod(message_descriptor, cls)
1320 _AddClearFieldMethod(message_descriptor, cls)
1321 if message_descriptor.is_extendable:
1322 _AddClearExtensionMethod(cls)
1323 _AddHasExtensionMethod(cls)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001324 _AddEqualsMethod(message_descriptor, cls)
1325 _AddStrMethod(message_descriptor, cls)
1326 _AddReprMethod(message_descriptor, cls)
1327 _AddUnicodeMethod(message_descriptor, cls)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001328 _AddByteSizeMethod(message_descriptor, cls)
1329 _AddSerializeToStringMethod(message_descriptor, cls)
1330 _AddSerializePartialToStringMethod(message_descriptor, cls)
1331 _AddMergeFromStringMethod(message_descriptor, cls)
1332 _AddIsInitializedMethod(message_descriptor, cls)
1333 _AddMergeFromMethod(cls)
1334 _AddWhichOneofMethod(message_descriptor, cls)
Austin Schuh40c16522018-10-28 20:27:54 -07001335 _AddReduceMethod(cls)
1336 # Adds methods which do not depend on cls.
1337 cls.Clear = _Clear
1338 cls.DiscardUnknownFields = _DiscardUnknownFields
1339 cls._SetListener = _SetListener
Brian Silverman9c614bc2016-02-15 20:20:02 -05001340
1341
1342def _AddPrivateHelperMethods(message_descriptor, cls):
1343 """Adds implementation of private helper methods to cls."""
1344
1345 def Modified(self):
1346 """Sets the _cached_byte_size_dirty bit to true,
1347 and propagates this to our listener iff this was a state change.
1348 """
1349
1350 # Note: Some callers check _cached_byte_size_dirty before calling
1351 # _Modified() as an extra optimization. So, if this method is ever
1352 # changed such that it does stuff even when _cached_byte_size_dirty is
1353 # already true, the callers need to be updated.
1354 if not self._cached_byte_size_dirty:
1355 self._cached_byte_size_dirty = True
1356 self._listener_for_children.dirty = True
1357 self._is_present_in_parent = True
1358 self._listener.Modified()
1359
1360 def _UpdateOneofState(self, field):
1361 """Sets field as the active field in its containing oneof.
1362
1363 Will also delete currently active field in the oneof, if it is different
1364 from the argument. Does not mark the message as modified.
1365 """
1366 other_field = self._oneofs.setdefault(field.containing_oneof, field)
1367 if other_field is not field:
1368 del self._fields[other_field]
1369 self._oneofs[field.containing_oneof] = field
1370
1371 cls._Modified = Modified
1372 cls.SetInParent = Modified
1373 cls._UpdateOneofState = _UpdateOneofState
1374
1375
1376class _Listener(object):
1377
1378 """MessageListener implementation that a parent message registers with its
1379 child message.
1380
1381 In order to support semantics like:
1382
1383 foo.bar.baz.qux = 23
1384 assert foo.HasField('bar')
1385
1386 ...child objects must have back references to their parents.
1387 This helper class is at the heart of this support.
1388 """
1389
1390 def __init__(self, parent_message):
1391 """Args:
1392 parent_message: The message whose _Modified() method we should call when
1393 we receive Modified() messages.
1394 """
1395 # This listener establishes a back reference from a child (contained) object
1396 # to its parent (containing) object. We make this a weak reference to avoid
1397 # creating cyclic garbage when the client finishes with the 'parent' object
1398 # in the tree.
1399 if isinstance(parent_message, weakref.ProxyType):
1400 self._parent_message_weakref = parent_message
1401 else:
1402 self._parent_message_weakref = weakref.proxy(parent_message)
1403
1404 # As an optimization, we also indicate directly on the listener whether
1405 # or not the parent message is dirty. This way we can avoid traversing
1406 # up the tree in the common case.
1407 self.dirty = False
1408
1409 def Modified(self):
1410 if self.dirty:
1411 return
1412 try:
1413 # Propagate the signal to our parents iff this is the first field set.
1414 self._parent_message_weakref._Modified()
1415 except ReferenceError:
1416 # We can get here if a client has kept a reference to a child object,
1417 # and is now setting a field on it, but the child's parent has been
1418 # garbage-collected. This is not an error.
1419 pass
1420
1421
1422class _OneofListener(_Listener):
1423 """Special listener implementation for setting composite oneof fields."""
1424
1425 def __init__(self, parent_message, field):
1426 """Args:
1427 parent_message: The message whose _Modified() method we should call when
1428 we receive Modified() messages.
1429 field: The descriptor of the field being set in the parent message.
1430 """
1431 super(_OneofListener, self).__init__(parent_message)
1432 self._field = field
1433
1434 def Modified(self):
1435 """Also updates the state of the containing oneof in the parent message."""
1436 try:
1437 self._parent_message_weakref._UpdateOneofState(self._field)
1438 super(_OneofListener, self).Modified()
1439 except ReferenceError:
1440 pass
1441
1442
1443# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
1444# TODO(robinson): Unify error handling of "unknown extension" crap.
1445# TODO(robinson): Support iteritems()-style iteration over all
1446# extensions with the "has" bits turned on?
1447class _ExtensionDict(object):
1448
1449 """Dict-like container for supporting an indexable "Extensions"
1450 field on proto instances.
1451
1452 Note that in all cases we expect extension handles to be
1453 FieldDescriptors.
1454 """
1455
1456 def __init__(self, extended_message):
1457 """extended_message: Message instance for which we are the Extensions dict.
1458 """
1459
1460 self._extended_message = extended_message
1461
1462 def __getitem__(self, extension_handle):
1463 """Returns the current value of the given extension handle."""
1464
1465 _VerifyExtensionHandle(self._extended_message, extension_handle)
1466
1467 result = self._extended_message._fields.get(extension_handle)
1468 if result is not None:
1469 return result
1470
1471 if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
1472 result = extension_handle._default_constructor(self._extended_message)
1473 elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
1474 result = extension_handle.message_type._concrete_class()
1475 try:
1476 result._SetListener(self._extended_message._listener_for_children)
1477 except ReferenceError:
1478 pass
1479 else:
1480 # Singular scalar -- just return the default without inserting into the
1481 # dict.
1482 return extension_handle.default_value
1483
1484 # Atomically check if another thread has preempted us and, if not, swap
1485 # in the new object we just created. If someone has preempted us, we
1486 # take that object and discard ours.
1487 # WARNING: We are relying on setdefault() being atomic. This is true
1488 # in CPython but we haven't investigated others. This warning appears
1489 # in several other locations in this file.
1490 result = self._extended_message._fields.setdefault(
1491 extension_handle, result)
1492
1493 return result
1494
1495 def __eq__(self, other):
1496 if not isinstance(other, self.__class__):
1497 return False
1498
1499 my_fields = self._extended_message.ListFields()
1500 other_fields = other._extended_message.ListFields()
1501
1502 # Get rid of non-extension fields.
1503 my_fields = [ field for field in my_fields if field.is_extension ]
1504 other_fields = [ field for field in other_fields if field.is_extension ]
1505
1506 return my_fields == other_fields
1507
1508 def __ne__(self, other):
1509 return not self == other
1510
1511 def __hash__(self):
1512 raise TypeError('unhashable object')
1513
1514 # Note that this is only meaningful for non-repeated, scalar extension
1515 # fields. Note also that we may have to call _Modified() when we do
1516 # successfully set a field this way, to set any necssary "has" bits in the
1517 # ancestors of the extended message.
1518 def __setitem__(self, extension_handle, value):
1519 """If extension_handle specifies a non-repeated, scalar extension
1520 field, sets the value of that field.
1521 """
1522
1523 _VerifyExtensionHandle(self._extended_message, extension_handle)
1524
1525 if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
1526 extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
1527 raise TypeError(
1528 'Cannot assign to extension "%s" because it is a repeated or '
1529 'composite type.' % extension_handle.full_name)
1530
1531 # It's slightly wasteful to lookup the type checker each time,
1532 # but we expect this to be a vanishingly uncommon case anyway.
1533 type_checker = type_checkers.GetTypeChecker(extension_handle)
1534 # pylint: disable=protected-access
1535 self._extended_message._fields[extension_handle] = (
1536 type_checker.CheckValue(value))
1537 self._extended_message._Modified()
1538
1539 def _FindExtensionByName(self, name):
1540 """Tries to find a known extension with the specified name.
1541
1542 Args:
1543 name: Extension full name.
1544
1545 Returns:
1546 Extension field descriptor.
1547 """
1548 return self._extended_message._extensions_by_name.get(name, None)
Austin Schuh40c16522018-10-28 20:27:54 -07001549
1550 def _FindExtensionByNumber(self, number):
1551 """Tries to find a known extension with the field number.
1552
1553 Args:
1554 number: Extension field number.
1555
1556 Returns:
1557 Extension field descriptor.
1558 """
1559 return self._extended_message._extensions_by_number.get(number, None)