blob: d03f2d25db50e523c4fd20b7603592e6b84ec2f0 [file] [log] [blame]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001#! /usr/bin/env python
2#
3# Protocol Buffers - Google's data interchange format
4# Copyright 2008 Google Inc. All rights reserved.
5# https://developers.google.com/protocol-buffers/
6#
7# Redistribution and use in source and binary forms, with or without
8# modification, are permitted provided that the following conditions are
9# met:
10#
11# * Redistributions of source code must retain the above copyright
12# notice, this list of conditions and the following disclaimer.
13# * Redistributions in binary form must reproduce the above
14# copyright notice, this list of conditions and the following disclaimer
15# in the documentation and/or other materials provided with the
16# distribution.
17# * Neither the name of Google Inc. nor the names of its
18# contributors may be used to endorse or promote products derived from
19# this software without specific prior written permission.
20#
21# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33"""Tests python protocol buffers against the golden message.
34
35Note that the golden messages exercise every known field type, thus this
36test ends up exercising and verifying nearly all of the parsing and
37serialization code in the whole library.
38
39TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of
40sense to call this a test of the "message" module, which only declares an
41abstract interface.
42"""
43
44__author__ = 'gps@google.com (Gregory P. Smith)'
45
46
47import collections
48import copy
49import math
50import operator
51import pickle
52import six
53import sys
54
55try:
56 import unittest2 as unittest
57except ImportError:
58 import unittest
59from google.protobuf.internal import _parameterized
60from google.protobuf import map_unittest_pb2
61from google.protobuf import unittest_pb2
62from google.protobuf import unittest_proto3_arena_pb2
63from google.protobuf.internal import any_test_pb2
64from google.protobuf.internal import api_implementation
65from google.protobuf.internal import packed_field_test_pb2
66from google.protobuf.internal import test_util
67from google.protobuf import message
68
69if six.PY3:
70 long = int
71
72# Python pre-2.6 does not have isinf() or isnan() functions, so we have
73# to provide our own.
74def isnan(val):
75 # NaN is never equal to itself.
76 return val != val
77def isinf(val):
78 # Infinity times zero equals NaN.
79 return not isnan(val) and isnan(val * 0)
80def IsPosInf(val):
81 return isinf(val) and (val > 0)
82def IsNegInf(val):
83 return isinf(val) and (val < 0)
84
85
86@_parameterized.Parameters(
87 (unittest_pb2),
88 (unittest_proto3_arena_pb2))
89class MessageTest(unittest.TestCase):
90
91 def testBadUtf8String(self, message_module):
92 if api_implementation.Type() != 'python':
93 self.skipTest("Skipping testBadUtf8String, currently only the python "
94 "api implementation raises UnicodeDecodeError when a "
95 "string field contains bad utf-8.")
96 bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
97 with self.assertRaises(UnicodeDecodeError) as context:
98 message_module.TestAllTypes.FromString(bad_utf8_data)
99 self.assertIn('TestAllTypes.optional_string', str(context.exception))
100
101 def testGoldenMessage(self, message_module):
102 # Proto3 doesn't have the "default_foo" members or foreign enums,
103 # and doesn't preserve unknown fields, so for proto3 we use a golden
104 # message that doesn't have these fields set.
105 if message_module is unittest_pb2:
106 golden_data = test_util.GoldenFileData(
107 'golden_message_oneof_implemented')
108 else:
109 golden_data = test_util.GoldenFileData('golden_message_proto3')
110
111 golden_message = message_module.TestAllTypes()
112 golden_message.ParseFromString(golden_data)
113 if message_module is unittest_pb2:
114 test_util.ExpectAllFieldsSet(self, golden_message)
115 self.assertEqual(golden_data, golden_message.SerializeToString())
116 golden_copy = copy.deepcopy(golden_message)
117 self.assertEqual(golden_data, golden_copy.SerializeToString())
118
119 def testGoldenPackedMessage(self, message_module):
120 golden_data = test_util.GoldenFileData('golden_packed_fields_message')
121 golden_message = message_module.TestPackedTypes()
122 golden_message.ParseFromString(golden_data)
123 all_set = message_module.TestPackedTypes()
124 test_util.SetAllPackedFields(all_set)
125 self.assertEqual(all_set, golden_message)
126 self.assertEqual(golden_data, all_set.SerializeToString())
127 golden_copy = copy.deepcopy(golden_message)
128 self.assertEqual(golden_data, golden_copy.SerializeToString())
129
130 def testPickleSupport(self, message_module):
131 golden_data = test_util.GoldenFileData('golden_message')
132 golden_message = message_module.TestAllTypes()
133 golden_message.ParseFromString(golden_data)
134 pickled_message = pickle.dumps(golden_message)
135
136 unpickled_message = pickle.loads(pickled_message)
137 self.assertEqual(unpickled_message, golden_message)
138
139 def testPositiveInfinity(self, message_module):
140 if message_module is unittest_pb2:
141 golden_data = (b'\x5D\x00\x00\x80\x7F'
142 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
143 b'\xCD\x02\x00\x00\x80\x7F'
144 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
145 else:
146 golden_data = (b'\x5D\x00\x00\x80\x7F'
147 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
148 b'\xCA\x02\x04\x00\x00\x80\x7F'
149 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
150
151 golden_message = message_module.TestAllTypes()
152 golden_message.ParseFromString(golden_data)
153 self.assertTrue(IsPosInf(golden_message.optional_float))
154 self.assertTrue(IsPosInf(golden_message.optional_double))
155 self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
156 self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
157 self.assertEqual(golden_data, golden_message.SerializeToString())
158
159 def testNegativeInfinity(self, message_module):
160 if message_module is unittest_pb2:
161 golden_data = (b'\x5D\x00\x00\x80\xFF'
162 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
163 b'\xCD\x02\x00\x00\x80\xFF'
164 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
165 else:
166 golden_data = (b'\x5D\x00\x00\x80\xFF'
167 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
168 b'\xCA\x02\x04\x00\x00\x80\xFF'
169 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
170
171 golden_message = message_module.TestAllTypes()
172 golden_message.ParseFromString(golden_data)
173 self.assertTrue(IsNegInf(golden_message.optional_float))
174 self.assertTrue(IsNegInf(golden_message.optional_double))
175 self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
176 self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
177 self.assertEqual(golden_data, golden_message.SerializeToString())
178
179 def testNotANumber(self, message_module):
180 golden_data = (b'\x5D\x00\x00\xC0\x7F'
181 b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
182 b'\xCD\x02\x00\x00\xC0\x7F'
183 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
184 golden_message = message_module.TestAllTypes()
185 golden_message.ParseFromString(golden_data)
186 self.assertTrue(isnan(golden_message.optional_float))
187 self.assertTrue(isnan(golden_message.optional_double))
188 self.assertTrue(isnan(golden_message.repeated_float[0]))
189 self.assertTrue(isnan(golden_message.repeated_double[0]))
190
191 # The protocol buffer may serialize to any one of multiple different
192 # representations of a NaN. Rather than verify a specific representation,
193 # verify the serialized string can be converted into a correctly
194 # behaving protocol buffer.
195 serialized = golden_message.SerializeToString()
196 message = message_module.TestAllTypes()
197 message.ParseFromString(serialized)
198 self.assertTrue(isnan(message.optional_float))
199 self.assertTrue(isnan(message.optional_double))
200 self.assertTrue(isnan(message.repeated_float[0]))
201 self.assertTrue(isnan(message.repeated_double[0]))
202
203 def testPositiveInfinityPacked(self, message_module):
204 golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
205 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
206 golden_message = message_module.TestPackedTypes()
207 golden_message.ParseFromString(golden_data)
208 self.assertTrue(IsPosInf(golden_message.packed_float[0]))
209 self.assertTrue(IsPosInf(golden_message.packed_double[0]))
210 self.assertEqual(golden_data, golden_message.SerializeToString())
211
212 def testNegativeInfinityPacked(self, message_module):
213 golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
214 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
215 golden_message = message_module.TestPackedTypes()
216 golden_message.ParseFromString(golden_data)
217 self.assertTrue(IsNegInf(golden_message.packed_float[0]))
218 self.assertTrue(IsNegInf(golden_message.packed_double[0]))
219 self.assertEqual(golden_data, golden_message.SerializeToString())
220
221 def testNotANumberPacked(self, message_module):
222 golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
223 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
224 golden_message = message_module.TestPackedTypes()
225 golden_message.ParseFromString(golden_data)
226 self.assertTrue(isnan(golden_message.packed_float[0]))
227 self.assertTrue(isnan(golden_message.packed_double[0]))
228
229 serialized = golden_message.SerializeToString()
230 message = message_module.TestPackedTypes()
231 message.ParseFromString(serialized)
232 self.assertTrue(isnan(message.packed_float[0]))
233 self.assertTrue(isnan(message.packed_double[0]))
234
235 def testExtremeFloatValues(self, message_module):
236 message = message_module.TestAllTypes()
237
238 # Most positive exponent, no significand bits set.
239 kMostPosExponentNoSigBits = math.pow(2, 127)
240 message.optional_float = kMostPosExponentNoSigBits
241 message.ParseFromString(message.SerializeToString())
242 self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
243
244 # Most positive exponent, one significand bit set.
245 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
246 message.optional_float = kMostPosExponentOneSigBit
247 message.ParseFromString(message.SerializeToString())
248 self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
249
250 # Repeat last two cases with values of same magnitude, but negative.
251 message.optional_float = -kMostPosExponentNoSigBits
252 message.ParseFromString(message.SerializeToString())
253 self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
254
255 message.optional_float = -kMostPosExponentOneSigBit
256 message.ParseFromString(message.SerializeToString())
257 self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
258
259 # Most negative exponent, no significand bits set.
260 kMostNegExponentNoSigBits = math.pow(2, -127)
261 message.optional_float = kMostNegExponentNoSigBits
262 message.ParseFromString(message.SerializeToString())
263 self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
264
265 # Most negative exponent, one significand bit set.
266 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
267 message.optional_float = kMostNegExponentOneSigBit
268 message.ParseFromString(message.SerializeToString())
269 self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
270
271 # Repeat last two cases with values of the same magnitude, but negative.
272 message.optional_float = -kMostNegExponentNoSigBits
273 message.ParseFromString(message.SerializeToString())
274 self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
275
276 message.optional_float = -kMostNegExponentOneSigBit
277 message.ParseFromString(message.SerializeToString())
278 self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
279
280 def testExtremeDoubleValues(self, message_module):
281 message = message_module.TestAllTypes()
282
283 # Most positive exponent, no significand bits set.
284 kMostPosExponentNoSigBits = math.pow(2, 1023)
285 message.optional_double = kMostPosExponentNoSigBits
286 message.ParseFromString(message.SerializeToString())
287 self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
288
289 # Most positive exponent, one significand bit set.
290 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
291 message.optional_double = kMostPosExponentOneSigBit
292 message.ParseFromString(message.SerializeToString())
293 self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
294
295 # Repeat last two cases with values of same magnitude, but negative.
296 message.optional_double = -kMostPosExponentNoSigBits
297 message.ParseFromString(message.SerializeToString())
298 self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
299
300 message.optional_double = -kMostPosExponentOneSigBit
301 message.ParseFromString(message.SerializeToString())
302 self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
303
304 # Most negative exponent, no significand bits set.
305 kMostNegExponentNoSigBits = math.pow(2, -1023)
306 message.optional_double = kMostNegExponentNoSigBits
307 message.ParseFromString(message.SerializeToString())
308 self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
309
310 # Most negative exponent, one significand bit set.
311 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
312 message.optional_double = kMostNegExponentOneSigBit
313 message.ParseFromString(message.SerializeToString())
314 self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
315
316 # Repeat last two cases with values of the same magnitude, but negative.
317 message.optional_double = -kMostNegExponentNoSigBits
318 message.ParseFromString(message.SerializeToString())
319 self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
320
321 message.optional_double = -kMostNegExponentOneSigBit
322 message.ParseFromString(message.SerializeToString())
323 self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
324
325 def testFloatPrinting(self, message_module):
326 message = message_module.TestAllTypes()
327 message.optional_float = 2.0
328 self.assertEqual(str(message), 'optional_float: 2.0\n')
329
330 def testHighPrecisionFloatPrinting(self, message_module):
331 message = message_module.TestAllTypes()
332 message.optional_double = 0.12345678912345678
333 if sys.version_info >= (3,):
334 self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
335 else:
336 self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
337
338 def testUnknownFieldPrinting(self, message_module):
339 populated = message_module.TestAllTypes()
340 test_util.SetAllNonLazyFields(populated)
341 empty = message_module.TestEmptyMessage()
342 empty.ParseFromString(populated.SerializeToString())
343 self.assertEqual(str(empty), '')
344
345 def testRepeatedNestedFieldIteration(self, message_module):
346 msg = message_module.TestAllTypes()
347 msg.repeated_nested_message.add(bb=1)
348 msg.repeated_nested_message.add(bb=2)
349 msg.repeated_nested_message.add(bb=3)
350 msg.repeated_nested_message.add(bb=4)
351
352 self.assertEqual([1, 2, 3, 4],
353 [m.bb for m in msg.repeated_nested_message])
354 self.assertEqual([4, 3, 2, 1],
355 [m.bb for m in reversed(msg.repeated_nested_message)])
356 self.assertEqual([4, 3, 2, 1],
357 [m.bb for m in msg.repeated_nested_message[::-1]])
358
359 def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
360 """Check some different types with the default comparator."""
361 message = message_module.TestAllTypes()
362
363 # TODO(mattp): would testing more scalar types strengthen test?
364 message.repeated_int32.append(1)
365 message.repeated_int32.append(3)
366 message.repeated_int32.append(2)
367 message.repeated_int32.sort()
368 self.assertEqual(message.repeated_int32[0], 1)
369 self.assertEqual(message.repeated_int32[1], 2)
370 self.assertEqual(message.repeated_int32[2], 3)
371
372 message.repeated_float.append(1.1)
373 message.repeated_float.append(1.3)
374 message.repeated_float.append(1.2)
375 message.repeated_float.sort()
376 self.assertAlmostEqual(message.repeated_float[0], 1.1)
377 self.assertAlmostEqual(message.repeated_float[1], 1.2)
378 self.assertAlmostEqual(message.repeated_float[2], 1.3)
379
380 message.repeated_string.append('a')
381 message.repeated_string.append('c')
382 message.repeated_string.append('b')
383 message.repeated_string.sort()
384 self.assertEqual(message.repeated_string[0], 'a')
385 self.assertEqual(message.repeated_string[1], 'b')
386 self.assertEqual(message.repeated_string[2], 'c')
387
388 message.repeated_bytes.append(b'a')
389 message.repeated_bytes.append(b'c')
390 message.repeated_bytes.append(b'b')
391 message.repeated_bytes.sort()
392 self.assertEqual(message.repeated_bytes[0], b'a')
393 self.assertEqual(message.repeated_bytes[1], b'b')
394 self.assertEqual(message.repeated_bytes[2], b'c')
395
396 def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
397 """Check some different types with custom comparator."""
398 message = message_module.TestAllTypes()
399
400 message.repeated_int32.append(-3)
401 message.repeated_int32.append(-2)
402 message.repeated_int32.append(-1)
403 message.repeated_int32.sort(key=abs)
404 self.assertEqual(message.repeated_int32[0], -1)
405 self.assertEqual(message.repeated_int32[1], -2)
406 self.assertEqual(message.repeated_int32[2], -3)
407
408 message.repeated_string.append('aaa')
409 message.repeated_string.append('bb')
410 message.repeated_string.append('c')
411 message.repeated_string.sort(key=len)
412 self.assertEqual(message.repeated_string[0], 'c')
413 self.assertEqual(message.repeated_string[1], 'bb')
414 self.assertEqual(message.repeated_string[2], 'aaa')
415
416 def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
417 """Check passing a custom comparator to sort a repeated composite field."""
418 message = message_module.TestAllTypes()
419
420 message.repeated_nested_message.add().bb = 1
421 message.repeated_nested_message.add().bb = 3
422 message.repeated_nested_message.add().bb = 2
423 message.repeated_nested_message.add().bb = 6
424 message.repeated_nested_message.add().bb = 5
425 message.repeated_nested_message.add().bb = 4
426 message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
427 self.assertEqual(message.repeated_nested_message[0].bb, 1)
428 self.assertEqual(message.repeated_nested_message[1].bb, 2)
429 self.assertEqual(message.repeated_nested_message[2].bb, 3)
430 self.assertEqual(message.repeated_nested_message[3].bb, 4)
431 self.assertEqual(message.repeated_nested_message[4].bb, 5)
432 self.assertEqual(message.repeated_nested_message[5].bb, 6)
433
434 def testSortingRepeatedCompositeFieldsStable(self, message_module):
435 """Check passing a custom comparator to sort a repeated composite field."""
436 message = message_module.TestAllTypes()
437
438 message.repeated_nested_message.add().bb = 21
439 message.repeated_nested_message.add().bb = 20
440 message.repeated_nested_message.add().bb = 13
441 message.repeated_nested_message.add().bb = 33
442 message.repeated_nested_message.add().bb = 11
443 message.repeated_nested_message.add().bb = 24
444 message.repeated_nested_message.add().bb = 10
445 message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
446 self.assertEqual(
447 [13, 11, 10, 21, 20, 24, 33],
448 [n.bb for n in message.repeated_nested_message])
449
450 # Make sure that for the C++ implementation, the underlying fields
451 # are actually reordered.
452 pb = message.SerializeToString()
453 message.Clear()
454 message.MergeFromString(pb)
455 self.assertEqual(
456 [13, 11, 10, 21, 20, 24, 33],
457 [n.bb for n in message.repeated_nested_message])
458
459 def testRepeatedCompositeFieldSortArguments(self, message_module):
460 """Check sorting a repeated composite field using list.sort() arguments."""
461 message = message_module.TestAllTypes()
462
463 get_bb = operator.attrgetter('bb')
464 cmp_bb = lambda a, b: cmp(a.bb, b.bb)
465 message.repeated_nested_message.add().bb = 1
466 message.repeated_nested_message.add().bb = 3
467 message.repeated_nested_message.add().bb = 2
468 message.repeated_nested_message.add().bb = 6
469 message.repeated_nested_message.add().bb = 5
470 message.repeated_nested_message.add().bb = 4
471 message.repeated_nested_message.sort(key=get_bb)
472 self.assertEqual([k.bb for k in message.repeated_nested_message],
473 [1, 2, 3, 4, 5, 6])
474 message.repeated_nested_message.sort(key=get_bb, reverse=True)
475 self.assertEqual([k.bb for k in message.repeated_nested_message],
476 [6, 5, 4, 3, 2, 1])
477 if sys.version_info >= (3,): return # No cmp sorting in PY3.
478 message.repeated_nested_message.sort(sort_function=cmp_bb)
479 self.assertEqual([k.bb for k in message.repeated_nested_message],
480 [1, 2, 3, 4, 5, 6])
481 message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
482 self.assertEqual([k.bb for k in message.repeated_nested_message],
483 [6, 5, 4, 3, 2, 1])
484
485 def testRepeatedScalarFieldSortArguments(self, message_module):
486 """Check sorting a scalar field using list.sort() arguments."""
487 message = message_module.TestAllTypes()
488
489 message.repeated_int32.append(-3)
490 message.repeated_int32.append(-2)
491 message.repeated_int32.append(-1)
492 message.repeated_int32.sort(key=abs)
493 self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
494 message.repeated_int32.sort(key=abs, reverse=True)
495 self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
496 if sys.version_info < (3,): # No cmp sorting in PY3.
497 abs_cmp = lambda a, b: cmp(abs(a), abs(b))
498 message.repeated_int32.sort(sort_function=abs_cmp)
499 self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
500 message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
501 self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
502
503 message.repeated_string.append('aaa')
504 message.repeated_string.append('bb')
505 message.repeated_string.append('c')
506 message.repeated_string.sort(key=len)
507 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
508 message.repeated_string.sort(key=len, reverse=True)
509 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
510 if sys.version_info < (3,): # No cmp sorting in PY3.
511 len_cmp = lambda a, b: cmp(len(a), len(b))
512 message.repeated_string.sort(sort_function=len_cmp)
513 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
514 message.repeated_string.sort(cmp=len_cmp, reverse=True)
515 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
516
517 def testRepeatedFieldsComparable(self, message_module):
518 m1 = message_module.TestAllTypes()
519 m2 = message_module.TestAllTypes()
520 m1.repeated_int32.append(0)
521 m1.repeated_int32.append(1)
522 m1.repeated_int32.append(2)
523 m2.repeated_int32.append(0)
524 m2.repeated_int32.append(1)
525 m2.repeated_int32.append(2)
526 m1.repeated_nested_message.add().bb = 1
527 m1.repeated_nested_message.add().bb = 2
528 m1.repeated_nested_message.add().bb = 3
529 m2.repeated_nested_message.add().bb = 1
530 m2.repeated_nested_message.add().bb = 2
531 m2.repeated_nested_message.add().bb = 3
532
533 if sys.version_info >= (3,): return # No cmp() in PY3.
534
535 # These comparisons should not raise errors.
536 _ = m1 < m2
537 _ = m1.repeated_nested_message < m2.repeated_nested_message
538
539 # Make sure cmp always works. If it wasn't defined, these would be
540 # id() comparisons and would all fail.
541 self.assertEqual(cmp(m1, m2), 0)
542 self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
543 self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
544 self.assertEqual(cmp(m1.repeated_nested_message,
545 m2.repeated_nested_message), 0)
546 with self.assertRaises(TypeError):
547 # Can't compare repeated composite containers to lists.
548 cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
549
550 # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
551
552 def testRepeatedFieldsAreSequences(self, message_module):
553 m = message_module.TestAllTypes()
554 self.assertIsInstance(m.repeated_int32, collections.MutableSequence)
555 self.assertIsInstance(m.repeated_nested_message,
556 collections.MutableSequence)
557
558 def ensureNestedMessageExists(self, msg, attribute):
559 """Make sure that a nested message object exists.
560
561 As soon as a nested message attribute is accessed, it will be present in the
562 _fields dict, without being marked as actually being set.
563 """
564 getattr(msg, attribute)
565 self.assertFalse(msg.HasField(attribute))
566
567 def testOneofGetCaseNonexistingField(self, message_module):
568 m = message_module.TestAllTypes()
569 self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
570
571 def testOneofDefaultValues(self, message_module):
572 m = message_module.TestAllTypes()
573 self.assertIs(None, m.WhichOneof('oneof_field'))
574 self.assertFalse(m.HasField('oneof_uint32'))
575
576 # Oneof is set even when setting it to a default value.
577 m.oneof_uint32 = 0
578 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
579 self.assertTrue(m.HasField('oneof_uint32'))
580 self.assertFalse(m.HasField('oneof_string'))
581
582 m.oneof_string = ""
583 self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
584 self.assertTrue(m.HasField('oneof_string'))
585 self.assertFalse(m.HasField('oneof_uint32'))
586
587 def testOneofSemantics(self, message_module):
588 m = message_module.TestAllTypes()
589 self.assertIs(None, m.WhichOneof('oneof_field'))
590
591 m.oneof_uint32 = 11
592 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
593 self.assertTrue(m.HasField('oneof_uint32'))
594
595 m.oneof_string = u'foo'
596 self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
597 self.assertFalse(m.HasField('oneof_uint32'))
598 self.assertTrue(m.HasField('oneof_string'))
599
600 # Read nested message accessor without accessing submessage.
601 m.oneof_nested_message
602 self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
603 self.assertTrue(m.HasField('oneof_string'))
604 self.assertFalse(m.HasField('oneof_nested_message'))
605
606 # Read accessor of nested message without accessing submessage.
607 m.oneof_nested_message.bb
608 self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
609 self.assertTrue(m.HasField('oneof_string'))
610 self.assertFalse(m.HasField('oneof_nested_message'))
611
612 m.oneof_nested_message.bb = 11
613 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
614 self.assertFalse(m.HasField('oneof_string'))
615 self.assertTrue(m.HasField('oneof_nested_message'))
616
617 m.oneof_bytes = b'bb'
618 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
619 self.assertFalse(m.HasField('oneof_nested_message'))
620 self.assertTrue(m.HasField('oneof_bytes'))
621
622 def testOneofCompositeFieldReadAccess(self, message_module):
623 m = message_module.TestAllTypes()
624 m.oneof_uint32 = 11
625
626 self.ensureNestedMessageExists(m, 'oneof_nested_message')
627 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
628 self.assertEqual(11, m.oneof_uint32)
629
630 def testOneofWhichOneof(self, message_module):
631 m = message_module.TestAllTypes()
632 self.assertIs(None, m.WhichOneof('oneof_field'))
633 if message_module is unittest_pb2:
634 self.assertFalse(m.HasField('oneof_field'))
635
636 m.oneof_uint32 = 11
637 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
638 if message_module is unittest_pb2:
639 self.assertTrue(m.HasField('oneof_field'))
640
641 m.oneof_bytes = b'bb'
642 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
643
644 m.ClearField('oneof_bytes')
645 self.assertIs(None, m.WhichOneof('oneof_field'))
646 if message_module is unittest_pb2:
647 self.assertFalse(m.HasField('oneof_field'))
648
649 def testOneofClearField(self, message_module):
650 m = message_module.TestAllTypes()
651 m.oneof_uint32 = 11
652 m.ClearField('oneof_field')
653 if message_module is unittest_pb2:
654 self.assertFalse(m.HasField('oneof_field'))
655 self.assertFalse(m.HasField('oneof_uint32'))
656 self.assertIs(None, m.WhichOneof('oneof_field'))
657
658 def testOneofClearSetField(self, message_module):
659 m = message_module.TestAllTypes()
660 m.oneof_uint32 = 11
661 m.ClearField('oneof_uint32')
662 if message_module is unittest_pb2:
663 self.assertFalse(m.HasField('oneof_field'))
664 self.assertFalse(m.HasField('oneof_uint32'))
665 self.assertIs(None, m.WhichOneof('oneof_field'))
666
667 def testOneofClearUnsetField(self, message_module):
668 m = message_module.TestAllTypes()
669 m.oneof_uint32 = 11
670 self.ensureNestedMessageExists(m, 'oneof_nested_message')
671 m.ClearField('oneof_nested_message')
672 self.assertEqual(11, m.oneof_uint32)
673 if message_module is unittest_pb2:
674 self.assertTrue(m.HasField('oneof_field'))
675 self.assertTrue(m.HasField('oneof_uint32'))
676 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
677
678 def testOneofDeserialize(self, message_module):
679 m = message_module.TestAllTypes()
680 m.oneof_uint32 = 11
681 m2 = message_module.TestAllTypes()
682 m2.ParseFromString(m.SerializeToString())
683 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
684
685 def testOneofCopyFrom(self, message_module):
686 m = message_module.TestAllTypes()
687 m.oneof_uint32 = 11
688 m2 = message_module.TestAllTypes()
689 m2.CopyFrom(m)
690 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
691
692 def testOneofNestedMergeFrom(self, message_module):
693 m = message_module.NestedTestAllTypes()
694 m.payload.oneof_uint32 = 11
695 m2 = message_module.NestedTestAllTypes()
696 m2.payload.oneof_bytes = b'bb'
697 m2.child.payload.oneof_bytes = b'bb'
698 m2.MergeFrom(m)
699 self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
700 self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
701
702 def testOneofMessageMergeFrom(self, message_module):
703 m = message_module.NestedTestAllTypes()
704 m.payload.oneof_nested_message.bb = 11
705 m.child.payload.oneof_nested_message.bb = 12
706 m2 = message_module.NestedTestAllTypes()
707 m2.payload.oneof_uint32 = 13
708 m2.MergeFrom(m)
709 self.assertEqual('oneof_nested_message',
710 m2.payload.WhichOneof('oneof_field'))
711 self.assertEqual('oneof_nested_message',
712 m2.child.payload.WhichOneof('oneof_field'))
713
714 def testOneofNestedMessageInit(self, message_module):
715 m = message_module.TestAllTypes(
716 oneof_nested_message=message_module.TestAllTypes.NestedMessage())
717 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
718
719 def testOneofClear(self, message_module):
720 m = message_module.TestAllTypes()
721 m.oneof_uint32 = 11
722 m.Clear()
723 self.assertIsNone(m.WhichOneof('oneof_field'))
724 m.oneof_bytes = b'bb'
725 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
726
727 def testAssignByteStringToUnicodeField(self, message_module):
728 """Assigning a byte string to a string field should result
729 in the value being converted to a Unicode string."""
730 m = message_module.TestAllTypes()
731 m.optional_string = str('')
732 self.assertIsInstance(m.optional_string, six.text_type)
733
734 def testLongValuedSlice(self, message_module):
735 """It should be possible to use long-valued indicies in slices
736
737 This didn't used to work in the v2 C++ implementation.
738 """
739 m = message_module.TestAllTypes()
740
741 # Repeated scalar
742 m.repeated_int32.append(1)
743 sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
744 self.assertEqual(len(m.repeated_int32), len(sl))
745
746 # Repeated composite
747 m.repeated_nested_message.add().bb = 3
748 sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
749 self.assertEqual(len(m.repeated_nested_message), len(sl))
750
751 def testExtendShouldNotSwallowExceptions(self, message_module):
752 """This didn't use to work in the v2 C++ implementation."""
753 m = message_module.TestAllTypes()
754 with self.assertRaises(NameError) as _:
755 m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable
756 with self.assertRaises(NameError) as _:
757 m.repeated_nested_enum.extend(
758 a for i in range(10)) # pylint: disable=undefined-variable
759
760 FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
761
762 def testExtendInt32WithNothing(self, message_module):
763 """Test no-ops extending repeated int32 fields."""
764 m = message_module.TestAllTypes()
765 self.assertSequenceEqual([], m.repeated_int32)
766
767 # TODO(ptucker): Deprecate this behavior. b/18413862
768 for falsy_value in MessageTest.FALSY_VALUES:
769 m.repeated_int32.extend(falsy_value)
770 self.assertSequenceEqual([], m.repeated_int32)
771
772 m.repeated_int32.extend([])
773 self.assertSequenceEqual([], m.repeated_int32)
774
775 def testExtendFloatWithNothing(self, message_module):
776 """Test no-ops extending repeated float fields."""
777 m = message_module.TestAllTypes()
778 self.assertSequenceEqual([], m.repeated_float)
779
780 # TODO(ptucker): Deprecate this behavior. b/18413862
781 for falsy_value in MessageTest.FALSY_VALUES:
782 m.repeated_float.extend(falsy_value)
783 self.assertSequenceEqual([], m.repeated_float)
784
785 m.repeated_float.extend([])
786 self.assertSequenceEqual([], m.repeated_float)
787
788 def testExtendStringWithNothing(self, message_module):
789 """Test no-ops extending repeated string fields."""
790 m = message_module.TestAllTypes()
791 self.assertSequenceEqual([], m.repeated_string)
792
793 # TODO(ptucker): Deprecate this behavior. b/18413862
794 for falsy_value in MessageTest.FALSY_VALUES:
795 m.repeated_string.extend(falsy_value)
796 self.assertSequenceEqual([], m.repeated_string)
797
798 m.repeated_string.extend([])
799 self.assertSequenceEqual([], m.repeated_string)
800
801 def testExtendInt32WithPythonList(self, message_module):
802 """Test extending repeated int32 fields with python lists."""
803 m = message_module.TestAllTypes()
804 self.assertSequenceEqual([], m.repeated_int32)
805 m.repeated_int32.extend([0])
806 self.assertSequenceEqual([0], m.repeated_int32)
807 m.repeated_int32.extend([1, 2])
808 self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
809 m.repeated_int32.extend([3, 4])
810 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
811
812 def testExtendFloatWithPythonList(self, message_module):
813 """Test extending repeated float fields with python lists."""
814 m = message_module.TestAllTypes()
815 self.assertSequenceEqual([], m.repeated_float)
816 m.repeated_float.extend([0.0])
817 self.assertSequenceEqual([0.0], m.repeated_float)
818 m.repeated_float.extend([1.0, 2.0])
819 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
820 m.repeated_float.extend([3.0, 4.0])
821 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
822
823 def testExtendStringWithPythonList(self, message_module):
824 """Test extending repeated string fields with python lists."""
825 m = message_module.TestAllTypes()
826 self.assertSequenceEqual([], m.repeated_string)
827 m.repeated_string.extend([''])
828 self.assertSequenceEqual([''], m.repeated_string)
829 m.repeated_string.extend(['11', '22'])
830 self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
831 m.repeated_string.extend(['33', '44'])
832 self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
833
834 def testExtendStringWithString(self, message_module):
835 """Test extending repeated string fields with characters from a string."""
836 m = message_module.TestAllTypes()
837 self.assertSequenceEqual([], m.repeated_string)
838 m.repeated_string.extend('abc')
839 self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
840
841 class TestIterable(object):
842 """This iterable object mimics the behavior of numpy.array.
843
844 __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
845
846 """
847
848 def __init__(self, values=None):
849 self._list = values or []
850
851 def __nonzero__(self):
852 size = len(self._list)
853 if size == 0:
854 return False
855 if size == 1:
856 return bool(self._list[0])
857 raise ValueError('Truth value is ambiguous.')
858
859 def __len__(self):
860 return len(self._list)
861
862 def __iter__(self):
863 return self._list.__iter__()
864
865 def testExtendInt32WithIterable(self, message_module):
866 """Test extending repeated int32 fields with iterable."""
867 m = message_module.TestAllTypes()
868 self.assertSequenceEqual([], m.repeated_int32)
869 m.repeated_int32.extend(MessageTest.TestIterable([]))
870 self.assertSequenceEqual([], m.repeated_int32)
871 m.repeated_int32.extend(MessageTest.TestIterable([0]))
872 self.assertSequenceEqual([0], m.repeated_int32)
873 m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
874 self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
875 m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
876 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
877
878 def testExtendFloatWithIterable(self, message_module):
879 """Test extending repeated float fields with iterable."""
880 m = message_module.TestAllTypes()
881 self.assertSequenceEqual([], m.repeated_float)
882 m.repeated_float.extend(MessageTest.TestIterable([]))
883 self.assertSequenceEqual([], m.repeated_float)
884 m.repeated_float.extend(MessageTest.TestIterable([0.0]))
885 self.assertSequenceEqual([0.0], m.repeated_float)
886 m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
887 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
888 m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
889 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
890
891 def testExtendStringWithIterable(self, message_module):
892 """Test extending repeated string fields with iterable."""
893 m = message_module.TestAllTypes()
894 self.assertSequenceEqual([], m.repeated_string)
895 m.repeated_string.extend(MessageTest.TestIterable([]))
896 self.assertSequenceEqual([], m.repeated_string)
897 m.repeated_string.extend(MessageTest.TestIterable(['']))
898 self.assertSequenceEqual([''], m.repeated_string)
899 m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
900 self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
901 m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
902 self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
903
904 def testPickleRepeatedScalarContainer(self, message_module):
905 # TODO(tibell): The pure-Python implementation support pickling of
906 # scalar containers in *some* cases. For now the cpp2 version
907 # throws an exception to avoid a segfault. Investigate if we
908 # want to support pickling of these fields.
909 #
910 # For more information see: https://b2.corp.google.com/u/0/issues/18677897
911 if (api_implementation.Type() != 'cpp' or
912 api_implementation.Version() == 2):
913 return
914 m = message_module.TestAllTypes()
915 with self.assertRaises(pickle.PickleError) as _:
916 pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
917
918 def testSortEmptyRepeatedCompositeContainer(self, message_module):
919 """Exercise a scenario that has led to segfaults in the past.
920 """
921 m = message_module.TestAllTypes()
922 m.repeated_nested_message.sort()
923
924 def testHasFieldOnRepeatedField(self, message_module):
925 """Using HasField on a repeated field should raise an exception.
926 """
927 m = message_module.TestAllTypes()
928 with self.assertRaises(ValueError) as _:
929 m.HasField('repeated_int32')
930
931 def testRepeatedScalarFieldPop(self, message_module):
932 m = message_module.TestAllTypes()
933 with self.assertRaises(IndexError) as _:
934 m.repeated_int32.pop()
935 m.repeated_int32.extend(range(5))
936 self.assertEqual(4, m.repeated_int32.pop())
937 self.assertEqual(0, m.repeated_int32.pop(0))
938 self.assertEqual(2, m.repeated_int32.pop(1))
939 self.assertEqual([1, 3], m.repeated_int32)
940
941 def testRepeatedCompositeFieldPop(self, message_module):
942 m = message_module.TestAllTypes()
943 with self.assertRaises(IndexError) as _:
944 m.repeated_nested_message.pop()
945 for i in range(5):
946 n = m.repeated_nested_message.add()
947 n.bb = i
948 self.assertEqual(4, m.repeated_nested_message.pop().bb)
949 self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
950 self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
951 self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
952
953
954# Class to test proto2-only features (required, extensions, etc.)
955class Proto2Test(unittest.TestCase):
956
957 def testFieldPresence(self):
958 message = unittest_pb2.TestAllTypes()
959
960 self.assertFalse(message.HasField("optional_int32"))
961 self.assertFalse(message.HasField("optional_bool"))
962 self.assertFalse(message.HasField("optional_nested_message"))
963
964 with self.assertRaises(ValueError):
965 message.HasField("field_doesnt_exist")
966
967 with self.assertRaises(ValueError):
968 message.HasField("repeated_int32")
969 with self.assertRaises(ValueError):
970 message.HasField("repeated_nested_message")
971
972 self.assertEqual(0, message.optional_int32)
973 self.assertEqual(False, message.optional_bool)
974 self.assertEqual(0, message.optional_nested_message.bb)
975
976 # Fields are set even when setting the values to default values.
977 message.optional_int32 = 0
978 message.optional_bool = False
979 message.optional_nested_message.bb = 0
980 self.assertTrue(message.HasField("optional_int32"))
981 self.assertTrue(message.HasField("optional_bool"))
982 self.assertTrue(message.HasField("optional_nested_message"))
983
984 # Set the fields to non-default values.
985 message.optional_int32 = 5
986 message.optional_bool = True
987 message.optional_nested_message.bb = 15
988
989 self.assertTrue(message.HasField("optional_int32"))
990 self.assertTrue(message.HasField("optional_bool"))
991 self.assertTrue(message.HasField("optional_nested_message"))
992
993 # Clearing the fields unsets them and resets their value to default.
994 message.ClearField("optional_int32")
995 message.ClearField("optional_bool")
996 message.ClearField("optional_nested_message")
997
998 self.assertFalse(message.HasField("optional_int32"))
999 self.assertFalse(message.HasField("optional_bool"))
1000 self.assertFalse(message.HasField("optional_nested_message"))
1001 self.assertEqual(0, message.optional_int32)
1002 self.assertEqual(False, message.optional_bool)
1003 self.assertEqual(0, message.optional_nested_message.bb)
1004
1005 # TODO(tibell): The C++ implementations actually allows assignment
1006 # of unknown enum values to *scalar* fields (but not repeated
1007 # fields). Once checked enum fields becomes the default in the
1008 # Python implementation, the C++ implementation should follow suit.
1009 def testAssignInvalidEnum(self):
1010 """It should not be possible to assign an invalid enum number to an
1011 enum field."""
1012 m = unittest_pb2.TestAllTypes()
1013
1014 with self.assertRaises(ValueError) as _:
1015 m.optional_nested_enum = 1234567
1016 self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
1017
1018 def testGoldenExtensions(self):
1019 golden_data = test_util.GoldenFileData('golden_message')
1020 golden_message = unittest_pb2.TestAllExtensions()
1021 golden_message.ParseFromString(golden_data)
1022 all_set = unittest_pb2.TestAllExtensions()
1023 test_util.SetAllExtensions(all_set)
1024 self.assertEqual(all_set, golden_message)
1025 self.assertEqual(golden_data, golden_message.SerializeToString())
1026 golden_copy = copy.deepcopy(golden_message)
1027 self.assertEqual(golden_data, golden_copy.SerializeToString())
1028
1029 def testGoldenPackedExtensions(self):
1030 golden_data = test_util.GoldenFileData('golden_packed_fields_message')
1031 golden_message = unittest_pb2.TestPackedExtensions()
1032 golden_message.ParseFromString(golden_data)
1033 all_set = unittest_pb2.TestPackedExtensions()
1034 test_util.SetAllPackedExtensions(all_set)
1035 self.assertEqual(all_set, golden_message)
1036 self.assertEqual(golden_data, all_set.SerializeToString())
1037 golden_copy = copy.deepcopy(golden_message)
1038 self.assertEqual(golden_data, golden_copy.SerializeToString())
1039
1040 def testPickleIncompleteProto(self):
1041 golden_message = unittest_pb2.TestRequired(a=1)
1042 pickled_message = pickle.dumps(golden_message)
1043
1044 unpickled_message = pickle.loads(pickled_message)
1045 self.assertEqual(unpickled_message, golden_message)
1046 self.assertEqual(unpickled_message.a, 1)
1047 # This is still an incomplete proto - so serializing should fail
1048 self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
1049
1050
1051 # TODO(haberman): this isn't really a proto2-specific test except that this
1052 # message has a required field in it. Should probably be factored out so
1053 # that we can test the other parts with proto3.
1054 def testParsingMerge(self):
1055 """Check the merge behavior when a required or optional field appears
1056 multiple times in the input."""
1057 messages = [
1058 unittest_pb2.TestAllTypes(),
1059 unittest_pb2.TestAllTypes(),
1060 unittest_pb2.TestAllTypes() ]
1061 messages[0].optional_int32 = 1
1062 messages[1].optional_int64 = 2
1063 messages[2].optional_int32 = 3
1064 messages[2].optional_string = 'hello'
1065
1066 merged_message = unittest_pb2.TestAllTypes()
1067 merged_message.optional_int32 = 3
1068 merged_message.optional_int64 = 2
1069 merged_message.optional_string = 'hello'
1070
1071 generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
1072 generator.field1.extend(messages)
1073 generator.field2.extend(messages)
1074 generator.field3.extend(messages)
1075 generator.ext1.extend(messages)
1076 generator.ext2.extend(messages)
1077 generator.group1.add().field1.MergeFrom(messages[0])
1078 generator.group1.add().field1.MergeFrom(messages[1])
1079 generator.group1.add().field1.MergeFrom(messages[2])
1080 generator.group2.add().field1.MergeFrom(messages[0])
1081 generator.group2.add().field1.MergeFrom(messages[1])
1082 generator.group2.add().field1.MergeFrom(messages[2])
1083
1084 data = generator.SerializeToString()
1085 parsing_merge = unittest_pb2.TestParsingMerge()
1086 parsing_merge.ParseFromString(data)
1087
1088 # Required and optional fields should be merged.
1089 self.assertEqual(parsing_merge.required_all_types, merged_message)
1090 self.assertEqual(parsing_merge.optional_all_types, merged_message)
1091 self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
1092 merged_message)
1093 self.assertEqual(parsing_merge.Extensions[
1094 unittest_pb2.TestParsingMerge.optional_ext],
1095 merged_message)
1096
1097 # Repeated fields should not be merged.
1098 self.assertEqual(len(parsing_merge.repeated_all_types), 3)
1099 self.assertEqual(len(parsing_merge.repeatedgroup), 3)
1100 self.assertEqual(len(parsing_merge.Extensions[
1101 unittest_pb2.TestParsingMerge.repeated_ext]), 3)
1102
1103 def testPythonicInit(self):
1104 message = unittest_pb2.TestAllTypes(
1105 optional_int32=100,
1106 optional_fixed32=200,
1107 optional_float=300.5,
1108 optional_bytes=b'x',
1109 optionalgroup={'a': 400},
1110 optional_nested_message={'bb': 500},
1111 optional_nested_enum='BAZ',
1112 repeatedgroup=[{'a': 600},
1113 {'a': 700}],
1114 repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
1115 default_int32=800,
1116 oneof_string='y')
1117 self.assertIsInstance(message, unittest_pb2.TestAllTypes)
1118 self.assertEqual(100, message.optional_int32)
1119 self.assertEqual(200, message.optional_fixed32)
1120 self.assertEqual(300.5, message.optional_float)
1121 self.assertEqual(b'x', message.optional_bytes)
1122 self.assertEqual(400, message.optionalgroup.a)
1123 self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage)
1124 self.assertEqual(500, message.optional_nested_message.bb)
1125 self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1126 message.optional_nested_enum)
1127 self.assertEqual(2, len(message.repeatedgroup))
1128 self.assertEqual(600, message.repeatedgroup[0].a)
1129 self.assertEqual(700, message.repeatedgroup[1].a)
1130 self.assertEqual(2, len(message.repeated_nested_enum))
1131 self.assertEqual(unittest_pb2.TestAllTypes.FOO,
1132 message.repeated_nested_enum[0])
1133 self.assertEqual(unittest_pb2.TestAllTypes.BAR,
1134 message.repeated_nested_enum[1])
1135 self.assertEqual(800, message.default_int32)
1136 self.assertEqual('y', message.oneof_string)
1137 self.assertFalse(message.HasField('optional_int64'))
1138 self.assertEqual(0, len(message.repeated_float))
1139 self.assertEqual(42, message.default_int64)
1140
1141 message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
1142 self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1143 message.optional_nested_enum)
1144
1145 with self.assertRaises(ValueError):
1146 unittest_pb2.TestAllTypes(
1147 optional_nested_message={'INVALID_NESTED_FIELD': 17})
1148
1149 with self.assertRaises(TypeError):
1150 unittest_pb2.TestAllTypes(
1151 optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
1152
1153 with self.assertRaises(ValueError):
1154 unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
1155
1156 with self.assertRaises(ValueError):
1157 unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
1158
1159
1160# Class to test proto3-only features/behavior (updated field presence & enums)
1161class Proto3Test(unittest.TestCase):
1162
1163 # Utility method for comparing equality with a map.
1164 def assertMapIterEquals(self, map_iter, dict_value):
1165 # Avoid mutating caller's copy.
1166 dict_value = dict(dict_value)
1167
1168 for k, v in map_iter:
1169 self.assertEqual(v, dict_value[k])
1170 del dict_value[k]
1171
1172 self.assertEqual({}, dict_value)
1173
1174 def testFieldPresence(self):
1175 message = unittest_proto3_arena_pb2.TestAllTypes()
1176
1177 # We can't test presence of non-repeated, non-submessage fields.
1178 with self.assertRaises(ValueError):
1179 message.HasField('optional_int32')
1180 with self.assertRaises(ValueError):
1181 message.HasField('optional_float')
1182 with self.assertRaises(ValueError):
1183 message.HasField('optional_string')
1184 with self.assertRaises(ValueError):
1185 message.HasField('optional_bool')
1186
1187 # But we can still test presence of submessage fields.
1188 self.assertFalse(message.HasField('optional_nested_message'))
1189
1190 # As with proto2, we can't test presence of fields that don't exist, or
1191 # repeated fields.
1192 with self.assertRaises(ValueError):
1193 message.HasField('field_doesnt_exist')
1194
1195 with self.assertRaises(ValueError):
1196 message.HasField('repeated_int32')
1197 with self.assertRaises(ValueError):
1198 message.HasField('repeated_nested_message')
1199
1200 # Fields should default to their type-specific default.
1201 self.assertEqual(0, message.optional_int32)
1202 self.assertEqual(0, message.optional_float)
1203 self.assertEqual('', message.optional_string)
1204 self.assertEqual(False, message.optional_bool)
1205 self.assertEqual(0, message.optional_nested_message.bb)
1206
1207 # Setting a submessage should still return proper presence information.
1208 message.optional_nested_message.bb = 0
1209 self.assertTrue(message.HasField('optional_nested_message'))
1210
1211 # Set the fields to non-default values.
1212 message.optional_int32 = 5
1213 message.optional_float = 1.1
1214 message.optional_string = 'abc'
1215 message.optional_bool = True
1216 message.optional_nested_message.bb = 15
1217
1218 # Clearing the fields unsets them and resets their value to default.
1219 message.ClearField('optional_int32')
1220 message.ClearField('optional_float')
1221 message.ClearField('optional_string')
1222 message.ClearField('optional_bool')
1223 message.ClearField('optional_nested_message')
1224
1225 self.assertEqual(0, message.optional_int32)
1226 self.assertEqual(0, message.optional_float)
1227 self.assertEqual('', message.optional_string)
1228 self.assertEqual(False, message.optional_bool)
1229 self.assertEqual(0, message.optional_nested_message.bb)
1230
1231 def testAssignUnknownEnum(self):
1232 """Assigning an unknown enum value is allowed and preserves the value."""
1233 m = unittest_proto3_arena_pb2.TestAllTypes()
1234
1235 m.optional_nested_enum = 1234567
1236 self.assertEqual(1234567, m.optional_nested_enum)
1237 m.repeated_nested_enum.append(22334455)
1238 self.assertEqual(22334455, m.repeated_nested_enum[0])
1239 # Assignment is a different code path than append for the C++ impl.
1240 m.repeated_nested_enum[0] = 7654321
1241 self.assertEqual(7654321, m.repeated_nested_enum[0])
1242 serialized = m.SerializeToString()
1243
1244 m2 = unittest_proto3_arena_pb2.TestAllTypes()
1245 m2.ParseFromString(serialized)
1246 self.assertEqual(1234567, m2.optional_nested_enum)
1247 self.assertEqual(7654321, m2.repeated_nested_enum[0])
1248
1249 # Map isn't really a proto3-only feature. But there is no proto2 equivalent
1250 # of google/protobuf/map_unittest.proto right now, so it's not easy to
1251 # test both with the same test like we do for the other proto2/proto3 tests.
1252 # (google/protobuf/map_protobuf_unittest.proto is very different in the set
1253 # of messages and fields it contains).
1254 def testScalarMapDefaults(self):
1255 msg = map_unittest_pb2.TestMap()
1256
1257 # Scalars start out unset.
1258 self.assertFalse(-123 in msg.map_int32_int32)
1259 self.assertFalse(-2**33 in msg.map_int64_int64)
1260 self.assertFalse(123 in msg.map_uint32_uint32)
1261 self.assertFalse(2**33 in msg.map_uint64_uint64)
1262 self.assertFalse('abc' in msg.map_string_string)
1263 self.assertFalse(888 in msg.map_int32_enum)
1264
1265 # Accessing an unset key returns the default.
1266 self.assertEqual(0, msg.map_int32_int32[-123])
1267 self.assertEqual(0, msg.map_int64_int64[-2**33])
1268 self.assertEqual(0, msg.map_uint32_uint32[123])
1269 self.assertEqual(0, msg.map_uint64_uint64[2**33])
1270 self.assertEqual('', msg.map_string_string['abc'])
1271 self.assertEqual(0, msg.map_int32_enum[888])
1272
1273 # It also sets the value in the map
1274 self.assertTrue(-123 in msg.map_int32_int32)
1275 self.assertTrue(-2**33 in msg.map_int64_int64)
1276 self.assertTrue(123 in msg.map_uint32_uint32)
1277 self.assertTrue(2**33 in msg.map_uint64_uint64)
1278 self.assertTrue('abc' in msg.map_string_string)
1279 self.assertTrue(888 in msg.map_int32_enum)
1280
1281 self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
1282
1283 # Accessing an unset key still throws TypeError if the type of the key
1284 # is incorrect.
1285 with self.assertRaises(TypeError):
1286 msg.map_string_string[123]
1287
1288 with self.assertRaises(TypeError):
1289 123 in msg.map_string_string
1290
1291 def testMapGet(self):
1292 # Need to test that get() properly returns the default, even though the dict
1293 # has defaultdict-like semantics.
1294 msg = map_unittest_pb2.TestMap()
1295
1296 self.assertIsNone(msg.map_int32_int32.get(5))
1297 self.assertEqual(10, msg.map_int32_int32.get(5, 10))
1298 self.assertIsNone(msg.map_int32_int32.get(5))
1299
1300 msg.map_int32_int32[5] = 15
1301 self.assertEqual(15, msg.map_int32_int32.get(5))
1302
1303 self.assertIsNone(msg.map_int32_foreign_message.get(5))
1304 self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
1305
1306 submsg = msg.map_int32_foreign_message[5]
1307 self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
1308
1309 def testScalarMap(self):
1310 msg = map_unittest_pb2.TestMap()
1311
1312 self.assertEqual(0, len(msg.map_int32_int32))
1313 self.assertFalse(5 in msg.map_int32_int32)
1314
1315 msg.map_int32_int32[-123] = -456
1316 msg.map_int64_int64[-2**33] = -2**34
1317 msg.map_uint32_uint32[123] = 456
1318 msg.map_uint64_uint64[2**33] = 2**34
1319 msg.map_string_string['abc'] = '123'
1320 msg.map_int32_enum[888] = 2
1321
1322 self.assertEqual([], msg.FindInitializationErrors())
1323
1324 self.assertEqual(1, len(msg.map_string_string))
1325
1326 # Bad key.
1327 with self.assertRaises(TypeError):
1328 msg.map_string_string[123] = '123'
1329
1330 # Verify that trying to assign a bad key doesn't actually add a member to
1331 # the map.
1332 self.assertEqual(1, len(msg.map_string_string))
1333
1334 # Bad value.
1335 with self.assertRaises(TypeError):
1336 msg.map_string_string['123'] = 123
1337
1338 serialized = msg.SerializeToString()
1339 msg2 = map_unittest_pb2.TestMap()
1340 msg2.ParseFromString(serialized)
1341
1342 # Bad key.
1343 with self.assertRaises(TypeError):
1344 msg2.map_string_string[123] = '123'
1345
1346 # Bad value.
1347 with self.assertRaises(TypeError):
1348 msg2.map_string_string['123'] = 123
1349
1350 self.assertEqual(-456, msg2.map_int32_int32[-123])
1351 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1352 self.assertEqual(456, msg2.map_uint32_uint32[123])
1353 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
1354 self.assertEqual('123', msg2.map_string_string['abc'])
1355 self.assertEqual(2, msg2.map_int32_enum[888])
1356
1357 def testStringUnicodeConversionInMap(self):
1358 msg = map_unittest_pb2.TestMap()
1359
1360 unicode_obj = u'\u1234'
1361 bytes_obj = unicode_obj.encode('utf8')
1362
1363 msg.map_string_string[bytes_obj] = bytes_obj
1364
1365 (key, value) = list(msg.map_string_string.items())[0]
1366
1367 self.assertEqual(key, unicode_obj)
1368 self.assertEqual(value, unicode_obj)
1369
1370 self.assertIsInstance(key, six.text_type)
1371 self.assertIsInstance(value, six.text_type)
1372
1373 def testMessageMap(self):
1374 msg = map_unittest_pb2.TestMap()
1375
1376 self.assertEqual(0, len(msg.map_int32_foreign_message))
1377 self.assertFalse(5 in msg.map_int32_foreign_message)
1378
1379 msg.map_int32_foreign_message[123]
1380 # get_or_create() is an alias for getitem.
1381 msg.map_int32_foreign_message.get_or_create(-456)
1382
1383 self.assertEqual(2, len(msg.map_int32_foreign_message))
1384 self.assertIn(123, msg.map_int32_foreign_message)
1385 self.assertIn(-456, msg.map_int32_foreign_message)
1386 self.assertEqual(2, len(msg.map_int32_foreign_message))
1387
1388 # Bad key.
1389 with self.assertRaises(TypeError):
1390 msg.map_int32_foreign_message['123']
1391
1392 # Can't assign directly to submessage.
1393 with self.assertRaises(ValueError):
1394 msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
1395
1396 # Verify that trying to assign a bad key doesn't actually add a member to
1397 # the map.
1398 self.assertEqual(2, len(msg.map_int32_foreign_message))
1399
1400 serialized = msg.SerializeToString()
1401 msg2 = map_unittest_pb2.TestMap()
1402 msg2.ParseFromString(serialized)
1403
1404 self.assertEqual(2, len(msg2.map_int32_foreign_message))
1405 self.assertIn(123, msg2.map_int32_foreign_message)
1406 self.assertIn(-456, msg2.map_int32_foreign_message)
1407 self.assertEqual(2, len(msg2.map_int32_foreign_message))
1408
1409 def testMergeFrom(self):
1410 msg = map_unittest_pb2.TestMap()
1411 msg.map_int32_int32[12] = 34
1412 msg.map_int32_int32[56] = 78
1413 msg.map_int64_int64[22] = 33
1414 msg.map_int32_foreign_message[111].c = 5
1415 msg.map_int32_foreign_message[222].c = 10
1416
1417 msg2 = map_unittest_pb2.TestMap()
1418 msg2.map_int32_int32[12] = 55
1419 msg2.map_int64_int64[88] = 99
1420 msg2.map_int32_foreign_message[222].c = 15
1421
1422 msg2.MergeFrom(msg)
1423
1424 self.assertEqual(34, msg2.map_int32_int32[12])
1425 self.assertEqual(78, msg2.map_int32_int32[56])
1426 self.assertEqual(33, msg2.map_int64_int64[22])
1427 self.assertEqual(99, msg2.map_int64_int64[88])
1428 self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
1429 self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
1430
1431 # Verify that there is only one entry per key, even though the MergeFrom
1432 # may have internally created multiple entries for a single key in the
1433 # list representation.
1434 as_dict = {}
1435 for key in msg2.map_int32_foreign_message:
1436 self.assertFalse(key in as_dict)
1437 as_dict[key] = msg2.map_int32_foreign_message[key].c
1438
1439 self.assertEqual({111: 5, 222: 10}, as_dict)
1440
1441 # Special case: test that delete of item really removes the item, even if
1442 # there might have physically been duplicate keys due to the previous merge.
1443 # This is only a special case for the C++ implementation which stores the
1444 # map as an array.
1445 del msg2.map_int32_int32[12]
1446 self.assertFalse(12 in msg2.map_int32_int32)
1447
1448 del msg2.map_int32_foreign_message[222]
1449 self.assertFalse(222 in msg2.map_int32_foreign_message)
1450
1451 def testIntegerMapWithLongs(self):
1452 msg = map_unittest_pb2.TestMap()
1453 msg.map_int32_int32[long(-123)] = long(-456)
1454 msg.map_int64_int64[long(-2**33)] = long(-2**34)
1455 msg.map_uint32_uint32[long(123)] = long(456)
1456 msg.map_uint64_uint64[long(2**33)] = long(2**34)
1457
1458 serialized = msg.SerializeToString()
1459 msg2 = map_unittest_pb2.TestMap()
1460 msg2.ParseFromString(serialized)
1461
1462 self.assertEqual(-456, msg2.map_int32_int32[-123])
1463 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1464 self.assertEqual(456, msg2.map_uint32_uint32[123])
1465 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
1466
1467 def testMapAssignmentCausesPresence(self):
1468 msg = map_unittest_pb2.TestMapSubmessage()
1469 msg.test_map.map_int32_int32[123] = 456
1470
1471 serialized = msg.SerializeToString()
1472 msg2 = map_unittest_pb2.TestMapSubmessage()
1473 msg2.ParseFromString(serialized)
1474
1475 self.assertEqual(msg, msg2)
1476
1477 # Now test that various mutations of the map properly invalidate the
1478 # cached size of the submessage.
1479 msg.test_map.map_int32_int32[888] = 999
1480 serialized = msg.SerializeToString()
1481 msg2.ParseFromString(serialized)
1482 self.assertEqual(msg, msg2)
1483
1484 msg.test_map.map_int32_int32.clear()
1485 serialized = msg.SerializeToString()
1486 msg2.ParseFromString(serialized)
1487 self.assertEqual(msg, msg2)
1488
1489 def testMapAssignmentCausesPresenceForSubmessages(self):
1490 msg = map_unittest_pb2.TestMapSubmessage()
1491 msg.test_map.map_int32_foreign_message[123].c = 5
1492
1493 serialized = msg.SerializeToString()
1494 msg2 = map_unittest_pb2.TestMapSubmessage()
1495 msg2.ParseFromString(serialized)
1496
1497 self.assertEqual(msg, msg2)
1498
1499 # Now test that various mutations of the map properly invalidate the
1500 # cached size of the submessage.
1501 msg.test_map.map_int32_foreign_message[888].c = 7
1502 serialized = msg.SerializeToString()
1503 msg2.ParseFromString(serialized)
1504 self.assertEqual(msg, msg2)
1505
1506 msg.test_map.map_int32_foreign_message[888].MergeFrom(
1507 msg.test_map.map_int32_foreign_message[123])
1508 serialized = msg.SerializeToString()
1509 msg2.ParseFromString(serialized)
1510 self.assertEqual(msg, msg2)
1511
1512 msg.test_map.map_int32_foreign_message.clear()
1513 serialized = msg.SerializeToString()
1514 msg2.ParseFromString(serialized)
1515 self.assertEqual(msg, msg2)
1516
1517 def testModifyMapWhileIterating(self):
1518 msg = map_unittest_pb2.TestMap()
1519
1520 string_string_iter = iter(msg.map_string_string)
1521 int32_foreign_iter = iter(msg.map_int32_foreign_message)
1522
1523 msg.map_string_string['abc'] = '123'
1524 msg.map_int32_foreign_message[5].c = 5
1525
1526 with self.assertRaises(RuntimeError):
1527 for key in string_string_iter:
1528 pass
1529
1530 with self.assertRaises(RuntimeError):
1531 for key in int32_foreign_iter:
1532 pass
1533
1534 def testSubmessageMap(self):
1535 msg = map_unittest_pb2.TestMap()
1536
1537 submsg = msg.map_int32_foreign_message[111]
1538 self.assertIs(submsg, msg.map_int32_foreign_message[111])
1539 self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
1540
1541 submsg.c = 5
1542
1543 serialized = msg.SerializeToString()
1544 msg2 = map_unittest_pb2.TestMap()
1545 msg2.ParseFromString(serialized)
1546
1547 self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
1548
1549 # Doesn't allow direct submessage assignment.
1550 with self.assertRaises(ValueError):
1551 msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
1552
1553 def testMapIteration(self):
1554 msg = map_unittest_pb2.TestMap()
1555
1556 for k, v in msg.map_int32_int32.items():
1557 # Should not be reached.
1558 self.assertTrue(False)
1559
1560 msg.map_int32_int32[2] = 4
1561 msg.map_int32_int32[3] = 6
1562 msg.map_int32_int32[4] = 8
1563 self.assertEqual(3, len(msg.map_int32_int32))
1564
1565 matching_dict = {2: 4, 3: 6, 4: 8}
1566 self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
1567
1568 def testMapIterationClearMessage(self):
1569 # Iterator needs to work even if message and map are deleted.
1570 msg = map_unittest_pb2.TestMap()
1571
1572 msg.map_int32_int32[2] = 4
1573 msg.map_int32_int32[3] = 6
1574 msg.map_int32_int32[4] = 8
1575
1576 it = msg.map_int32_int32.items()
1577 del msg
1578
1579 matching_dict = {2: 4, 3: 6, 4: 8}
1580 self.assertMapIterEquals(it, matching_dict)
1581
1582 def testMapConstruction(self):
1583 msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
1584 self.assertEqual(2, msg.map_int32_int32[1])
1585 self.assertEqual(4, msg.map_int32_int32[3])
1586
1587 msg = map_unittest_pb2.TestMap(
1588 map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
1589 self.assertEqual(5, msg.map_int32_foreign_message[3].c)
1590
1591 def testMapValidAfterFieldCleared(self):
1592 # Map needs to work even if field is cleared.
1593 # For the C++ implementation this tests the correctness of
1594 # ScalarMapContainer::Release()
1595 msg = map_unittest_pb2.TestMap()
1596 int32_map = msg.map_int32_int32
1597
1598 int32_map[2] = 4
1599 int32_map[3] = 6
1600 int32_map[4] = 8
1601
1602 msg.ClearField('map_int32_int32')
1603 self.assertEqual(b'', msg.SerializeToString())
1604 matching_dict = {2: 4, 3: 6, 4: 8}
1605 self.assertMapIterEquals(int32_map.items(), matching_dict)
1606
1607 def testMessageMapValidAfterFieldCleared(self):
1608 # Map needs to work even if field is cleared.
1609 # For the C++ implementation this tests the correctness of
1610 # ScalarMapContainer::Release()
1611 msg = map_unittest_pb2.TestMap()
1612 int32_foreign_message = msg.map_int32_foreign_message
1613
1614 int32_foreign_message[2].c = 5
1615
1616 msg.ClearField('map_int32_foreign_message')
1617 self.assertEqual(b'', msg.SerializeToString())
1618 self.assertTrue(2 in int32_foreign_message.keys())
1619
1620 def testMapIterInvalidatedByClearField(self):
1621 # Map iterator is invalidated when field is cleared.
1622 # But this case does need to not crash the interpreter.
1623 # For the C++ implementation this tests the correctness of
1624 # ScalarMapContainer::Release()
1625 msg = map_unittest_pb2.TestMap()
1626
1627 it = iter(msg.map_int32_int32)
1628
1629 msg.ClearField('map_int32_int32')
1630 with self.assertRaises(RuntimeError):
1631 for _ in it:
1632 pass
1633
1634 it = iter(msg.map_int32_foreign_message)
1635 msg.ClearField('map_int32_foreign_message')
1636 with self.assertRaises(RuntimeError):
1637 for _ in it:
1638 pass
1639
1640 def testMapDelete(self):
1641 msg = map_unittest_pb2.TestMap()
1642
1643 self.assertEqual(0, len(msg.map_int32_int32))
1644
1645 msg.map_int32_int32[4] = 6
1646 self.assertEqual(1, len(msg.map_int32_int32))
1647
1648 with self.assertRaises(KeyError):
1649 del msg.map_int32_int32[88]
1650
1651 del msg.map_int32_int32[4]
1652 self.assertEqual(0, len(msg.map_int32_int32))
1653
1654 def testMapsAreMapping(self):
1655 msg = map_unittest_pb2.TestMap()
1656 self.assertIsInstance(msg.map_int32_int32, collections.Mapping)
1657 self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping)
1658 self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping)
1659 self.assertIsInstance(msg.map_int32_foreign_message,
1660 collections.MutableMapping)
1661
1662 def testMapFindInitializationErrorsSmokeTest(self):
1663 msg = map_unittest_pb2.TestMap()
1664 msg.map_string_string['abc'] = '123'
1665 msg.map_int32_int32[35] = 64
1666 msg.map_string_foreign_message['foo'].c = 5
1667 self.assertEqual(0, len(msg.FindInitializationErrors()))
1668
1669 def testAnyMessage(self):
1670 # Creates and sets message.
1671 msg = any_test_pb2.TestAny()
1672 msg_descriptor = msg.DESCRIPTOR
1673 all_types = unittest_pb2.TestAllTypes()
1674 all_descriptor = all_types.DESCRIPTOR
1675 all_types.repeated_string.append(u'\u00fc\ua71f')
1676 # Packs to Any.
1677 msg.value.Pack(all_types)
1678 self.assertEqual(msg.value.type_url,
1679 'type.googleapis.com/%s' % all_descriptor.full_name)
1680 self.assertEqual(msg.value.value,
1681 all_types.SerializeToString())
1682 # Tests Is() method.
1683 self.assertTrue(msg.value.Is(all_descriptor))
1684 self.assertFalse(msg.value.Is(msg_descriptor))
1685 # Unpacks Any.
1686 unpacked_message = unittest_pb2.TestAllTypes()
1687 self.assertTrue(msg.value.Unpack(unpacked_message))
1688 self.assertEqual(all_types, unpacked_message)
1689 # Unpacks to different type.
1690 self.assertFalse(msg.value.Unpack(msg))
1691 # Only Any messages have Pack method.
1692 try:
1693 msg.Pack(all_types)
1694 except AttributeError:
1695 pass
1696 else:
1697 raise AttributeError('%s should not have Pack method.' %
1698 msg_descriptor.full_name)
1699
1700
1701
1702class ValidTypeNamesTest(unittest.TestCase):
1703
1704 def assertImportFromName(self, msg, base_name):
1705 # Parse <type 'module.class_name'> to extra 'some.name' as a string.
1706 tp_name = str(type(msg)).split("'")[1]
1707 valid_names = ('Repeated%sContainer' % base_name,
1708 'Repeated%sFieldContainer' % base_name)
1709 self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
1710 '%r does end with any of %r' % (tp_name, valid_names))
1711
1712 parts = tp_name.split('.')
1713 class_name = parts[-1]
1714 module_name = '.'.join(parts[:-1])
1715 __import__(module_name, fromlist=[class_name])
1716
1717 def testTypeNamesCanBeImported(self):
1718 # If import doesn't work, pickling won't work either.
1719 pb = unittest_pb2.TestAllTypes()
1720 self.assertImportFromName(pb.repeated_int32, 'Scalar')
1721 self.assertImportFromName(pb.repeated_nested_message, 'Composite')
1722
1723class PackedFieldTest(unittest.TestCase):
1724
1725 def setMessage(self, message):
1726 message.repeated_int32.append(1)
1727 message.repeated_int64.append(1)
1728 message.repeated_uint32.append(1)
1729 message.repeated_uint64.append(1)
1730 message.repeated_sint32.append(1)
1731 message.repeated_sint64.append(1)
1732 message.repeated_fixed32.append(1)
1733 message.repeated_fixed64.append(1)
1734 message.repeated_sfixed32.append(1)
1735 message.repeated_sfixed64.append(1)
1736 message.repeated_float.append(1.0)
1737 message.repeated_double.append(1.0)
1738 message.repeated_bool.append(True)
1739 message.repeated_nested_enum.append(1)
1740
1741 def testPackedFields(self):
1742 message = packed_field_test_pb2.TestPackedTypes()
1743 self.setMessage(message)
1744 golden_data = (b'\x0A\x01\x01'
1745 b'\x12\x01\x01'
1746 b'\x1A\x01\x01'
1747 b'\x22\x01\x01'
1748 b'\x2A\x01\x02'
1749 b'\x32\x01\x02'
1750 b'\x3A\x04\x01\x00\x00\x00'
1751 b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
1752 b'\x4A\x04\x01\x00\x00\x00'
1753 b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
1754 b'\x5A\x04\x00\x00\x80\x3f'
1755 b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
1756 b'\x6A\x01\x01'
1757 b'\x72\x01\x01')
1758 self.assertEqual(golden_data, message.SerializeToString())
1759
1760 def testUnpackedFields(self):
1761 message = packed_field_test_pb2.TestUnpackedTypes()
1762 self.setMessage(message)
1763 golden_data = (b'\x08\x01'
1764 b'\x10\x01'
1765 b'\x18\x01'
1766 b'\x20\x01'
1767 b'\x28\x02'
1768 b'\x30\x02'
1769 b'\x3D\x01\x00\x00\x00'
1770 b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
1771 b'\x4D\x01\x00\x00\x00'
1772 b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
1773 b'\x5D\x00\x00\x80\x3f'
1774 b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
1775 b'\x68\x01'
1776 b'\x70\x01')
1777 self.assertEqual(golden_data, message.SerializeToString())
1778
1779if __name__ == '__main__':
1780 unittest.main()