blob: 8983f76f01ec899ec21802350c25d9a7b6937c77 [file] [log] [blame]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""Provides DescriptorPool to use as a container for proto2 descriptors.
32
33The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
34a collection of protocol buffer descriptors for use when dynamically creating
35message types at runtime.
36
37For most applications protocol buffers should be used via modules generated by
38the protocol buffer compiler tool. This should only be used when the type of
39protocol buffers used in an application or library cannot be predetermined.
40
41Below is a straightforward example on how to use this class:
42
43 pool = DescriptorPool()
44 file_descriptor_protos = [ ... ]
45 for file_descriptor_proto in file_descriptor_protos:
46 pool.Add(file_descriptor_proto)
47 my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
48
49The message descriptor can be used in conjunction with the message_factory
50module in order to create a protocol buffer class that can be encoded and
51decoded.
52
53If you want to get a Python class for the specified proto, use the
54helper functions inside google.protobuf.message_factory
55directly instead of this class.
56"""
57
58__author__ = 'matthewtoia@google.com (Matt Toia)'
59
Austin Schuh40c16522018-10-28 20:27:54 -070060import collections
61import warnings
62
Brian Silverman9c614bc2016-02-15 20:20:02 -050063from google.protobuf import descriptor
64from google.protobuf import descriptor_database
65from google.protobuf import text_encoding
66
67
Austin Schuh40c16522018-10-28 20:27:54 -070068_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS # pylint: disable=protected-access
Brian Silverman9c614bc2016-02-15 20:20:02 -050069
70
71def _NormalizeFullyQualifiedName(name):
72 """Remove leading period from fully-qualified type name.
73
74 Due to b/13860351 in descriptor_database.py, types in the root namespace are
75 generated with a leading period. This function removes that prefix.
76
77 Args:
78 name: A str, the fully-qualified symbol name.
79
80 Returns:
81 A str, the normalized fully-qualified symbol name.
82 """
83 return name.lstrip('.')
84
85
Austin Schuh40c16522018-10-28 20:27:54 -070086def _OptionsOrNone(descriptor_proto):
87 """Returns the value of the field `options`, or None if it is not set."""
88 if descriptor_proto.HasField('options'):
89 return descriptor_proto.options
90 else:
91 return None
92
93
94def _IsMessageSetExtension(field):
95 return (field.is_extension and
96 field.containing_type.has_options and
97 field.containing_type.GetOptions().message_set_wire_format and
98 field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
99 field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL)
100
101
Brian Silverman9c614bc2016-02-15 20:20:02 -0500102class DescriptorPool(object):
103 """A collection of protobufs dynamically constructed by descriptor protos."""
104
105 if _USE_C_DESCRIPTORS:
106
107 def __new__(cls, descriptor_db=None):
108 # pylint: disable=protected-access
109 return descriptor._message.DescriptorPool(descriptor_db)
110
111 def __init__(self, descriptor_db=None):
112 """Initializes a Pool of proto buffs.
113
114 The descriptor_db argument to the constructor is provided to allow
115 specialized file descriptor proto lookup code to be triggered on demand. An
116 example would be an implementation which will read and compile a file
117 specified in a call to FindFileByName() and not require the call to Add()
118 at all. Results from this database will be cached internally here as well.
119
120 Args:
121 descriptor_db: A secondary source of file descriptors.
122 """
123
124 self._internal_db = descriptor_database.DescriptorDatabase()
125 self._descriptor_db = descriptor_db
126 self._descriptors = {}
127 self._enum_descriptors = {}
Austin Schuh40c16522018-10-28 20:27:54 -0700128 self._service_descriptors = {}
Brian Silverman9c614bc2016-02-15 20:20:02 -0500129 self._file_descriptors = {}
Austin Schuh40c16522018-10-28 20:27:54 -0700130 self._toplevel_extensions = {}
131 # TODO(jieluo): Remove _file_desc_by_toplevel_extension after
132 # maybe year 2020 for compatibility issue (with 3.4.1 only).
133 self._file_desc_by_toplevel_extension = {}
134 # We store extensions in two two-level mappings: The first key is the
135 # descriptor of the message being extended, the second key is the extension
136 # full name or its tag number.
137 self._extensions_by_name = collections.defaultdict(dict)
138 self._extensions_by_number = collections.defaultdict(dict)
139
140 def _CheckConflictRegister(self, desc):
141 """Check if the descriptor name conflicts with another of the same name.
142
143 Args:
144 desc: Descriptor of a message, enum, service or extension.
145 """
146 desc_name = desc.full_name
147 for register, descriptor_type in [
148 (self._descriptors, descriptor.Descriptor),
149 (self._enum_descriptors, descriptor.EnumDescriptor),
150 (self._service_descriptors, descriptor.ServiceDescriptor),
151 (self._toplevel_extensions, descriptor.FieldDescriptor)]:
152 if desc_name in register:
153 file_name = register[desc_name].file.name
154 if not isinstance(desc, descriptor_type) or (
155 file_name != desc.file.name):
156 warn_msg = ('Conflict register for file "' + desc.file.name +
157 '": ' + desc_name +
158 ' is already defined in file "' +
159 file_name + '"')
160 warnings.warn(warn_msg, RuntimeWarning)
161 return
Brian Silverman9c614bc2016-02-15 20:20:02 -0500162
163 def Add(self, file_desc_proto):
164 """Adds the FileDescriptorProto and its types to this pool.
165
166 Args:
167 file_desc_proto: The FileDescriptorProto to add.
168 """
169
170 self._internal_db.Add(file_desc_proto)
171
172 def AddSerializedFile(self, serialized_file_desc_proto):
173 """Adds the FileDescriptorProto and its types to this pool.
174
175 Args:
176 serialized_file_desc_proto: A bytes string, serialization of the
177 FileDescriptorProto to add.
178 """
179
180 # pylint: disable=g-import-not-at-top
181 from google.protobuf import descriptor_pb2
182 file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
183 serialized_file_desc_proto)
184 self.Add(file_desc_proto)
185
186 def AddDescriptor(self, desc):
187 """Adds a Descriptor to the pool, non-recursively.
188
189 If the Descriptor contains nested messages or enums, the caller must
190 explicitly register them. This method also registers the FileDescriptor
191 associated with the message.
192
193 Args:
194 desc: A Descriptor.
195 """
196 if not isinstance(desc, descriptor.Descriptor):
197 raise TypeError('Expected instance of descriptor.Descriptor.')
198
Austin Schuh40c16522018-10-28 20:27:54 -0700199 self._CheckConflictRegister(desc)
200
Brian Silverman9c614bc2016-02-15 20:20:02 -0500201 self._descriptors[desc.full_name] = desc
Austin Schuh40c16522018-10-28 20:27:54 -0700202 self._AddFileDescriptor(desc.file)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500203
204 def AddEnumDescriptor(self, enum_desc):
205 """Adds an EnumDescriptor to the pool.
206
Austin Schuh40c16522018-10-28 20:27:54 -0700207 This method also registers the FileDescriptor associated with the enum.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500208
209 Args:
210 enum_desc: An EnumDescriptor.
211 """
212
213 if not isinstance(enum_desc, descriptor.EnumDescriptor):
214 raise TypeError('Expected instance of descriptor.EnumDescriptor.')
215
Austin Schuh40c16522018-10-28 20:27:54 -0700216 self._CheckConflictRegister(enum_desc)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500217 self._enum_descriptors[enum_desc.full_name] = enum_desc
Austin Schuh40c16522018-10-28 20:27:54 -0700218 self._AddFileDescriptor(enum_desc.file)
219
220 def AddServiceDescriptor(self, service_desc):
221 """Adds a ServiceDescriptor to the pool.
222
223 Args:
224 service_desc: A ServiceDescriptor.
225 """
226
227 if not isinstance(service_desc, descriptor.ServiceDescriptor):
228 raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
229
230 self._CheckConflictRegister(service_desc)
231 self._service_descriptors[service_desc.full_name] = service_desc
232
233 def AddExtensionDescriptor(self, extension):
234 """Adds a FieldDescriptor describing an extension to the pool.
235
236 Args:
237 extension: A FieldDescriptor.
238
239 Raises:
240 AssertionError: when another extension with the same number extends the
241 same message.
242 TypeError: when the specified extension is not a
243 descriptor.FieldDescriptor.
244 """
245 if not (isinstance(extension, descriptor.FieldDescriptor) and
246 extension.is_extension):
247 raise TypeError('Expected an extension descriptor.')
248
249 if extension.extension_scope is None:
250 self._CheckConflictRegister(extension)
251 self._toplevel_extensions[extension.full_name] = extension
252
253 try:
254 existing_desc = self._extensions_by_number[
255 extension.containing_type][extension.number]
256 except KeyError:
257 pass
258 else:
259 if extension is not existing_desc:
260 raise AssertionError(
261 'Extensions "%s" and "%s" both try to extend message type "%s" '
262 'with field number %d.' %
263 (extension.full_name, existing_desc.full_name,
264 extension.containing_type.full_name, extension.number))
265
266 self._extensions_by_number[extension.containing_type][
267 extension.number] = extension
268 self._extensions_by_name[extension.containing_type][
269 extension.full_name] = extension
270
271 # Also register MessageSet extensions with the type name.
272 if _IsMessageSetExtension(extension):
273 self._extensions_by_name[extension.containing_type][
274 extension.message_type.full_name] = extension
Brian Silverman9c614bc2016-02-15 20:20:02 -0500275
276 def AddFileDescriptor(self, file_desc):
277 """Adds a FileDescriptor to the pool, non-recursively.
278
279 If the FileDescriptor contains messages or enums, the caller must explicitly
280 register them.
281
282 Args:
283 file_desc: A FileDescriptor.
284 """
285
Austin Schuh40c16522018-10-28 20:27:54 -0700286 self._AddFileDescriptor(file_desc)
287 # TODO(jieluo): This is a temporary solution for FieldDescriptor.file.
288 # FieldDescriptor.file is added in code gen. Remove this solution after
289 # maybe 2020 for compatibility reason (with 3.4.1 only).
290 for extension in file_desc.extensions_by_name.values():
291 self._file_desc_by_toplevel_extension[
292 extension.full_name] = file_desc
293
294 def _AddFileDescriptor(self, file_desc):
295 """Adds a FileDescriptor to the pool, non-recursively.
296
297 If the FileDescriptor contains messages or enums, the caller must explicitly
298 register them.
299
300 Args:
301 file_desc: A FileDescriptor.
302 """
303
Brian Silverman9c614bc2016-02-15 20:20:02 -0500304 if not isinstance(file_desc, descriptor.FileDescriptor):
305 raise TypeError('Expected instance of descriptor.FileDescriptor.')
306 self._file_descriptors[file_desc.name] = file_desc
307
308 def FindFileByName(self, file_name):
309 """Gets a FileDescriptor by file name.
310
311 Args:
312 file_name: The path to the file to get a descriptor for.
313
314 Returns:
315 A FileDescriptor for the named file.
316
317 Raises:
Austin Schuh40c16522018-10-28 20:27:54 -0700318 KeyError: if the file cannot be found in the pool.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500319 """
320
321 try:
322 return self._file_descriptors[file_name]
323 except KeyError:
324 pass
325
326 try:
327 file_proto = self._internal_db.FindFileByName(file_name)
328 except KeyError as error:
329 if self._descriptor_db:
330 file_proto = self._descriptor_db.FindFileByName(file_name)
331 else:
332 raise error
333 if not file_proto:
334 raise KeyError('Cannot find a file named %s' % file_name)
335 return self._ConvertFileProtoToFileDescriptor(file_proto)
336
337 def FindFileContainingSymbol(self, symbol):
338 """Gets the FileDescriptor for the file containing the specified symbol.
339
340 Args:
341 symbol: The name of the symbol to search for.
342
343 Returns:
344 A FileDescriptor that contains the specified symbol.
345
346 Raises:
Austin Schuh40c16522018-10-28 20:27:54 -0700347 KeyError: if the file cannot be found in the pool.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500348 """
349
350 symbol = _NormalizeFullyQualifiedName(symbol)
351 try:
352 return self._descriptors[symbol].file
353 except KeyError:
354 pass
355
356 try:
357 return self._enum_descriptors[symbol].file
358 except KeyError:
359 pass
360
361 try:
Austin Schuh40c16522018-10-28 20:27:54 -0700362 return self._service_descriptors[symbol].file
363 except KeyError:
364 pass
365
366 try:
367 return self._FindFileContainingSymbolInDb(symbol)
368 except KeyError:
369 pass
370
371 try:
372 return self._file_desc_by_toplevel_extension[symbol]
373 except KeyError:
374 pass
375
376 # Try nested extensions inside a message.
377 message_name, _, extension_name = symbol.rpartition('.')
378 try:
379 message = self.FindMessageTypeByName(message_name)
380 assert message.extensions_by_name[extension_name]
381 return message.file
382 except KeyError:
Brian Silverman9c614bc2016-02-15 20:20:02 -0500383 raise KeyError('Cannot find a file containing %s' % symbol)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500384
385 def FindMessageTypeByName(self, full_name):
386 """Loads the named descriptor from the pool.
387
388 Args:
389 full_name: The full name of the descriptor to load.
390
391 Returns:
392 The descriptor for the named type.
Austin Schuh40c16522018-10-28 20:27:54 -0700393
394 Raises:
395 KeyError: if the message cannot be found in the pool.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500396 """
397
398 full_name = _NormalizeFullyQualifiedName(full_name)
399 if full_name not in self._descriptors:
Austin Schuh40c16522018-10-28 20:27:54 -0700400 self._FindFileContainingSymbolInDb(full_name)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500401 return self._descriptors[full_name]
402
403 def FindEnumTypeByName(self, full_name):
404 """Loads the named enum descriptor from the pool.
405
406 Args:
407 full_name: The full name of the enum descriptor to load.
408
409 Returns:
410 The enum descriptor for the named type.
Austin Schuh40c16522018-10-28 20:27:54 -0700411
412 Raises:
413 KeyError: if the enum cannot be found in the pool.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500414 """
415
416 full_name = _NormalizeFullyQualifiedName(full_name)
417 if full_name not in self._enum_descriptors:
Austin Schuh40c16522018-10-28 20:27:54 -0700418 self._FindFileContainingSymbolInDb(full_name)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500419 return self._enum_descriptors[full_name]
420
421 def FindFieldByName(self, full_name):
422 """Loads the named field descriptor from the pool.
423
424 Args:
425 full_name: The full name of the field descriptor to load.
426
427 Returns:
428 The field descriptor for the named field.
Austin Schuh40c16522018-10-28 20:27:54 -0700429
430 Raises:
431 KeyError: if the field cannot be found in the pool.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500432 """
433 full_name = _NormalizeFullyQualifiedName(full_name)
434 message_name, _, field_name = full_name.rpartition('.')
435 message_descriptor = self.FindMessageTypeByName(message_name)
436 return message_descriptor.fields_by_name[field_name]
437
Austin Schuh40c16522018-10-28 20:27:54 -0700438 def FindOneofByName(self, full_name):
439 """Loads the named oneof descriptor from the pool.
440
441 Args:
442 full_name: The full name of the oneof descriptor to load.
443
444 Returns:
445 The oneof descriptor for the named oneof.
446
447 Raises:
448 KeyError: if the oneof cannot be found in the pool.
449 """
450 full_name = _NormalizeFullyQualifiedName(full_name)
451 message_name, _, oneof_name = full_name.rpartition('.')
452 message_descriptor = self.FindMessageTypeByName(message_name)
453 return message_descriptor.oneofs_by_name[oneof_name]
454
Brian Silverman9c614bc2016-02-15 20:20:02 -0500455 def FindExtensionByName(self, full_name):
456 """Loads the named extension descriptor from the pool.
457
458 Args:
459 full_name: The full name of the extension descriptor to load.
460
461 Returns:
462 A FieldDescriptor, describing the named extension.
Austin Schuh40c16522018-10-28 20:27:54 -0700463
464 Raises:
465 KeyError: if the extension cannot be found in the pool.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500466 """
467 full_name = _NormalizeFullyQualifiedName(full_name)
Austin Schuh40c16522018-10-28 20:27:54 -0700468 try:
469 # The proto compiler does not give any link between the FileDescriptor
470 # and top-level extensions unless the FileDescriptorProto is added to
471 # the DescriptorDatabase, but this can impact memory usage.
472 # So we registered these extensions by name explicitly.
473 return self._toplevel_extensions[full_name]
474 except KeyError:
475 pass
Brian Silverman9c614bc2016-02-15 20:20:02 -0500476 message_name, _, extension_name = full_name.rpartition('.')
477 try:
478 # Most extensions are nested inside a message.
479 scope = self.FindMessageTypeByName(message_name)
480 except KeyError:
481 # Some extensions are defined at file scope.
Austin Schuh40c16522018-10-28 20:27:54 -0700482 scope = self._FindFileContainingSymbolInDb(full_name)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500483 return scope.extensions_by_name[extension_name]
484
Austin Schuh40c16522018-10-28 20:27:54 -0700485 def FindExtensionByNumber(self, message_descriptor, number):
486 """Gets the extension of the specified message with the specified number.
487
488 Extensions have to be registered to this pool by calling
489 AddExtensionDescriptor.
490
491 Args:
492 message_descriptor: descriptor of the extended message.
493 number: integer, number of the extension field.
494
495 Returns:
496 A FieldDescriptor describing the extension.
497
498 Raises:
499 KeyError: when no extension with the given number is known for the
500 specified message.
501 """
502 return self._extensions_by_number[message_descriptor][number]
503
504 def FindAllExtensions(self, message_descriptor):
505 """Gets all the known extension of a given message.
506
507 Extensions have to be registered to this pool by calling
508 AddExtensionDescriptor.
509
510 Args:
511 message_descriptor: descriptor of the extended message.
512
513 Returns:
514 A list of FieldDescriptor describing the extensions.
515 """
516 return list(self._extensions_by_number[message_descriptor].values())
517
518 def FindServiceByName(self, full_name):
519 """Loads the named service descriptor from the pool.
520
521 Args:
522 full_name: The full name of the service descriptor to load.
523
524 Returns:
525 The service descriptor for the named service.
526
527 Raises:
528 KeyError: if the service cannot be found in the pool.
529 """
530 full_name = _NormalizeFullyQualifiedName(full_name)
531 if full_name not in self._service_descriptors:
532 self._FindFileContainingSymbolInDb(full_name)
533 return self._service_descriptors[full_name]
534
535 def _FindFileContainingSymbolInDb(self, symbol):
536 """Finds the file in descriptor DB containing the specified symbol.
537
538 Args:
539 symbol: The name of the symbol to search for.
540
541 Returns:
542 A FileDescriptor that contains the specified symbol.
543
544 Raises:
545 KeyError: if the file cannot be found in the descriptor database.
546 """
547 try:
548 file_proto = self._internal_db.FindFileContainingSymbol(symbol)
549 except KeyError as error:
550 if self._descriptor_db:
551 file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
552 else:
553 raise error
554 if not file_proto:
555 raise KeyError('Cannot find a file containing %s' % symbol)
556 return self._ConvertFileProtoToFileDescriptor(file_proto)
557
Brian Silverman9c614bc2016-02-15 20:20:02 -0500558 def _ConvertFileProtoToFileDescriptor(self, file_proto):
559 """Creates a FileDescriptor from a proto or returns a cached copy.
560
561 This method also has the side effect of loading all the symbols found in
562 the file into the appropriate dictionaries in the pool.
563
564 Args:
565 file_proto: The proto to convert.
566
567 Returns:
568 A FileDescriptor matching the passed in proto.
569 """
570
571 if file_proto.name not in self._file_descriptors:
572 built_deps = list(self._GetDeps(file_proto.dependency))
573 direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
Austin Schuh40c16522018-10-28 20:27:54 -0700574 public_deps = [direct_deps[i] for i in file_proto.public_dependency]
Brian Silverman9c614bc2016-02-15 20:20:02 -0500575
576 file_descriptor = descriptor.FileDescriptor(
577 pool=self,
578 name=file_proto.name,
579 package=file_proto.package,
580 syntax=file_proto.syntax,
Austin Schuh40c16522018-10-28 20:27:54 -0700581 options=_OptionsOrNone(file_proto),
Brian Silverman9c614bc2016-02-15 20:20:02 -0500582 serialized_pb=file_proto.SerializeToString(),
Austin Schuh40c16522018-10-28 20:27:54 -0700583 dependencies=direct_deps,
584 public_dependencies=public_deps)
585 scope = {}
586
587 # This loop extracts all the message and enum types from all the
588 # dependencies of the file_proto. This is necessary to create the
589 # scope of available message types when defining the passed in
590 # file proto.
591 for dependency in built_deps:
592 scope.update(self._ExtractSymbols(
593 dependency.message_types_by_name.values()))
594 scope.update((_PrefixWithDot(enum.full_name), enum)
595 for enum in dependency.enum_types_by_name.values())
596
597 for message_type in file_proto.message_type:
598 message_desc = self._ConvertMessageDescriptor(
599 message_type, file_proto.package, file_descriptor, scope,
600 file_proto.syntax)
601 file_descriptor.message_types_by_name[message_desc.name] = (
602 message_desc)
603
604 for enum_type in file_proto.enum_type:
605 file_descriptor.enum_types_by_name[enum_type.name] = (
606 self._ConvertEnumDescriptor(enum_type, file_proto.package,
607 file_descriptor, None, scope))
608
609 for index, extension_proto in enumerate(file_proto.extension):
610 extension_desc = self._MakeFieldDescriptor(
611 extension_proto, file_proto.package, index, file_descriptor,
612 is_extension=True)
613 extension_desc.containing_type = self._GetTypeFromScope(
614 file_descriptor.package, extension_proto.extendee, scope)
615 self._SetFieldType(extension_proto, extension_desc,
616 file_descriptor.package, scope)
617 file_descriptor.extensions_by_name[extension_desc.name] = (
618 extension_desc)
619
620 for desc_proto in file_proto.message_type:
621 self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
622
623 if file_proto.package:
624 desc_proto_prefix = _PrefixWithDot(file_proto.package)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500625 else:
Austin Schuh40c16522018-10-28 20:27:54 -0700626 desc_proto_prefix = ''
Brian Silverman9c614bc2016-02-15 20:20:02 -0500627
Austin Schuh40c16522018-10-28 20:27:54 -0700628 for desc_proto in file_proto.message_type:
629 desc = self._GetTypeFromScope(
630 desc_proto_prefix, desc_proto.name, scope)
631 file_descriptor.message_types_by_name[desc_proto.name] = desc
Brian Silverman9c614bc2016-02-15 20:20:02 -0500632
Austin Schuh40c16522018-10-28 20:27:54 -0700633 for index, service_proto in enumerate(file_proto.service):
634 file_descriptor.services_by_name[service_proto.name] = (
635 self._MakeServiceDescriptor(service_proto, index, scope,
636 file_proto.package, file_descriptor))
Brian Silverman9c614bc2016-02-15 20:20:02 -0500637
638 self.Add(file_proto)
639 self._file_descriptors[file_proto.name] = file_descriptor
640
641 return self._file_descriptors[file_proto.name]
642
643 def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
644 scope=None, syntax=None):
645 """Adds the proto to the pool in the specified package.
646
647 Args:
648 desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
649 package: The package the proto should be located in.
650 file_desc: The file containing this message.
651 scope: Dict mapping short and full symbols to message and enum types.
Austin Schuh40c16522018-10-28 20:27:54 -0700652 syntax: string indicating syntax of the file ("proto2" or "proto3")
Brian Silverman9c614bc2016-02-15 20:20:02 -0500653
654 Returns:
655 The added descriptor.
656 """
657
658 if package:
659 desc_name = '.'.join((package, desc_proto.name))
660 else:
661 desc_name = desc_proto.name
662
663 if file_desc is None:
664 file_name = None
665 else:
666 file_name = file_desc.name
667
668 if scope is None:
669 scope = {}
670
671 nested = [
672 self._ConvertMessageDescriptor(
673 nested, desc_name, file_desc, scope, syntax)
674 for nested in desc_proto.nested_type]
675 enums = [
676 self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
677 for enum in desc_proto.enum_type]
Austin Schuh40c16522018-10-28 20:27:54 -0700678 fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500679 for index, field in enumerate(desc_proto.field)]
680 extensions = [
Austin Schuh40c16522018-10-28 20:27:54 -0700681 self._MakeFieldDescriptor(extension, desc_name, index, file_desc,
Brian Silverman9c614bc2016-02-15 20:20:02 -0500682 is_extension=True)
683 for index, extension in enumerate(desc_proto.extension)]
684 oneofs = [
685 descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)),
Austin Schuh40c16522018-10-28 20:27:54 -0700686 index, None, [], desc.options)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500687 for index, desc in enumerate(desc_proto.oneof_decl)]
688 extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
689 if extension_ranges:
690 is_extendable = True
691 else:
692 is_extendable = False
693 desc = descriptor.Descriptor(
694 name=desc_proto.name,
695 full_name=desc_name,
696 filename=file_name,
697 containing_type=None,
698 fields=fields,
699 oneofs=oneofs,
700 nested_types=nested,
701 enum_types=enums,
702 extensions=extensions,
Austin Schuh40c16522018-10-28 20:27:54 -0700703 options=_OptionsOrNone(desc_proto),
Brian Silverman9c614bc2016-02-15 20:20:02 -0500704 is_extendable=is_extendable,
705 extension_ranges=extension_ranges,
706 file=file_desc,
707 serialized_start=None,
708 serialized_end=None,
709 syntax=syntax)
710 for nested in desc.nested_types:
711 nested.containing_type = desc
712 for enum in desc.enum_types:
713 enum.containing_type = desc
714 for field_index, field_desc in enumerate(desc_proto.field):
715 if field_desc.HasField('oneof_index'):
716 oneof_index = field_desc.oneof_index
717 oneofs[oneof_index].fields.append(fields[field_index])
718 fields[field_index].containing_oneof = oneofs[oneof_index]
719
720 scope[_PrefixWithDot(desc_name)] = desc
Austin Schuh40c16522018-10-28 20:27:54 -0700721 self._CheckConflictRegister(desc)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500722 self._descriptors[desc_name] = desc
723 return desc
724
725 def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
726 containing_type=None, scope=None):
727 """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
728
729 Args:
730 enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
731 package: Optional package name for the new message EnumDescriptor.
732 file_desc: The file containing the enum descriptor.
733 containing_type: The type containing this enum.
734 scope: Scope containing available types.
735
736 Returns:
737 The added descriptor
738 """
739
740 if package:
741 enum_name = '.'.join((package, enum_proto.name))
742 else:
743 enum_name = enum_proto.name
744
745 if file_desc is None:
746 file_name = None
747 else:
748 file_name = file_desc.name
749
750 values = [self._MakeEnumValueDescriptor(value, index)
751 for index, value in enumerate(enum_proto.value)]
752 desc = descriptor.EnumDescriptor(name=enum_proto.name,
753 full_name=enum_name,
754 filename=file_name,
755 file=file_desc,
756 values=values,
757 containing_type=containing_type,
Austin Schuh40c16522018-10-28 20:27:54 -0700758 options=_OptionsOrNone(enum_proto))
Brian Silverman9c614bc2016-02-15 20:20:02 -0500759 scope['.%s' % enum_name] = desc
Austin Schuh40c16522018-10-28 20:27:54 -0700760 self._CheckConflictRegister(desc)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500761 self._enum_descriptors[enum_name] = desc
762 return desc
763
764 def _MakeFieldDescriptor(self, field_proto, message_name, index,
Austin Schuh40c16522018-10-28 20:27:54 -0700765 file_desc, is_extension=False):
Brian Silverman9c614bc2016-02-15 20:20:02 -0500766 """Creates a field descriptor from a FieldDescriptorProto.
767
768 For message and enum type fields, this method will do a look up
769 in the pool for the appropriate descriptor for that type. If it
770 is unavailable, it will fall back to the _source function to
771 create it. If this type is still unavailable, construction will
772 fail.
773
774 Args:
775 field_proto: The proto describing the field.
776 message_name: The name of the containing message.
777 index: Index of the field
Austin Schuh40c16522018-10-28 20:27:54 -0700778 file_desc: The file containing the field descriptor.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500779 is_extension: Indication that this field is for an extension.
780
781 Returns:
782 An initialized FieldDescriptor object
783 """
784
785 if message_name:
786 full_name = '.'.join((message_name, field_proto.name))
787 else:
788 full_name = field_proto.name
789
790 return descriptor.FieldDescriptor(
791 name=field_proto.name,
792 full_name=full_name,
793 index=index,
794 number=field_proto.number,
795 type=field_proto.type,
796 cpp_type=None,
797 message_type=None,
798 enum_type=None,
799 containing_type=None,
800 label=field_proto.label,
801 has_default_value=False,
802 default_value=None,
803 is_extension=is_extension,
804 extension_scope=None,
Austin Schuh40c16522018-10-28 20:27:54 -0700805 options=_OptionsOrNone(field_proto),
806 file=file_desc)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500807
808 def _SetAllFieldTypes(self, package, desc_proto, scope):
809 """Sets all the descriptor's fields's types.
810
811 This method also sets the containing types on any extensions.
812
813 Args:
814 package: The current package of desc_proto.
815 desc_proto: The message descriptor to update.
816 scope: Enclosing scope of available types.
817 """
818
819 package = _PrefixWithDot(package)
820
821 main_desc = self._GetTypeFromScope(package, desc_proto.name, scope)
822
823 if package == '.':
824 nested_package = _PrefixWithDot(desc_proto.name)
825 else:
826 nested_package = '.'.join([package, desc_proto.name])
827
828 for field_proto, field_desc in zip(desc_proto.field, main_desc.fields):
829 self._SetFieldType(field_proto, field_desc, nested_package, scope)
830
831 for extension_proto, extension_desc in (
832 zip(desc_proto.extension, main_desc.extensions)):
833 extension_desc.containing_type = self._GetTypeFromScope(
834 nested_package, extension_proto.extendee, scope)
835 self._SetFieldType(extension_proto, extension_desc, nested_package, scope)
836
837 for nested_type in desc_proto.nested_type:
838 self._SetAllFieldTypes(nested_package, nested_type, scope)
839
840 def _SetFieldType(self, field_proto, field_desc, package, scope):
841 """Sets the field's type, cpp_type, message_type and enum_type.
842
843 Args:
844 field_proto: Data about the field in proto format.
845 field_desc: The descriptor to modiy.
846 package: The package the field's container is in.
847 scope: Enclosing scope of available types.
848 """
849 if field_proto.type_name:
850 desc = self._GetTypeFromScope(package, field_proto.type_name, scope)
851 else:
852 desc = None
853
854 if not field_proto.HasField('type'):
855 if isinstance(desc, descriptor.Descriptor):
856 field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
857 else:
858 field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
859
860 field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
861 field_proto.type)
862
863 if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
864 or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
865 field_desc.message_type = desc
866
867 if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
868 field_desc.enum_type = desc
869
870 if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
871 field_desc.has_default_value = False
872 field_desc.default_value = []
873 elif field_proto.HasField('default_value'):
874 field_desc.has_default_value = True
875 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
876 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
877 field_desc.default_value = float(field_proto.default_value)
878 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
879 field_desc.default_value = field_proto.default_value
880 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
881 field_desc.default_value = field_proto.default_value.lower() == 'true'
882 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
883 field_desc.default_value = field_desc.enum_type.values_by_name[
884 field_proto.default_value].number
885 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
886 field_desc.default_value = text_encoding.CUnescape(
887 field_proto.default_value)
888 else:
889 # All other types are of the "int" type.
890 field_desc.default_value = int(field_proto.default_value)
891 else:
892 field_desc.has_default_value = False
893 if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
894 field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
895 field_desc.default_value = 0.0
896 elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
897 field_desc.default_value = u''
898 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
899 field_desc.default_value = False
900 elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
901 field_desc.default_value = field_desc.enum_type.values[0].number
902 elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
903 field_desc.default_value = b''
904 else:
905 # All other types are of the "int" type.
906 field_desc.default_value = 0
907
908 field_desc.type = field_proto.type
909
910 def _MakeEnumValueDescriptor(self, value_proto, index):
911 """Creates a enum value descriptor object from a enum value proto.
912
913 Args:
914 value_proto: The proto describing the enum value.
915 index: The index of the enum value.
916
917 Returns:
918 An initialized EnumValueDescriptor object.
919 """
920
921 return descriptor.EnumValueDescriptor(
922 name=value_proto.name,
923 index=index,
924 number=value_proto.number,
Austin Schuh40c16522018-10-28 20:27:54 -0700925 options=_OptionsOrNone(value_proto),
Brian Silverman9c614bc2016-02-15 20:20:02 -0500926 type=None)
927
Austin Schuh40c16522018-10-28 20:27:54 -0700928 def _MakeServiceDescriptor(self, service_proto, service_index, scope,
929 package, file_desc):
930 """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto.
931
932 Args:
933 service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message.
934 service_index: The index of the service in the File.
935 scope: Dict mapping short and full symbols to message and enum types.
936 package: Optional package name for the new message EnumDescriptor.
937 file_desc: The file containing the service descriptor.
938
939 Returns:
940 The added descriptor.
941 """
942
943 if package:
944 service_name = '.'.join((package, service_proto.name))
945 else:
946 service_name = service_proto.name
947
948 methods = [self._MakeMethodDescriptor(method_proto, service_name, package,
949 scope, index)
950 for index, method_proto in enumerate(service_proto.method)]
951 desc = descriptor.ServiceDescriptor(name=service_proto.name,
952 full_name=service_name,
953 index=service_index,
954 methods=methods,
955 options=_OptionsOrNone(service_proto),
956 file=file_desc)
957 self._CheckConflictRegister(desc)
958 self._service_descriptors[service_name] = desc
959 return desc
960
961 def _MakeMethodDescriptor(self, method_proto, service_name, package, scope,
962 index):
963 """Creates a method descriptor from a MethodDescriptorProto.
964
965 Args:
966 method_proto: The proto describing the method.
967 service_name: The name of the containing service.
968 package: Optional package name to look up for types.
969 scope: Scope containing available types.
970 index: Index of the method in the service.
971
972 Returns:
973 An initialized MethodDescriptor object.
974 """
975 full_name = '.'.join((service_name, method_proto.name))
976 input_type = self._GetTypeFromScope(
977 package, method_proto.input_type, scope)
978 output_type = self._GetTypeFromScope(
979 package, method_proto.output_type, scope)
980 return descriptor.MethodDescriptor(name=method_proto.name,
981 full_name=full_name,
982 index=index,
983 containing_service=None,
984 input_type=input_type,
985 output_type=output_type,
986 options=_OptionsOrNone(method_proto))
987
Brian Silverman9c614bc2016-02-15 20:20:02 -0500988 def _ExtractSymbols(self, descriptors):
989 """Pulls out all the symbols from descriptor protos.
990
991 Args:
992 descriptors: The messages to extract descriptors from.
993 Yields:
994 A two element tuple of the type name and descriptor object.
995 """
996
997 for desc in descriptors:
998 yield (_PrefixWithDot(desc.full_name), desc)
999 for symbol in self._ExtractSymbols(desc.nested_types):
1000 yield symbol
1001 for enum in desc.enum_types:
1002 yield (_PrefixWithDot(enum.full_name), enum)
1003
1004 def _GetDeps(self, dependencies):
1005 """Recursively finds dependencies for file protos.
1006
1007 Args:
1008 dependencies: The names of the files being depended on.
1009
1010 Yields:
1011 Each direct and indirect dependency.
1012 """
1013
1014 for dependency in dependencies:
1015 dep_desc = self.FindFileByName(dependency)
1016 yield dep_desc
1017 for parent_dep in dep_desc.dependencies:
1018 yield parent_dep
1019
1020 def _GetTypeFromScope(self, package, type_name, scope):
1021 """Finds a given type name in the current scope.
1022
1023 Args:
1024 package: The package the proto should be located in.
1025 type_name: The name of the type to be found in the scope.
1026 scope: Dict mapping short and full symbols to message and enum types.
1027
1028 Returns:
1029 The descriptor for the requested type.
1030 """
1031 if type_name not in scope:
1032 components = _PrefixWithDot(package).split('.')
1033 while components:
1034 possible_match = '.'.join(components + [type_name])
1035 if possible_match in scope:
1036 type_name = possible_match
1037 break
1038 else:
1039 components.pop(-1)
1040 return scope[type_name]
1041
1042
1043def _PrefixWithDot(name):
1044 return name if name.startswith('.') else '.%s' % name
1045
1046
1047if _USE_C_DESCRIPTORS:
1048 # TODO(amauryfa): This pool could be constructed from Python code, when we
1049 # support a flag like 'use_cpp_generated_pool=True'.
1050 # pylint: disable=protected-access
1051 _DEFAULT = descriptor._message.default_pool
1052else:
1053 _DEFAULT = DescriptorPool()
1054
1055
1056def Default():
1057 return _DEFAULT