blob: 5ad869f49a76da39112a8e7ad30857f9837e53ed [file] [log] [blame]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001# Protocol Buffers - Google's data interchange format
2# Copyright 2008 Google Inc. All rights reserved.
3# https://developers.google.com/protocol-buffers/
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions are
7# met:
8#
9# * Redistributions of source code must retain the above copyright
10# notice, this list of conditions and the following disclaimer.
11# * Redistributions in binary form must reproduce the above
12# copyright notice, this list of conditions and the following disclaimer
13# in the documentation and/or other materials provided with the
14# distribution.
15# * Neither the name of Google Inc. nor the names of its
16# contributors may be used to endorse or promote products derived from
17# this software without specific prior written permission.
18#
19# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31"""A database of Python protocol buffer generated symbols.
32
Austin Schuh40c16522018-10-28 20:27:54 -070033SymbolDatabase is the MessageFactory for messages generated at compile time,
34and makes it easy to create new instances of a registered type, given only the
35type's protocol buffer symbol name.
Brian Silverman9c614bc2016-02-15 20:20:02 -050036
37Example usage:
38
39 db = symbol_database.SymbolDatabase()
40
41 # Register symbols of interest, from one or multiple files.
42 db.RegisterFileDescriptor(my_proto_pb2.DESCRIPTOR)
43 db.RegisterMessage(my_proto_pb2.MyMessage)
44 db.RegisterEnumDescriptor(my_proto_pb2.MyEnum.DESCRIPTOR)
45
46 # The database can be used as a MessageFactory, to generate types based on
47 # their name:
48 types = db.GetMessages(['my_proto.proto'])
49 my_message_instance = types['MyMessage']()
50
51 # The database's underlying descriptor pool can be queried, so it's not
52 # necessary to know a type's filename to be able to generate it:
53 filename = db.pool.FindFileContainingSymbol('MyMessage')
54 my_message_instance = db.GetMessages([filename])['MyMessage']()
55
56 # This functionality is also provided directly via a convenience method:
57 my_message_instance = db.GetSymbol('MyMessage')()
58"""
59
60
61from google.protobuf import descriptor_pool
Austin Schuh40c16522018-10-28 20:27:54 -070062from google.protobuf import message_factory
Brian Silverman9c614bc2016-02-15 20:20:02 -050063
64
Austin Schuh40c16522018-10-28 20:27:54 -070065class SymbolDatabase(message_factory.MessageFactory):
66 """A database of Python generated symbols."""
Brian Silverman9c614bc2016-02-15 20:20:02 -050067
68 def RegisterMessage(self, message):
69 """Registers the given message type in the local database.
70
Austin Schuh40c16522018-10-28 20:27:54 -070071 Calls to GetSymbol() and GetMessages() will return messages registered here.
72
Brian Silverman9c614bc2016-02-15 20:20:02 -050073 Args:
74 message: a message.Message, to be registered.
75
76 Returns:
77 The provided message.
78 """
79
80 desc = message.DESCRIPTOR
Austin Schuh40c16522018-10-28 20:27:54 -070081 self._classes[desc] = message
82 self.RegisterMessageDescriptor(desc)
Brian Silverman9c614bc2016-02-15 20:20:02 -050083 return message
84
Austin Schuh40c16522018-10-28 20:27:54 -070085 def RegisterMessageDescriptor(self, message_descriptor):
86 """Registers the given message descriptor in the local database.
87
88 Args:
89 message_descriptor: a descriptor.MessageDescriptor.
90 """
91 self.pool.AddDescriptor(message_descriptor)
92
Brian Silverman9c614bc2016-02-15 20:20:02 -050093 def RegisterEnumDescriptor(self, enum_descriptor):
94 """Registers the given enum descriptor in the local database.
95
96 Args:
97 enum_descriptor: a descriptor.EnumDescriptor.
98
99 Returns:
100 The provided descriptor.
101 """
102 self.pool.AddEnumDescriptor(enum_descriptor)
103 return enum_descriptor
104
Austin Schuh40c16522018-10-28 20:27:54 -0700105 def RegisterServiceDescriptor(self, service_descriptor):
106 """Registers the given service descriptor in the local database.
107
108 Args:
109 service_descriptor: a descriptor.ServiceDescriptor.
110
111 Returns:
112 The provided descriptor.
113 """
114 self.pool.AddServiceDescriptor(service_descriptor)
115
Brian Silverman9c614bc2016-02-15 20:20:02 -0500116 def RegisterFileDescriptor(self, file_descriptor):
117 """Registers the given file descriptor in the local database.
118
119 Args:
120 file_descriptor: a descriptor.FileDescriptor.
121
122 Returns:
123 The provided descriptor.
124 """
125 self.pool.AddFileDescriptor(file_descriptor)
126
127 def GetSymbol(self, symbol):
128 """Tries to find a symbol in the local database.
129
130 Currently, this method only returns message.Message instances, however, if
131 may be extended in future to support other symbol types.
132
133 Args:
134 symbol: A str, a protocol buffer symbol.
135
136 Returns:
137 A Python class corresponding to the symbol.
138
139 Raises:
140 KeyError: if the symbol could not be found.
141 """
142
Austin Schuh40c16522018-10-28 20:27:54 -0700143 return self._classes[self.pool.FindMessageTypeByName(symbol)]
Brian Silverman9c614bc2016-02-15 20:20:02 -0500144
145 def GetMessages(self, files):
Austin Schuh40c16522018-10-28 20:27:54 -0700146 # TODO(amauryfa): Fix the differences with MessageFactory.
147 """Gets all registered messages from a specified file.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500148
Austin Schuh40c16522018-10-28 20:27:54 -0700149 Only messages already created and registered will be returned; (this is the
150 case for imported _pb2 modules)
151 But unlike MessageFactory, this version also returns already defined nested
152 messages, but does not register any message extensions.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500153
154 Args:
155 files: The file names to extract messages from.
156
157 Returns:
Austin Schuh40c16522018-10-28 20:27:54 -0700158 A dictionary mapping proto names to the message classes.
Brian Silverman9c614bc2016-02-15 20:20:02 -0500159
160 Raises:
161 KeyError: if a file could not be found.
162 """
163
Austin Schuh40c16522018-10-28 20:27:54 -0700164 def _GetAllMessages(desc):
165 """Walk a message Descriptor and recursively yields all message names."""
166 yield desc
167 for msg_desc in desc.nested_types:
168 for nested_desc in _GetAllMessages(msg_desc):
169 yield nested_desc
170
Brian Silverman9c614bc2016-02-15 20:20:02 -0500171 result = {}
Austin Schuh40c16522018-10-28 20:27:54 -0700172 for file_name in files:
173 file_desc = self.pool.FindFileByName(file_name)
174 for msg_desc in file_desc.message_types_by_name.values():
175 for desc in _GetAllMessages(msg_desc):
176 try:
177 result[desc.full_name] = self._classes[desc]
178 except KeyError:
179 # This descriptor has no registered class, skip it.
180 pass
Brian Silverman9c614bc2016-02-15 20:20:02 -0500181 return result
182
Austin Schuh40c16522018-10-28 20:27:54 -0700183
Brian Silverman9c614bc2016-02-15 20:20:02 -0500184_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default())
185
186
187def Default():
188 """Returns the default SymbolDatabase."""
189 return _DEFAULT