blob: 61a56a678cbde65d76001781fecb97edab356142 [file] [log] [blame]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001#! /usr/bin/env python
2#
3# Protocol Buffers - Google's data interchange format
4# Copyright 2008 Google Inc. All rights reserved.
5# https://developers.google.com/protocol-buffers/
6#
7# Redistribution and use in source and binary forms, with or without
8# modification, are permitted provided that the following conditions are
9# met:
10#
11# * Redistributions of source code must retain the above copyright
12# notice, this list of conditions and the following disclaimer.
13# * Redistributions in binary form must reproduce the above
14# copyright notice, this list of conditions and the following disclaimer
15# in the documentation and/or other materials provided with the
16# distribution.
17# * Neither the name of Google Inc. nor the names of its
18# contributors may be used to endorse or promote products derived from
19# this software without specific prior written permission.
20#
21# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
33"""Tests python protocol buffers against the golden message.
34
35Note that the golden messages exercise every known field type, thus this
36test ends up exercising and verifying nearly all of the parsing and
37serialization code in the whole library.
38
39TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of
40sense to call this a test of the "message" module, which only declares an
41abstract interface.
42"""
43
44__author__ = 'gps@google.com (Gregory P. Smith)'
45
46
47import collections
48import copy
49import math
50import operator
51import pickle
52import six
53import sys
Austin Schuh40c16522018-10-28 20:27:54 -070054import warnings
Brian Silverman9c614bc2016-02-15 20:20:02 -050055
56try:
Austin Schuh40c16522018-10-28 20:27:54 -070057 import unittest2 as unittest # PY26
Brian Silverman9c614bc2016-02-15 20:20:02 -050058except ImportError:
59 import unittest
Austin Schuh40c16522018-10-28 20:27:54 -070060try:
61 cmp # Python 2
62except NameError:
63 cmp = lambda x, y: (x > y) - (x < y) # Python 3
64
65from google.protobuf import map_proto2_unittest_pb2
Brian Silverman9c614bc2016-02-15 20:20:02 -050066from google.protobuf import map_unittest_pb2
67from google.protobuf import unittest_pb2
68from google.protobuf import unittest_proto3_arena_pb2
Austin Schuh40c16522018-10-28 20:27:54 -070069from google.protobuf import descriptor_pb2
70from google.protobuf import descriptor_pool
71from google.protobuf import message_factory
72from google.protobuf import text_format
Brian Silverman9c614bc2016-02-15 20:20:02 -050073from google.protobuf.internal import api_implementation
Austin Schuh40c16522018-10-28 20:27:54 -070074from google.protobuf.internal import encoder
Brian Silverman9c614bc2016-02-15 20:20:02 -050075from google.protobuf.internal import packed_field_test_pb2
76from google.protobuf.internal import test_util
Austin Schuh40c16522018-10-28 20:27:54 -070077from google.protobuf.internal import testing_refleaks
Brian Silverman9c614bc2016-02-15 20:20:02 -050078from google.protobuf import message
Austin Schuh40c16522018-10-28 20:27:54 -070079from google.protobuf.internal import _parameterized
Brian Silverman9c614bc2016-02-15 20:20:02 -050080
81if six.PY3:
82 long = int
83
Austin Schuh40c16522018-10-28 20:27:54 -070084
Brian Silverman9c614bc2016-02-15 20:20:02 -050085# Python pre-2.6 does not have isinf() or isnan() functions, so we have
86# to provide our own.
87def isnan(val):
88 # NaN is never equal to itself.
89 return val != val
90def isinf(val):
91 # Infinity times zero equals NaN.
92 return not isnan(val) and isnan(val * 0)
93def IsPosInf(val):
94 return isinf(val) and (val > 0)
95def IsNegInf(val):
96 return isinf(val) and (val < 0)
97
98
Austin Schuh40c16522018-10-28 20:27:54 -070099BaseTestCase = testing_refleaks.BaseTestCase
100
101
102@_parameterized.named_parameters(
103 ('_proto2', unittest_pb2),
104 ('_proto3', unittest_proto3_arena_pb2))
105class MessageTest(BaseTestCase):
Brian Silverman9c614bc2016-02-15 20:20:02 -0500106
107 def testBadUtf8String(self, message_module):
108 if api_implementation.Type() != 'python':
109 self.skipTest("Skipping testBadUtf8String, currently only the python "
110 "api implementation raises UnicodeDecodeError when a "
111 "string field contains bad utf-8.")
112 bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
113 with self.assertRaises(UnicodeDecodeError) as context:
114 message_module.TestAllTypes.FromString(bad_utf8_data)
115 self.assertIn('TestAllTypes.optional_string', str(context.exception))
116
117 def testGoldenMessage(self, message_module):
118 # Proto3 doesn't have the "default_foo" members or foreign enums,
119 # and doesn't preserve unknown fields, so for proto3 we use a golden
120 # message that doesn't have these fields set.
121 if message_module is unittest_pb2:
122 golden_data = test_util.GoldenFileData(
123 'golden_message_oneof_implemented')
124 else:
125 golden_data = test_util.GoldenFileData('golden_message_proto3')
126
127 golden_message = message_module.TestAllTypes()
128 golden_message.ParseFromString(golden_data)
129 if message_module is unittest_pb2:
130 test_util.ExpectAllFieldsSet(self, golden_message)
131 self.assertEqual(golden_data, golden_message.SerializeToString())
132 golden_copy = copy.deepcopy(golden_message)
133 self.assertEqual(golden_data, golden_copy.SerializeToString())
134
135 def testGoldenPackedMessage(self, message_module):
136 golden_data = test_util.GoldenFileData('golden_packed_fields_message')
137 golden_message = message_module.TestPackedTypes()
138 golden_message.ParseFromString(golden_data)
139 all_set = message_module.TestPackedTypes()
140 test_util.SetAllPackedFields(all_set)
141 self.assertEqual(all_set, golden_message)
142 self.assertEqual(golden_data, all_set.SerializeToString())
143 golden_copy = copy.deepcopy(golden_message)
144 self.assertEqual(golden_data, golden_copy.SerializeToString())
145
Austin Schuh40c16522018-10-28 20:27:54 -0700146 def testParseErrors(self, message_module):
147 msg = message_module.TestAllTypes()
148 self.assertRaises(TypeError, msg.FromString, 0)
149 self.assertRaises(Exception, msg.FromString, '0')
150 # TODO(jieluo): Fix cpp extension to raise error instead of warning.
151 # b/27494216
152 end_tag = encoder.TagBytes(1, 4)
153 if api_implementation.Type() == 'python':
154 with self.assertRaises(message.DecodeError) as context:
155 msg.FromString(end_tag)
156 self.assertEqual('Unexpected end-group tag.', str(context.exception))
157 else:
158 with warnings.catch_warnings(record=True) as w:
159 # Cause all warnings to always be triggered.
160 warnings.simplefilter('always')
161 msg.FromString(end_tag)
162 assert len(w) == 1
163 assert issubclass(w[-1].category, RuntimeWarning)
164 self.assertEqual('Unexpected end-group tag: Not all data was converted',
165 str(w[-1].message))
166
167 def testDeterminismParameters(self, message_module):
168 # This message is always deterministically serialized, even if determinism
169 # is disabled, so we can use it to verify that all the determinism
170 # parameters work correctly.
171 golden_data = (b'\xe2\x02\nOne string'
172 b'\xe2\x02\nTwo string'
173 b'\xe2\x02\nRed string'
174 b'\xe2\x02\x0bBlue string')
175 golden_message = message_module.TestAllTypes()
176 golden_message.repeated_string.extend([
177 'One string',
178 'Two string',
179 'Red string',
180 'Blue string',
181 ])
182 self.assertEqual(golden_data,
183 golden_message.SerializeToString(deterministic=None))
184 self.assertEqual(golden_data,
185 golden_message.SerializeToString(deterministic=False))
186 self.assertEqual(golden_data,
187 golden_message.SerializeToString(deterministic=True))
188
189 class BadArgError(Exception):
190 pass
191
192 class BadArg(object):
193
194 def __nonzero__(self):
195 raise BadArgError()
196
197 def __bool__(self):
198 raise BadArgError()
199
200 with self.assertRaises(BadArgError):
201 golden_message.SerializeToString(deterministic=BadArg())
202
Brian Silverman9c614bc2016-02-15 20:20:02 -0500203 def testPickleSupport(self, message_module):
204 golden_data = test_util.GoldenFileData('golden_message')
205 golden_message = message_module.TestAllTypes()
206 golden_message.ParseFromString(golden_data)
207 pickled_message = pickle.dumps(golden_message)
208
209 unpickled_message = pickle.loads(pickled_message)
210 self.assertEqual(unpickled_message, golden_message)
211
212 def testPositiveInfinity(self, message_module):
213 if message_module is unittest_pb2:
214 golden_data = (b'\x5D\x00\x00\x80\x7F'
215 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
216 b'\xCD\x02\x00\x00\x80\x7F'
217 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
218 else:
219 golden_data = (b'\x5D\x00\x00\x80\x7F'
220 b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
221 b'\xCA\x02\x04\x00\x00\x80\x7F'
222 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
223
224 golden_message = message_module.TestAllTypes()
225 golden_message.ParseFromString(golden_data)
226 self.assertTrue(IsPosInf(golden_message.optional_float))
227 self.assertTrue(IsPosInf(golden_message.optional_double))
228 self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
229 self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
230 self.assertEqual(golden_data, golden_message.SerializeToString())
231
232 def testNegativeInfinity(self, message_module):
233 if message_module is unittest_pb2:
234 golden_data = (b'\x5D\x00\x00\x80\xFF'
235 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
236 b'\xCD\x02\x00\x00\x80\xFF'
237 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
238 else:
239 golden_data = (b'\x5D\x00\x00\x80\xFF'
240 b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
241 b'\xCA\x02\x04\x00\x00\x80\xFF'
242 b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
243
244 golden_message = message_module.TestAllTypes()
245 golden_message.ParseFromString(golden_data)
246 self.assertTrue(IsNegInf(golden_message.optional_float))
247 self.assertTrue(IsNegInf(golden_message.optional_double))
248 self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
249 self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
250 self.assertEqual(golden_data, golden_message.SerializeToString())
251
252 def testNotANumber(self, message_module):
253 golden_data = (b'\x5D\x00\x00\xC0\x7F'
254 b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
255 b'\xCD\x02\x00\x00\xC0\x7F'
256 b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
257 golden_message = message_module.TestAllTypes()
258 golden_message.ParseFromString(golden_data)
259 self.assertTrue(isnan(golden_message.optional_float))
260 self.assertTrue(isnan(golden_message.optional_double))
261 self.assertTrue(isnan(golden_message.repeated_float[0]))
262 self.assertTrue(isnan(golden_message.repeated_double[0]))
263
264 # The protocol buffer may serialize to any one of multiple different
265 # representations of a NaN. Rather than verify a specific representation,
266 # verify the serialized string can be converted into a correctly
267 # behaving protocol buffer.
268 serialized = golden_message.SerializeToString()
269 message = message_module.TestAllTypes()
270 message.ParseFromString(serialized)
271 self.assertTrue(isnan(message.optional_float))
272 self.assertTrue(isnan(message.optional_double))
273 self.assertTrue(isnan(message.repeated_float[0]))
274 self.assertTrue(isnan(message.repeated_double[0]))
275
276 def testPositiveInfinityPacked(self, message_module):
277 golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
278 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
279 golden_message = message_module.TestPackedTypes()
280 golden_message.ParseFromString(golden_data)
281 self.assertTrue(IsPosInf(golden_message.packed_float[0]))
282 self.assertTrue(IsPosInf(golden_message.packed_double[0]))
283 self.assertEqual(golden_data, golden_message.SerializeToString())
284
285 def testNegativeInfinityPacked(self, message_module):
286 golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
287 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
288 golden_message = message_module.TestPackedTypes()
289 golden_message.ParseFromString(golden_data)
290 self.assertTrue(IsNegInf(golden_message.packed_float[0]))
291 self.assertTrue(IsNegInf(golden_message.packed_double[0]))
292 self.assertEqual(golden_data, golden_message.SerializeToString())
293
294 def testNotANumberPacked(self, message_module):
295 golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
296 b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
297 golden_message = message_module.TestPackedTypes()
298 golden_message.ParseFromString(golden_data)
299 self.assertTrue(isnan(golden_message.packed_float[0]))
300 self.assertTrue(isnan(golden_message.packed_double[0]))
301
302 serialized = golden_message.SerializeToString()
303 message = message_module.TestPackedTypes()
304 message.ParseFromString(serialized)
305 self.assertTrue(isnan(message.packed_float[0]))
306 self.assertTrue(isnan(message.packed_double[0]))
307
308 def testExtremeFloatValues(self, message_module):
309 message = message_module.TestAllTypes()
310
311 # Most positive exponent, no significand bits set.
312 kMostPosExponentNoSigBits = math.pow(2, 127)
313 message.optional_float = kMostPosExponentNoSigBits
314 message.ParseFromString(message.SerializeToString())
315 self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
316
317 # Most positive exponent, one significand bit set.
318 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
319 message.optional_float = kMostPosExponentOneSigBit
320 message.ParseFromString(message.SerializeToString())
321 self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
322
323 # Repeat last two cases with values of same magnitude, but negative.
324 message.optional_float = -kMostPosExponentNoSigBits
325 message.ParseFromString(message.SerializeToString())
326 self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
327
328 message.optional_float = -kMostPosExponentOneSigBit
329 message.ParseFromString(message.SerializeToString())
330 self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
331
332 # Most negative exponent, no significand bits set.
333 kMostNegExponentNoSigBits = math.pow(2, -127)
334 message.optional_float = kMostNegExponentNoSigBits
335 message.ParseFromString(message.SerializeToString())
336 self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
337
338 # Most negative exponent, one significand bit set.
339 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
340 message.optional_float = kMostNegExponentOneSigBit
341 message.ParseFromString(message.SerializeToString())
342 self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
343
344 # Repeat last two cases with values of the same magnitude, but negative.
345 message.optional_float = -kMostNegExponentNoSigBits
346 message.ParseFromString(message.SerializeToString())
347 self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
348
349 message.optional_float = -kMostNegExponentOneSigBit
350 message.ParseFromString(message.SerializeToString())
351 self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
352
353 def testExtremeDoubleValues(self, message_module):
354 message = message_module.TestAllTypes()
355
356 # Most positive exponent, no significand bits set.
357 kMostPosExponentNoSigBits = math.pow(2, 1023)
358 message.optional_double = kMostPosExponentNoSigBits
359 message.ParseFromString(message.SerializeToString())
360 self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
361
362 # Most positive exponent, one significand bit set.
363 kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
364 message.optional_double = kMostPosExponentOneSigBit
365 message.ParseFromString(message.SerializeToString())
366 self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
367
368 # Repeat last two cases with values of same magnitude, but negative.
369 message.optional_double = -kMostPosExponentNoSigBits
370 message.ParseFromString(message.SerializeToString())
371 self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
372
373 message.optional_double = -kMostPosExponentOneSigBit
374 message.ParseFromString(message.SerializeToString())
375 self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
376
377 # Most negative exponent, no significand bits set.
378 kMostNegExponentNoSigBits = math.pow(2, -1023)
379 message.optional_double = kMostNegExponentNoSigBits
380 message.ParseFromString(message.SerializeToString())
381 self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
382
383 # Most negative exponent, one significand bit set.
384 kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
385 message.optional_double = kMostNegExponentOneSigBit
386 message.ParseFromString(message.SerializeToString())
387 self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
388
389 # Repeat last two cases with values of the same magnitude, but negative.
390 message.optional_double = -kMostNegExponentNoSigBits
391 message.ParseFromString(message.SerializeToString())
392 self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
393
394 message.optional_double = -kMostNegExponentOneSigBit
395 message.ParseFromString(message.SerializeToString())
396 self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
397
398 def testFloatPrinting(self, message_module):
399 message = message_module.TestAllTypes()
400 message.optional_float = 2.0
401 self.assertEqual(str(message), 'optional_float: 2.0\n')
402
403 def testHighPrecisionFloatPrinting(self, message_module):
404 message = message_module.TestAllTypes()
405 message.optional_double = 0.12345678912345678
406 if sys.version_info >= (3,):
407 self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
408 else:
409 self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
410
411 def testUnknownFieldPrinting(self, message_module):
412 populated = message_module.TestAllTypes()
413 test_util.SetAllNonLazyFields(populated)
414 empty = message_module.TestEmptyMessage()
415 empty.ParseFromString(populated.SerializeToString())
416 self.assertEqual(str(empty), '')
417
418 def testRepeatedNestedFieldIteration(self, message_module):
419 msg = message_module.TestAllTypes()
420 msg.repeated_nested_message.add(bb=1)
421 msg.repeated_nested_message.add(bb=2)
422 msg.repeated_nested_message.add(bb=3)
423 msg.repeated_nested_message.add(bb=4)
424
425 self.assertEqual([1, 2, 3, 4],
426 [m.bb for m in msg.repeated_nested_message])
427 self.assertEqual([4, 3, 2, 1],
428 [m.bb for m in reversed(msg.repeated_nested_message)])
429 self.assertEqual([4, 3, 2, 1],
430 [m.bb for m in msg.repeated_nested_message[::-1]])
431
432 def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
433 """Check some different types with the default comparator."""
434 message = message_module.TestAllTypes()
435
436 # TODO(mattp): would testing more scalar types strengthen test?
437 message.repeated_int32.append(1)
438 message.repeated_int32.append(3)
439 message.repeated_int32.append(2)
440 message.repeated_int32.sort()
441 self.assertEqual(message.repeated_int32[0], 1)
442 self.assertEqual(message.repeated_int32[1], 2)
443 self.assertEqual(message.repeated_int32[2], 3)
Austin Schuh40c16522018-10-28 20:27:54 -0700444 self.assertEqual(str(message.repeated_int32), str([1, 2, 3]))
Brian Silverman9c614bc2016-02-15 20:20:02 -0500445
446 message.repeated_float.append(1.1)
447 message.repeated_float.append(1.3)
448 message.repeated_float.append(1.2)
449 message.repeated_float.sort()
450 self.assertAlmostEqual(message.repeated_float[0], 1.1)
451 self.assertAlmostEqual(message.repeated_float[1], 1.2)
452 self.assertAlmostEqual(message.repeated_float[2], 1.3)
453
454 message.repeated_string.append('a')
455 message.repeated_string.append('c')
456 message.repeated_string.append('b')
457 message.repeated_string.sort()
458 self.assertEqual(message.repeated_string[0], 'a')
459 self.assertEqual(message.repeated_string[1], 'b')
460 self.assertEqual(message.repeated_string[2], 'c')
Austin Schuh40c16522018-10-28 20:27:54 -0700461 self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c']))
Brian Silverman9c614bc2016-02-15 20:20:02 -0500462
463 message.repeated_bytes.append(b'a')
464 message.repeated_bytes.append(b'c')
465 message.repeated_bytes.append(b'b')
466 message.repeated_bytes.sort()
467 self.assertEqual(message.repeated_bytes[0], b'a')
468 self.assertEqual(message.repeated_bytes[1], b'b')
469 self.assertEqual(message.repeated_bytes[2], b'c')
Austin Schuh40c16522018-10-28 20:27:54 -0700470 self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c']))
Brian Silverman9c614bc2016-02-15 20:20:02 -0500471
472 def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
473 """Check some different types with custom comparator."""
474 message = message_module.TestAllTypes()
475
476 message.repeated_int32.append(-3)
477 message.repeated_int32.append(-2)
478 message.repeated_int32.append(-1)
479 message.repeated_int32.sort(key=abs)
480 self.assertEqual(message.repeated_int32[0], -1)
481 self.assertEqual(message.repeated_int32[1], -2)
482 self.assertEqual(message.repeated_int32[2], -3)
483
484 message.repeated_string.append('aaa')
485 message.repeated_string.append('bb')
486 message.repeated_string.append('c')
487 message.repeated_string.sort(key=len)
488 self.assertEqual(message.repeated_string[0], 'c')
489 self.assertEqual(message.repeated_string[1], 'bb')
490 self.assertEqual(message.repeated_string[2], 'aaa')
491
492 def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
493 """Check passing a custom comparator to sort a repeated composite field."""
494 message = message_module.TestAllTypes()
495
496 message.repeated_nested_message.add().bb = 1
497 message.repeated_nested_message.add().bb = 3
498 message.repeated_nested_message.add().bb = 2
499 message.repeated_nested_message.add().bb = 6
500 message.repeated_nested_message.add().bb = 5
501 message.repeated_nested_message.add().bb = 4
502 message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
503 self.assertEqual(message.repeated_nested_message[0].bb, 1)
504 self.assertEqual(message.repeated_nested_message[1].bb, 2)
505 self.assertEqual(message.repeated_nested_message[2].bb, 3)
506 self.assertEqual(message.repeated_nested_message[3].bb, 4)
507 self.assertEqual(message.repeated_nested_message[4].bb, 5)
508 self.assertEqual(message.repeated_nested_message[5].bb, 6)
Austin Schuh40c16522018-10-28 20:27:54 -0700509 self.assertEqual(str(message.repeated_nested_message),
510 '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]')
Brian Silverman9c614bc2016-02-15 20:20:02 -0500511
512 def testSortingRepeatedCompositeFieldsStable(self, message_module):
513 """Check passing a custom comparator to sort a repeated composite field."""
514 message = message_module.TestAllTypes()
515
516 message.repeated_nested_message.add().bb = 21
517 message.repeated_nested_message.add().bb = 20
518 message.repeated_nested_message.add().bb = 13
519 message.repeated_nested_message.add().bb = 33
520 message.repeated_nested_message.add().bb = 11
521 message.repeated_nested_message.add().bb = 24
522 message.repeated_nested_message.add().bb = 10
523 message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
524 self.assertEqual(
525 [13, 11, 10, 21, 20, 24, 33],
526 [n.bb for n in message.repeated_nested_message])
527
528 # Make sure that for the C++ implementation, the underlying fields
529 # are actually reordered.
530 pb = message.SerializeToString()
531 message.Clear()
532 message.MergeFromString(pb)
533 self.assertEqual(
534 [13, 11, 10, 21, 20, 24, 33],
535 [n.bb for n in message.repeated_nested_message])
536
537 def testRepeatedCompositeFieldSortArguments(self, message_module):
538 """Check sorting a repeated composite field using list.sort() arguments."""
539 message = message_module.TestAllTypes()
540
541 get_bb = operator.attrgetter('bb')
542 cmp_bb = lambda a, b: cmp(a.bb, b.bb)
543 message.repeated_nested_message.add().bb = 1
544 message.repeated_nested_message.add().bb = 3
545 message.repeated_nested_message.add().bb = 2
546 message.repeated_nested_message.add().bb = 6
547 message.repeated_nested_message.add().bb = 5
548 message.repeated_nested_message.add().bb = 4
549 message.repeated_nested_message.sort(key=get_bb)
550 self.assertEqual([k.bb for k in message.repeated_nested_message],
551 [1, 2, 3, 4, 5, 6])
552 message.repeated_nested_message.sort(key=get_bb, reverse=True)
553 self.assertEqual([k.bb for k in message.repeated_nested_message],
554 [6, 5, 4, 3, 2, 1])
555 if sys.version_info >= (3,): return # No cmp sorting in PY3.
556 message.repeated_nested_message.sort(sort_function=cmp_bb)
557 self.assertEqual([k.bb for k in message.repeated_nested_message],
558 [1, 2, 3, 4, 5, 6])
559 message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
560 self.assertEqual([k.bb for k in message.repeated_nested_message],
561 [6, 5, 4, 3, 2, 1])
562
563 def testRepeatedScalarFieldSortArguments(self, message_module):
564 """Check sorting a scalar field using list.sort() arguments."""
565 message = message_module.TestAllTypes()
566
567 message.repeated_int32.append(-3)
568 message.repeated_int32.append(-2)
569 message.repeated_int32.append(-1)
570 message.repeated_int32.sort(key=abs)
571 self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
572 message.repeated_int32.sort(key=abs, reverse=True)
573 self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
574 if sys.version_info < (3,): # No cmp sorting in PY3.
575 abs_cmp = lambda a, b: cmp(abs(a), abs(b))
576 message.repeated_int32.sort(sort_function=abs_cmp)
577 self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
578 message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
579 self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
580
581 message.repeated_string.append('aaa')
582 message.repeated_string.append('bb')
583 message.repeated_string.append('c')
584 message.repeated_string.sort(key=len)
585 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
586 message.repeated_string.sort(key=len, reverse=True)
587 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
588 if sys.version_info < (3,): # No cmp sorting in PY3.
589 len_cmp = lambda a, b: cmp(len(a), len(b))
590 message.repeated_string.sort(sort_function=len_cmp)
591 self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
592 message.repeated_string.sort(cmp=len_cmp, reverse=True)
593 self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
594
595 def testRepeatedFieldsComparable(self, message_module):
596 m1 = message_module.TestAllTypes()
597 m2 = message_module.TestAllTypes()
598 m1.repeated_int32.append(0)
599 m1.repeated_int32.append(1)
600 m1.repeated_int32.append(2)
601 m2.repeated_int32.append(0)
602 m2.repeated_int32.append(1)
603 m2.repeated_int32.append(2)
604 m1.repeated_nested_message.add().bb = 1
605 m1.repeated_nested_message.add().bb = 2
606 m1.repeated_nested_message.add().bb = 3
607 m2.repeated_nested_message.add().bb = 1
608 m2.repeated_nested_message.add().bb = 2
609 m2.repeated_nested_message.add().bb = 3
610
611 if sys.version_info >= (3,): return # No cmp() in PY3.
612
613 # These comparisons should not raise errors.
614 _ = m1 < m2
615 _ = m1.repeated_nested_message < m2.repeated_nested_message
616
617 # Make sure cmp always works. If it wasn't defined, these would be
618 # id() comparisons and would all fail.
619 self.assertEqual(cmp(m1, m2), 0)
620 self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
621 self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
622 self.assertEqual(cmp(m1.repeated_nested_message,
623 m2.repeated_nested_message), 0)
624 with self.assertRaises(TypeError):
625 # Can't compare repeated composite containers to lists.
626 cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
627
628 # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
629
630 def testRepeatedFieldsAreSequences(self, message_module):
631 m = message_module.TestAllTypes()
632 self.assertIsInstance(m.repeated_int32, collections.MutableSequence)
633 self.assertIsInstance(m.repeated_nested_message,
634 collections.MutableSequence)
635
Austin Schuh40c16522018-10-28 20:27:54 -0700636 def testRepeatedFieldsNotHashable(self, message_module):
637 m = message_module.TestAllTypes()
638 with self.assertRaises(TypeError):
639 hash(m.repeated_int32)
640 with self.assertRaises(TypeError):
641 hash(m.repeated_nested_message)
642
643 def testRepeatedFieldInsideNestedMessage(self, message_module):
644 m = message_module.NestedTestAllTypes()
645 m.payload.repeated_int32.extend([])
646 self.assertTrue(m.HasField('payload'))
647
Brian Silverman9c614bc2016-02-15 20:20:02 -0500648 def ensureNestedMessageExists(self, msg, attribute):
649 """Make sure that a nested message object exists.
650
651 As soon as a nested message attribute is accessed, it will be present in the
652 _fields dict, without being marked as actually being set.
653 """
654 getattr(msg, attribute)
655 self.assertFalse(msg.HasField(attribute))
656
657 def testOneofGetCaseNonexistingField(self, message_module):
658 m = message_module.TestAllTypes()
659 self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
Austin Schuh40c16522018-10-28 20:27:54 -0700660 self.assertRaises(Exception, m.WhichOneof, 0)
Brian Silverman9c614bc2016-02-15 20:20:02 -0500661
662 def testOneofDefaultValues(self, message_module):
663 m = message_module.TestAllTypes()
664 self.assertIs(None, m.WhichOneof('oneof_field'))
665 self.assertFalse(m.HasField('oneof_uint32'))
666
667 # Oneof is set even when setting it to a default value.
668 m.oneof_uint32 = 0
669 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
670 self.assertTrue(m.HasField('oneof_uint32'))
671 self.assertFalse(m.HasField('oneof_string'))
672
673 m.oneof_string = ""
674 self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
675 self.assertTrue(m.HasField('oneof_string'))
676 self.assertFalse(m.HasField('oneof_uint32'))
677
678 def testOneofSemantics(self, message_module):
679 m = message_module.TestAllTypes()
680 self.assertIs(None, m.WhichOneof('oneof_field'))
681
682 m.oneof_uint32 = 11
683 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
684 self.assertTrue(m.HasField('oneof_uint32'))
685
686 m.oneof_string = u'foo'
687 self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
688 self.assertFalse(m.HasField('oneof_uint32'))
689 self.assertTrue(m.HasField('oneof_string'))
690
691 # Read nested message accessor without accessing submessage.
692 m.oneof_nested_message
693 self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
694 self.assertTrue(m.HasField('oneof_string'))
695 self.assertFalse(m.HasField('oneof_nested_message'))
696
697 # Read accessor of nested message without accessing submessage.
698 m.oneof_nested_message.bb
699 self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
700 self.assertTrue(m.HasField('oneof_string'))
701 self.assertFalse(m.HasField('oneof_nested_message'))
702
703 m.oneof_nested_message.bb = 11
704 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
705 self.assertFalse(m.HasField('oneof_string'))
706 self.assertTrue(m.HasField('oneof_nested_message'))
707
708 m.oneof_bytes = b'bb'
709 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
710 self.assertFalse(m.HasField('oneof_nested_message'))
711 self.assertTrue(m.HasField('oneof_bytes'))
712
713 def testOneofCompositeFieldReadAccess(self, message_module):
714 m = message_module.TestAllTypes()
715 m.oneof_uint32 = 11
716
717 self.ensureNestedMessageExists(m, 'oneof_nested_message')
718 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
719 self.assertEqual(11, m.oneof_uint32)
720
721 def testOneofWhichOneof(self, message_module):
722 m = message_module.TestAllTypes()
723 self.assertIs(None, m.WhichOneof('oneof_field'))
724 if message_module is unittest_pb2:
725 self.assertFalse(m.HasField('oneof_field'))
726
727 m.oneof_uint32 = 11
728 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
729 if message_module is unittest_pb2:
730 self.assertTrue(m.HasField('oneof_field'))
731
732 m.oneof_bytes = b'bb'
733 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
734
735 m.ClearField('oneof_bytes')
736 self.assertIs(None, m.WhichOneof('oneof_field'))
737 if message_module is unittest_pb2:
738 self.assertFalse(m.HasField('oneof_field'))
739
740 def testOneofClearField(self, message_module):
741 m = message_module.TestAllTypes()
742 m.oneof_uint32 = 11
743 m.ClearField('oneof_field')
744 if message_module is unittest_pb2:
745 self.assertFalse(m.HasField('oneof_field'))
746 self.assertFalse(m.HasField('oneof_uint32'))
747 self.assertIs(None, m.WhichOneof('oneof_field'))
748
749 def testOneofClearSetField(self, message_module):
750 m = message_module.TestAllTypes()
751 m.oneof_uint32 = 11
752 m.ClearField('oneof_uint32')
753 if message_module is unittest_pb2:
754 self.assertFalse(m.HasField('oneof_field'))
755 self.assertFalse(m.HasField('oneof_uint32'))
756 self.assertIs(None, m.WhichOneof('oneof_field'))
757
758 def testOneofClearUnsetField(self, message_module):
759 m = message_module.TestAllTypes()
760 m.oneof_uint32 = 11
761 self.ensureNestedMessageExists(m, 'oneof_nested_message')
762 m.ClearField('oneof_nested_message')
763 self.assertEqual(11, m.oneof_uint32)
764 if message_module is unittest_pb2:
765 self.assertTrue(m.HasField('oneof_field'))
766 self.assertTrue(m.HasField('oneof_uint32'))
767 self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
768
769 def testOneofDeserialize(self, message_module):
770 m = message_module.TestAllTypes()
771 m.oneof_uint32 = 11
772 m2 = message_module.TestAllTypes()
773 m2.ParseFromString(m.SerializeToString())
774 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
775
776 def testOneofCopyFrom(self, message_module):
777 m = message_module.TestAllTypes()
778 m.oneof_uint32 = 11
779 m2 = message_module.TestAllTypes()
780 m2.CopyFrom(m)
781 self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
782
783 def testOneofNestedMergeFrom(self, message_module):
784 m = message_module.NestedTestAllTypes()
785 m.payload.oneof_uint32 = 11
786 m2 = message_module.NestedTestAllTypes()
787 m2.payload.oneof_bytes = b'bb'
788 m2.child.payload.oneof_bytes = b'bb'
789 m2.MergeFrom(m)
790 self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
791 self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
792
793 def testOneofMessageMergeFrom(self, message_module):
794 m = message_module.NestedTestAllTypes()
795 m.payload.oneof_nested_message.bb = 11
796 m.child.payload.oneof_nested_message.bb = 12
797 m2 = message_module.NestedTestAllTypes()
798 m2.payload.oneof_uint32 = 13
799 m2.MergeFrom(m)
800 self.assertEqual('oneof_nested_message',
801 m2.payload.WhichOneof('oneof_field'))
802 self.assertEqual('oneof_nested_message',
803 m2.child.payload.WhichOneof('oneof_field'))
804
805 def testOneofNestedMessageInit(self, message_module):
806 m = message_module.TestAllTypes(
807 oneof_nested_message=message_module.TestAllTypes.NestedMessage())
808 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
809
810 def testOneofClear(self, message_module):
811 m = message_module.TestAllTypes()
812 m.oneof_uint32 = 11
813 m.Clear()
814 self.assertIsNone(m.WhichOneof('oneof_field'))
815 m.oneof_bytes = b'bb'
816 self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
817
818 def testAssignByteStringToUnicodeField(self, message_module):
819 """Assigning a byte string to a string field should result
820 in the value being converted to a Unicode string."""
821 m = message_module.TestAllTypes()
822 m.optional_string = str('')
823 self.assertIsInstance(m.optional_string, six.text_type)
824
825 def testLongValuedSlice(self, message_module):
826 """It should be possible to use long-valued indicies in slices
827
828 This didn't used to work in the v2 C++ implementation.
829 """
830 m = message_module.TestAllTypes()
831
832 # Repeated scalar
833 m.repeated_int32.append(1)
834 sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
835 self.assertEqual(len(m.repeated_int32), len(sl))
836
837 # Repeated composite
838 m.repeated_nested_message.add().bb = 3
839 sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
840 self.assertEqual(len(m.repeated_nested_message), len(sl))
841
842 def testExtendShouldNotSwallowExceptions(self, message_module):
843 """This didn't use to work in the v2 C++ implementation."""
844 m = message_module.TestAllTypes()
845 with self.assertRaises(NameError) as _:
846 m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable
847 with self.assertRaises(NameError) as _:
848 m.repeated_nested_enum.extend(
849 a for i in range(10)) # pylint: disable=undefined-variable
850
851 FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
852
853 def testExtendInt32WithNothing(self, message_module):
854 """Test no-ops extending repeated int32 fields."""
855 m = message_module.TestAllTypes()
856 self.assertSequenceEqual([], m.repeated_int32)
857
858 # TODO(ptucker): Deprecate this behavior. b/18413862
859 for falsy_value in MessageTest.FALSY_VALUES:
860 m.repeated_int32.extend(falsy_value)
861 self.assertSequenceEqual([], m.repeated_int32)
862
863 m.repeated_int32.extend([])
864 self.assertSequenceEqual([], m.repeated_int32)
865
866 def testExtendFloatWithNothing(self, message_module):
867 """Test no-ops extending repeated float fields."""
868 m = message_module.TestAllTypes()
869 self.assertSequenceEqual([], m.repeated_float)
870
871 # TODO(ptucker): Deprecate this behavior. b/18413862
872 for falsy_value in MessageTest.FALSY_VALUES:
873 m.repeated_float.extend(falsy_value)
874 self.assertSequenceEqual([], m.repeated_float)
875
876 m.repeated_float.extend([])
877 self.assertSequenceEqual([], m.repeated_float)
878
879 def testExtendStringWithNothing(self, message_module):
880 """Test no-ops extending repeated string fields."""
881 m = message_module.TestAllTypes()
882 self.assertSequenceEqual([], m.repeated_string)
883
884 # TODO(ptucker): Deprecate this behavior. b/18413862
885 for falsy_value in MessageTest.FALSY_VALUES:
886 m.repeated_string.extend(falsy_value)
887 self.assertSequenceEqual([], m.repeated_string)
888
889 m.repeated_string.extend([])
890 self.assertSequenceEqual([], m.repeated_string)
891
892 def testExtendInt32WithPythonList(self, message_module):
893 """Test extending repeated int32 fields with python lists."""
894 m = message_module.TestAllTypes()
895 self.assertSequenceEqual([], m.repeated_int32)
896 m.repeated_int32.extend([0])
897 self.assertSequenceEqual([0], m.repeated_int32)
898 m.repeated_int32.extend([1, 2])
899 self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
900 m.repeated_int32.extend([3, 4])
901 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
902
903 def testExtendFloatWithPythonList(self, message_module):
904 """Test extending repeated float fields with python lists."""
905 m = message_module.TestAllTypes()
906 self.assertSequenceEqual([], m.repeated_float)
907 m.repeated_float.extend([0.0])
908 self.assertSequenceEqual([0.0], m.repeated_float)
909 m.repeated_float.extend([1.0, 2.0])
910 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
911 m.repeated_float.extend([3.0, 4.0])
912 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
913
914 def testExtendStringWithPythonList(self, message_module):
915 """Test extending repeated string fields with python lists."""
916 m = message_module.TestAllTypes()
917 self.assertSequenceEqual([], m.repeated_string)
918 m.repeated_string.extend([''])
919 self.assertSequenceEqual([''], m.repeated_string)
920 m.repeated_string.extend(['11', '22'])
921 self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
922 m.repeated_string.extend(['33', '44'])
923 self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
924
925 def testExtendStringWithString(self, message_module):
926 """Test extending repeated string fields with characters from a string."""
927 m = message_module.TestAllTypes()
928 self.assertSequenceEqual([], m.repeated_string)
929 m.repeated_string.extend('abc')
930 self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
931
932 class TestIterable(object):
933 """This iterable object mimics the behavior of numpy.array.
934
935 __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
936
937 """
938
939 def __init__(self, values=None):
940 self._list = values or []
941
942 def __nonzero__(self):
943 size = len(self._list)
944 if size == 0:
945 return False
946 if size == 1:
947 return bool(self._list[0])
948 raise ValueError('Truth value is ambiguous.')
949
950 def __len__(self):
951 return len(self._list)
952
953 def __iter__(self):
954 return self._list.__iter__()
955
956 def testExtendInt32WithIterable(self, message_module):
957 """Test extending repeated int32 fields with iterable."""
958 m = message_module.TestAllTypes()
959 self.assertSequenceEqual([], m.repeated_int32)
960 m.repeated_int32.extend(MessageTest.TestIterable([]))
961 self.assertSequenceEqual([], m.repeated_int32)
962 m.repeated_int32.extend(MessageTest.TestIterable([0]))
963 self.assertSequenceEqual([0], m.repeated_int32)
964 m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
965 self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
966 m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
967 self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
968
969 def testExtendFloatWithIterable(self, message_module):
970 """Test extending repeated float fields with iterable."""
971 m = message_module.TestAllTypes()
972 self.assertSequenceEqual([], m.repeated_float)
973 m.repeated_float.extend(MessageTest.TestIterable([]))
974 self.assertSequenceEqual([], m.repeated_float)
975 m.repeated_float.extend(MessageTest.TestIterable([0.0]))
976 self.assertSequenceEqual([0.0], m.repeated_float)
977 m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
978 self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
979 m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
980 self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
981
982 def testExtendStringWithIterable(self, message_module):
983 """Test extending repeated string fields with iterable."""
984 m = message_module.TestAllTypes()
985 self.assertSequenceEqual([], m.repeated_string)
986 m.repeated_string.extend(MessageTest.TestIterable([]))
987 self.assertSequenceEqual([], m.repeated_string)
988 m.repeated_string.extend(MessageTest.TestIterable(['']))
989 self.assertSequenceEqual([''], m.repeated_string)
990 m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
991 self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
992 m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
993 self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
994
995 def testPickleRepeatedScalarContainer(self, message_module):
996 # TODO(tibell): The pure-Python implementation support pickling of
997 # scalar containers in *some* cases. For now the cpp2 version
998 # throws an exception to avoid a segfault. Investigate if we
999 # want to support pickling of these fields.
1000 #
1001 # For more information see: https://b2.corp.google.com/u/0/issues/18677897
1002 if (api_implementation.Type() != 'cpp' or
1003 api_implementation.Version() == 2):
1004 return
1005 m = message_module.TestAllTypes()
1006 with self.assertRaises(pickle.PickleError) as _:
1007 pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
1008
1009 def testSortEmptyRepeatedCompositeContainer(self, message_module):
1010 """Exercise a scenario that has led to segfaults in the past.
1011 """
1012 m = message_module.TestAllTypes()
1013 m.repeated_nested_message.sort()
1014
1015 def testHasFieldOnRepeatedField(self, message_module):
1016 """Using HasField on a repeated field should raise an exception.
1017 """
1018 m = message_module.TestAllTypes()
1019 with self.assertRaises(ValueError) as _:
1020 m.HasField('repeated_int32')
1021
1022 def testRepeatedScalarFieldPop(self, message_module):
1023 m = message_module.TestAllTypes()
1024 with self.assertRaises(IndexError) as _:
1025 m.repeated_int32.pop()
1026 m.repeated_int32.extend(range(5))
1027 self.assertEqual(4, m.repeated_int32.pop())
1028 self.assertEqual(0, m.repeated_int32.pop(0))
1029 self.assertEqual(2, m.repeated_int32.pop(1))
1030 self.assertEqual([1, 3], m.repeated_int32)
1031
1032 def testRepeatedCompositeFieldPop(self, message_module):
1033 m = message_module.TestAllTypes()
1034 with self.assertRaises(IndexError) as _:
1035 m.repeated_nested_message.pop()
Austin Schuh40c16522018-10-28 20:27:54 -07001036 with self.assertRaises(TypeError) as _:
1037 m.repeated_nested_message.pop('0')
Brian Silverman9c614bc2016-02-15 20:20:02 -05001038 for i in range(5):
1039 n = m.repeated_nested_message.add()
1040 n.bb = i
1041 self.assertEqual(4, m.repeated_nested_message.pop().bb)
1042 self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
1043 self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
1044 self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
1045
Austin Schuh40c16522018-10-28 20:27:54 -07001046 def testRepeatedCompareWithSelf(self, message_module):
1047 m = message_module.TestAllTypes()
1048 for i in range(5):
1049 m.repeated_int32.insert(i, i)
1050 n = m.repeated_nested_message.add()
1051 n.bb = i
1052 self.assertSequenceEqual(m.repeated_int32, m.repeated_int32)
1053 self.assertEqual(m.repeated_nested_message, m.repeated_nested_message)
1054
1055 def testReleasedNestedMessages(self, message_module):
1056 """A case that lead to a segfault when a message detached from its parent
1057 container has itself a child container.
1058 """
1059 m = message_module.NestedTestAllTypes()
1060 m = m.repeated_child.add()
1061 m = m.child
1062 m = m.repeated_child.add()
1063 self.assertEqual(m.payload.optional_int32, 0)
1064
1065 def testSetRepeatedComposite(self, message_module):
1066 m = message_module.TestAllTypes()
1067 with self.assertRaises(AttributeError):
1068 m.repeated_int32 = []
1069 m.repeated_int32.append(1)
1070 if api_implementation.Type() == 'cpp':
1071 # For test coverage: cpp has a different path if composite
1072 # field is in cache
1073 with self.assertRaises(TypeError):
1074 m.repeated_int32 = []
1075 else:
1076 with self.assertRaises(AttributeError):
1077 m.repeated_int32 = []
1078
Brian Silverman9c614bc2016-02-15 20:20:02 -05001079
1080# Class to test proto2-only features (required, extensions, etc.)
Austin Schuh40c16522018-10-28 20:27:54 -07001081class Proto2Test(BaseTestCase):
Brian Silverman9c614bc2016-02-15 20:20:02 -05001082
1083 def testFieldPresence(self):
1084 message = unittest_pb2.TestAllTypes()
1085
1086 self.assertFalse(message.HasField("optional_int32"))
1087 self.assertFalse(message.HasField("optional_bool"))
1088 self.assertFalse(message.HasField("optional_nested_message"))
1089
1090 with self.assertRaises(ValueError):
1091 message.HasField("field_doesnt_exist")
1092
1093 with self.assertRaises(ValueError):
1094 message.HasField("repeated_int32")
1095 with self.assertRaises(ValueError):
1096 message.HasField("repeated_nested_message")
1097
1098 self.assertEqual(0, message.optional_int32)
1099 self.assertEqual(False, message.optional_bool)
1100 self.assertEqual(0, message.optional_nested_message.bb)
1101
1102 # Fields are set even when setting the values to default values.
1103 message.optional_int32 = 0
1104 message.optional_bool = False
1105 message.optional_nested_message.bb = 0
1106 self.assertTrue(message.HasField("optional_int32"))
1107 self.assertTrue(message.HasField("optional_bool"))
1108 self.assertTrue(message.HasField("optional_nested_message"))
1109
1110 # Set the fields to non-default values.
1111 message.optional_int32 = 5
1112 message.optional_bool = True
1113 message.optional_nested_message.bb = 15
1114
1115 self.assertTrue(message.HasField("optional_int32"))
1116 self.assertTrue(message.HasField("optional_bool"))
1117 self.assertTrue(message.HasField("optional_nested_message"))
1118
1119 # Clearing the fields unsets them and resets their value to default.
1120 message.ClearField("optional_int32")
1121 message.ClearField("optional_bool")
1122 message.ClearField("optional_nested_message")
1123
1124 self.assertFalse(message.HasField("optional_int32"))
1125 self.assertFalse(message.HasField("optional_bool"))
1126 self.assertFalse(message.HasField("optional_nested_message"))
1127 self.assertEqual(0, message.optional_int32)
1128 self.assertEqual(False, message.optional_bool)
1129 self.assertEqual(0, message.optional_nested_message.bb)
1130
Brian Silverman9c614bc2016-02-15 20:20:02 -05001131 def testAssignInvalidEnum(self):
Austin Schuh40c16522018-10-28 20:27:54 -07001132 """Assigning an invalid enum number is not allowed in proto2."""
Brian Silverman9c614bc2016-02-15 20:20:02 -05001133 m = unittest_pb2.TestAllTypes()
1134
Austin Schuh40c16522018-10-28 20:27:54 -07001135 # Proto2 can not assign unknown enum.
Brian Silverman9c614bc2016-02-15 20:20:02 -05001136 with self.assertRaises(ValueError) as _:
1137 m.optional_nested_enum = 1234567
1138 self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
Austin Schuh40c16522018-10-28 20:27:54 -07001139 # Assignment is a different code path than append for the C++ impl.
1140 m.repeated_nested_enum.append(2)
1141 m.repeated_nested_enum[0] = 2
1142 with self.assertRaises(ValueError):
1143 m.repeated_nested_enum[0] = 123456
1144
1145 # Unknown enum value can be parsed but is ignored.
1146 m2 = unittest_proto3_arena_pb2.TestAllTypes()
1147 m2.optional_nested_enum = 1234567
1148 m2.repeated_nested_enum.append(7654321)
1149 serialized = m2.SerializeToString()
1150
1151 m3 = unittest_pb2.TestAllTypes()
1152 m3.ParseFromString(serialized)
1153 self.assertFalse(m3.HasField('optional_nested_enum'))
1154 # 1 is the default value for optional_nested_enum.
1155 self.assertEqual(1, m3.optional_nested_enum)
1156 self.assertEqual(0, len(m3.repeated_nested_enum))
1157 m2.Clear()
1158 m2.ParseFromString(m3.SerializeToString())
1159 self.assertEqual(1234567, m2.optional_nested_enum)
1160 self.assertEqual(7654321, m2.repeated_nested_enum[0])
1161
1162 def testUnknownEnumMap(self):
1163 m = map_proto2_unittest_pb2.TestEnumMap()
1164 m.known_map_field[123] = 0
1165 with self.assertRaises(ValueError):
1166 m.unknown_map_field[1] = 123
1167
1168 def testExtensionsErrors(self):
1169 msg = unittest_pb2.TestAllTypes()
1170 self.assertRaises(AttributeError, getattr, msg, 'Extensions')
Brian Silverman9c614bc2016-02-15 20:20:02 -05001171
1172 def testGoldenExtensions(self):
1173 golden_data = test_util.GoldenFileData('golden_message')
1174 golden_message = unittest_pb2.TestAllExtensions()
1175 golden_message.ParseFromString(golden_data)
1176 all_set = unittest_pb2.TestAllExtensions()
1177 test_util.SetAllExtensions(all_set)
1178 self.assertEqual(all_set, golden_message)
1179 self.assertEqual(golden_data, golden_message.SerializeToString())
1180 golden_copy = copy.deepcopy(golden_message)
1181 self.assertEqual(golden_data, golden_copy.SerializeToString())
1182
1183 def testGoldenPackedExtensions(self):
1184 golden_data = test_util.GoldenFileData('golden_packed_fields_message')
1185 golden_message = unittest_pb2.TestPackedExtensions()
1186 golden_message.ParseFromString(golden_data)
1187 all_set = unittest_pb2.TestPackedExtensions()
1188 test_util.SetAllPackedExtensions(all_set)
1189 self.assertEqual(all_set, golden_message)
1190 self.assertEqual(golden_data, all_set.SerializeToString())
1191 golden_copy = copy.deepcopy(golden_message)
1192 self.assertEqual(golden_data, golden_copy.SerializeToString())
1193
1194 def testPickleIncompleteProto(self):
1195 golden_message = unittest_pb2.TestRequired(a=1)
1196 pickled_message = pickle.dumps(golden_message)
1197
1198 unpickled_message = pickle.loads(pickled_message)
1199 self.assertEqual(unpickled_message, golden_message)
1200 self.assertEqual(unpickled_message.a, 1)
1201 # This is still an incomplete proto - so serializing should fail
1202 self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
1203
1204
1205 # TODO(haberman): this isn't really a proto2-specific test except that this
1206 # message has a required field in it. Should probably be factored out so
1207 # that we can test the other parts with proto3.
1208 def testParsingMerge(self):
1209 """Check the merge behavior when a required or optional field appears
1210 multiple times in the input."""
1211 messages = [
1212 unittest_pb2.TestAllTypes(),
1213 unittest_pb2.TestAllTypes(),
1214 unittest_pb2.TestAllTypes() ]
1215 messages[0].optional_int32 = 1
1216 messages[1].optional_int64 = 2
1217 messages[2].optional_int32 = 3
1218 messages[2].optional_string = 'hello'
1219
1220 merged_message = unittest_pb2.TestAllTypes()
1221 merged_message.optional_int32 = 3
1222 merged_message.optional_int64 = 2
1223 merged_message.optional_string = 'hello'
1224
1225 generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
1226 generator.field1.extend(messages)
1227 generator.field2.extend(messages)
1228 generator.field3.extend(messages)
1229 generator.ext1.extend(messages)
1230 generator.ext2.extend(messages)
1231 generator.group1.add().field1.MergeFrom(messages[0])
1232 generator.group1.add().field1.MergeFrom(messages[1])
1233 generator.group1.add().field1.MergeFrom(messages[2])
1234 generator.group2.add().field1.MergeFrom(messages[0])
1235 generator.group2.add().field1.MergeFrom(messages[1])
1236 generator.group2.add().field1.MergeFrom(messages[2])
1237
1238 data = generator.SerializeToString()
1239 parsing_merge = unittest_pb2.TestParsingMerge()
1240 parsing_merge.ParseFromString(data)
1241
1242 # Required and optional fields should be merged.
1243 self.assertEqual(parsing_merge.required_all_types, merged_message)
1244 self.assertEqual(parsing_merge.optional_all_types, merged_message)
1245 self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
1246 merged_message)
1247 self.assertEqual(parsing_merge.Extensions[
1248 unittest_pb2.TestParsingMerge.optional_ext],
1249 merged_message)
1250
1251 # Repeated fields should not be merged.
1252 self.assertEqual(len(parsing_merge.repeated_all_types), 3)
1253 self.assertEqual(len(parsing_merge.repeatedgroup), 3)
1254 self.assertEqual(len(parsing_merge.Extensions[
1255 unittest_pb2.TestParsingMerge.repeated_ext]), 3)
1256
1257 def testPythonicInit(self):
1258 message = unittest_pb2.TestAllTypes(
1259 optional_int32=100,
1260 optional_fixed32=200,
1261 optional_float=300.5,
1262 optional_bytes=b'x',
1263 optionalgroup={'a': 400},
1264 optional_nested_message={'bb': 500},
Austin Schuh40c16522018-10-28 20:27:54 -07001265 optional_foreign_message={},
Brian Silverman9c614bc2016-02-15 20:20:02 -05001266 optional_nested_enum='BAZ',
1267 repeatedgroup=[{'a': 600},
1268 {'a': 700}],
1269 repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
1270 default_int32=800,
1271 oneof_string='y')
1272 self.assertIsInstance(message, unittest_pb2.TestAllTypes)
1273 self.assertEqual(100, message.optional_int32)
1274 self.assertEqual(200, message.optional_fixed32)
1275 self.assertEqual(300.5, message.optional_float)
1276 self.assertEqual(b'x', message.optional_bytes)
1277 self.assertEqual(400, message.optionalgroup.a)
Austin Schuh40c16522018-10-28 20:27:54 -07001278 self.assertIsInstance(message.optional_nested_message,
1279 unittest_pb2.TestAllTypes.NestedMessage)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001280 self.assertEqual(500, message.optional_nested_message.bb)
Austin Schuh40c16522018-10-28 20:27:54 -07001281 self.assertTrue(message.HasField('optional_foreign_message'))
1282 self.assertEqual(message.optional_foreign_message,
1283 unittest_pb2.ForeignMessage())
Brian Silverman9c614bc2016-02-15 20:20:02 -05001284 self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1285 message.optional_nested_enum)
1286 self.assertEqual(2, len(message.repeatedgroup))
1287 self.assertEqual(600, message.repeatedgroup[0].a)
1288 self.assertEqual(700, message.repeatedgroup[1].a)
1289 self.assertEqual(2, len(message.repeated_nested_enum))
1290 self.assertEqual(unittest_pb2.TestAllTypes.FOO,
1291 message.repeated_nested_enum[0])
1292 self.assertEqual(unittest_pb2.TestAllTypes.BAR,
1293 message.repeated_nested_enum[1])
1294 self.assertEqual(800, message.default_int32)
1295 self.assertEqual('y', message.oneof_string)
1296 self.assertFalse(message.HasField('optional_int64'))
1297 self.assertEqual(0, len(message.repeated_float))
1298 self.assertEqual(42, message.default_int64)
1299
1300 message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
1301 self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
1302 message.optional_nested_enum)
1303
1304 with self.assertRaises(ValueError):
1305 unittest_pb2.TestAllTypes(
1306 optional_nested_message={'INVALID_NESTED_FIELD': 17})
1307
1308 with self.assertRaises(TypeError):
1309 unittest_pb2.TestAllTypes(
1310 optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
1311
1312 with self.assertRaises(ValueError):
1313 unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
1314
1315 with self.assertRaises(ValueError):
1316 unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
1317
1318
Austin Schuh40c16522018-10-28 20:27:54 -07001319
Brian Silverman9c614bc2016-02-15 20:20:02 -05001320# Class to test proto3-only features/behavior (updated field presence & enums)
Austin Schuh40c16522018-10-28 20:27:54 -07001321class Proto3Test(BaseTestCase):
Brian Silverman9c614bc2016-02-15 20:20:02 -05001322
1323 # Utility method for comparing equality with a map.
1324 def assertMapIterEquals(self, map_iter, dict_value):
1325 # Avoid mutating caller's copy.
1326 dict_value = dict(dict_value)
1327
1328 for k, v in map_iter:
1329 self.assertEqual(v, dict_value[k])
1330 del dict_value[k]
1331
1332 self.assertEqual({}, dict_value)
1333
1334 def testFieldPresence(self):
1335 message = unittest_proto3_arena_pb2.TestAllTypes()
1336
1337 # We can't test presence of non-repeated, non-submessage fields.
1338 with self.assertRaises(ValueError):
1339 message.HasField('optional_int32')
1340 with self.assertRaises(ValueError):
1341 message.HasField('optional_float')
1342 with self.assertRaises(ValueError):
1343 message.HasField('optional_string')
1344 with self.assertRaises(ValueError):
1345 message.HasField('optional_bool')
1346
1347 # But we can still test presence of submessage fields.
1348 self.assertFalse(message.HasField('optional_nested_message'))
1349
1350 # As with proto2, we can't test presence of fields that don't exist, or
1351 # repeated fields.
1352 with self.assertRaises(ValueError):
1353 message.HasField('field_doesnt_exist')
1354
1355 with self.assertRaises(ValueError):
1356 message.HasField('repeated_int32')
1357 with self.assertRaises(ValueError):
1358 message.HasField('repeated_nested_message')
1359
1360 # Fields should default to their type-specific default.
1361 self.assertEqual(0, message.optional_int32)
1362 self.assertEqual(0, message.optional_float)
1363 self.assertEqual('', message.optional_string)
1364 self.assertEqual(False, message.optional_bool)
1365 self.assertEqual(0, message.optional_nested_message.bb)
1366
1367 # Setting a submessage should still return proper presence information.
1368 message.optional_nested_message.bb = 0
1369 self.assertTrue(message.HasField('optional_nested_message'))
1370
1371 # Set the fields to non-default values.
1372 message.optional_int32 = 5
1373 message.optional_float = 1.1
1374 message.optional_string = 'abc'
1375 message.optional_bool = True
1376 message.optional_nested_message.bb = 15
1377
1378 # Clearing the fields unsets them and resets their value to default.
1379 message.ClearField('optional_int32')
1380 message.ClearField('optional_float')
1381 message.ClearField('optional_string')
1382 message.ClearField('optional_bool')
1383 message.ClearField('optional_nested_message')
1384
1385 self.assertEqual(0, message.optional_int32)
1386 self.assertEqual(0, message.optional_float)
1387 self.assertEqual('', message.optional_string)
1388 self.assertEqual(False, message.optional_bool)
1389 self.assertEqual(0, message.optional_nested_message.bb)
1390
1391 def testAssignUnknownEnum(self):
1392 """Assigning an unknown enum value is allowed and preserves the value."""
1393 m = unittest_proto3_arena_pb2.TestAllTypes()
1394
Austin Schuh40c16522018-10-28 20:27:54 -07001395 # Proto3 can assign unknown enums.
Brian Silverman9c614bc2016-02-15 20:20:02 -05001396 m.optional_nested_enum = 1234567
1397 self.assertEqual(1234567, m.optional_nested_enum)
1398 m.repeated_nested_enum.append(22334455)
1399 self.assertEqual(22334455, m.repeated_nested_enum[0])
1400 # Assignment is a different code path than append for the C++ impl.
1401 m.repeated_nested_enum[0] = 7654321
1402 self.assertEqual(7654321, m.repeated_nested_enum[0])
1403 serialized = m.SerializeToString()
1404
1405 m2 = unittest_proto3_arena_pb2.TestAllTypes()
1406 m2.ParseFromString(serialized)
1407 self.assertEqual(1234567, m2.optional_nested_enum)
1408 self.assertEqual(7654321, m2.repeated_nested_enum[0])
1409
1410 # Map isn't really a proto3-only feature. But there is no proto2 equivalent
1411 # of google/protobuf/map_unittest.proto right now, so it's not easy to
1412 # test both with the same test like we do for the other proto2/proto3 tests.
Austin Schuh40c16522018-10-28 20:27:54 -07001413 # (google/protobuf/map_proto2_unittest.proto is very different in the set
Brian Silverman9c614bc2016-02-15 20:20:02 -05001414 # of messages and fields it contains).
1415 def testScalarMapDefaults(self):
1416 msg = map_unittest_pb2.TestMap()
1417
1418 # Scalars start out unset.
1419 self.assertFalse(-123 in msg.map_int32_int32)
1420 self.assertFalse(-2**33 in msg.map_int64_int64)
1421 self.assertFalse(123 in msg.map_uint32_uint32)
1422 self.assertFalse(2**33 in msg.map_uint64_uint64)
Austin Schuh40c16522018-10-28 20:27:54 -07001423 self.assertFalse(123 in msg.map_int32_double)
1424 self.assertFalse(False in msg.map_bool_bool)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001425 self.assertFalse('abc' in msg.map_string_string)
Austin Schuh40c16522018-10-28 20:27:54 -07001426 self.assertFalse(111 in msg.map_int32_bytes)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001427 self.assertFalse(888 in msg.map_int32_enum)
1428
1429 # Accessing an unset key returns the default.
1430 self.assertEqual(0, msg.map_int32_int32[-123])
1431 self.assertEqual(0, msg.map_int64_int64[-2**33])
1432 self.assertEqual(0, msg.map_uint32_uint32[123])
1433 self.assertEqual(0, msg.map_uint64_uint64[2**33])
Austin Schuh40c16522018-10-28 20:27:54 -07001434 self.assertEqual(0.0, msg.map_int32_double[123])
1435 self.assertTrue(isinstance(msg.map_int32_double[123], float))
1436 self.assertEqual(False, msg.map_bool_bool[False])
1437 self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
Brian Silverman9c614bc2016-02-15 20:20:02 -05001438 self.assertEqual('', msg.map_string_string['abc'])
Austin Schuh40c16522018-10-28 20:27:54 -07001439 self.assertEqual(b'', msg.map_int32_bytes[111])
Brian Silverman9c614bc2016-02-15 20:20:02 -05001440 self.assertEqual(0, msg.map_int32_enum[888])
1441
1442 # It also sets the value in the map
1443 self.assertTrue(-123 in msg.map_int32_int32)
1444 self.assertTrue(-2**33 in msg.map_int64_int64)
1445 self.assertTrue(123 in msg.map_uint32_uint32)
1446 self.assertTrue(2**33 in msg.map_uint64_uint64)
Austin Schuh40c16522018-10-28 20:27:54 -07001447 self.assertTrue(123 in msg.map_int32_double)
1448 self.assertTrue(False in msg.map_bool_bool)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001449 self.assertTrue('abc' in msg.map_string_string)
Austin Schuh40c16522018-10-28 20:27:54 -07001450 self.assertTrue(111 in msg.map_int32_bytes)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001451 self.assertTrue(888 in msg.map_int32_enum)
1452
1453 self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
1454
1455 # Accessing an unset key still throws TypeError if the type of the key
1456 # is incorrect.
1457 with self.assertRaises(TypeError):
1458 msg.map_string_string[123]
1459
1460 with self.assertRaises(TypeError):
1461 123 in msg.map_string_string
1462
1463 def testMapGet(self):
1464 # Need to test that get() properly returns the default, even though the dict
1465 # has defaultdict-like semantics.
1466 msg = map_unittest_pb2.TestMap()
1467
1468 self.assertIsNone(msg.map_int32_int32.get(5))
1469 self.assertEqual(10, msg.map_int32_int32.get(5, 10))
1470 self.assertIsNone(msg.map_int32_int32.get(5))
1471
1472 msg.map_int32_int32[5] = 15
1473 self.assertEqual(15, msg.map_int32_int32.get(5))
Austin Schuh40c16522018-10-28 20:27:54 -07001474 self.assertEqual(15, msg.map_int32_int32.get(5))
1475 with self.assertRaises(TypeError):
1476 msg.map_int32_int32.get('')
Brian Silverman9c614bc2016-02-15 20:20:02 -05001477
1478 self.assertIsNone(msg.map_int32_foreign_message.get(5))
1479 self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
1480
1481 submsg = msg.map_int32_foreign_message[5]
1482 self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
Austin Schuh40c16522018-10-28 20:27:54 -07001483 with self.assertRaises(TypeError):
1484 msg.map_int32_foreign_message.get('')
Brian Silverman9c614bc2016-02-15 20:20:02 -05001485
1486 def testScalarMap(self):
1487 msg = map_unittest_pb2.TestMap()
1488
1489 self.assertEqual(0, len(msg.map_int32_int32))
1490 self.assertFalse(5 in msg.map_int32_int32)
1491
1492 msg.map_int32_int32[-123] = -456
1493 msg.map_int64_int64[-2**33] = -2**34
1494 msg.map_uint32_uint32[123] = 456
1495 msg.map_uint64_uint64[2**33] = 2**34
Austin Schuh40c16522018-10-28 20:27:54 -07001496 msg.map_int32_float[2] = 1.2
1497 msg.map_int32_double[1] = 3.3
Brian Silverman9c614bc2016-02-15 20:20:02 -05001498 msg.map_string_string['abc'] = '123'
Austin Schuh40c16522018-10-28 20:27:54 -07001499 msg.map_bool_bool[True] = True
Brian Silverman9c614bc2016-02-15 20:20:02 -05001500 msg.map_int32_enum[888] = 2
Austin Schuh40c16522018-10-28 20:27:54 -07001501 # Unknown numeric enum is supported in proto3.
1502 msg.map_int32_enum[123] = 456
Brian Silverman9c614bc2016-02-15 20:20:02 -05001503
1504 self.assertEqual([], msg.FindInitializationErrors())
1505
1506 self.assertEqual(1, len(msg.map_string_string))
1507
1508 # Bad key.
1509 with self.assertRaises(TypeError):
1510 msg.map_string_string[123] = '123'
1511
1512 # Verify that trying to assign a bad key doesn't actually add a member to
1513 # the map.
1514 self.assertEqual(1, len(msg.map_string_string))
1515
1516 # Bad value.
1517 with self.assertRaises(TypeError):
1518 msg.map_string_string['123'] = 123
1519
1520 serialized = msg.SerializeToString()
1521 msg2 = map_unittest_pb2.TestMap()
1522 msg2.ParseFromString(serialized)
1523
1524 # Bad key.
1525 with self.assertRaises(TypeError):
1526 msg2.map_string_string[123] = '123'
1527
1528 # Bad value.
1529 with self.assertRaises(TypeError):
1530 msg2.map_string_string['123'] = 123
1531
1532 self.assertEqual(-456, msg2.map_int32_int32[-123])
1533 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1534 self.assertEqual(456, msg2.map_uint32_uint32[123])
1535 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
Austin Schuh40c16522018-10-28 20:27:54 -07001536 self.assertAlmostEqual(1.2, msg.map_int32_float[2])
1537 self.assertEqual(3.3, msg.map_int32_double[1])
Brian Silverman9c614bc2016-02-15 20:20:02 -05001538 self.assertEqual('123', msg2.map_string_string['abc'])
Austin Schuh40c16522018-10-28 20:27:54 -07001539 self.assertEqual(True, msg2.map_bool_bool[True])
Brian Silverman9c614bc2016-02-15 20:20:02 -05001540 self.assertEqual(2, msg2.map_int32_enum[888])
Austin Schuh40c16522018-10-28 20:27:54 -07001541 self.assertEqual(456, msg2.map_int32_enum[123])
1542 # TODO(jieluo): Add cpp extension support.
1543 if api_implementation.Type() == 'python':
1544 self.assertEqual('{-123: -456}',
1545 str(msg2.map_int32_int32))
1546
1547 def testMapEntryAlwaysSerialized(self):
1548 msg = map_unittest_pb2.TestMap()
1549 msg.map_int32_int32[0] = 0
1550 msg.map_string_string[''] = ''
1551 self.assertEqual(msg.ByteSize(), 12)
1552 self.assertEqual(b'\n\x04\x08\x00\x10\x00r\x04\n\x00\x12\x00',
1553 msg.SerializeToString())
Brian Silverman9c614bc2016-02-15 20:20:02 -05001554
1555 def testStringUnicodeConversionInMap(self):
1556 msg = map_unittest_pb2.TestMap()
1557
1558 unicode_obj = u'\u1234'
1559 bytes_obj = unicode_obj.encode('utf8')
1560
1561 msg.map_string_string[bytes_obj] = bytes_obj
1562
1563 (key, value) = list(msg.map_string_string.items())[0]
1564
1565 self.assertEqual(key, unicode_obj)
1566 self.assertEqual(value, unicode_obj)
1567
1568 self.assertIsInstance(key, six.text_type)
1569 self.assertIsInstance(value, six.text_type)
1570
1571 def testMessageMap(self):
1572 msg = map_unittest_pb2.TestMap()
1573
1574 self.assertEqual(0, len(msg.map_int32_foreign_message))
1575 self.assertFalse(5 in msg.map_int32_foreign_message)
1576
1577 msg.map_int32_foreign_message[123]
1578 # get_or_create() is an alias for getitem.
1579 msg.map_int32_foreign_message.get_or_create(-456)
1580
1581 self.assertEqual(2, len(msg.map_int32_foreign_message))
1582 self.assertIn(123, msg.map_int32_foreign_message)
1583 self.assertIn(-456, msg.map_int32_foreign_message)
1584 self.assertEqual(2, len(msg.map_int32_foreign_message))
1585
1586 # Bad key.
1587 with self.assertRaises(TypeError):
1588 msg.map_int32_foreign_message['123']
1589
1590 # Can't assign directly to submessage.
1591 with self.assertRaises(ValueError):
1592 msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
1593
1594 # Verify that trying to assign a bad key doesn't actually add a member to
1595 # the map.
1596 self.assertEqual(2, len(msg.map_int32_foreign_message))
1597
1598 serialized = msg.SerializeToString()
1599 msg2 = map_unittest_pb2.TestMap()
1600 msg2.ParseFromString(serialized)
1601
1602 self.assertEqual(2, len(msg2.map_int32_foreign_message))
1603 self.assertIn(123, msg2.map_int32_foreign_message)
1604 self.assertIn(-456, msg2.map_int32_foreign_message)
1605 self.assertEqual(2, len(msg2.map_int32_foreign_message))
Austin Schuh40c16522018-10-28 20:27:54 -07001606 # TODO(jieluo): Fix text format for message map.
1607 # TODO(jieluo): Add cpp extension support.
1608 if api_implementation.Type() == 'python':
1609 self.assertEqual(15,
1610 len(str(msg2.map_int32_foreign_message)))
1611
1612 def testNestedMessageMapItemDelete(self):
1613 msg = map_unittest_pb2.TestMap()
1614 msg.map_int32_all_types[1].optional_nested_message.bb = 1
1615 del msg.map_int32_all_types[1]
1616 msg.map_int32_all_types[2].optional_nested_message.bb = 2
1617 self.assertEqual(1, len(msg.map_int32_all_types))
1618 msg.map_int32_all_types[1].optional_nested_message.bb = 1
1619 self.assertEqual(2, len(msg.map_int32_all_types))
1620
1621 serialized = msg.SerializeToString()
1622 msg2 = map_unittest_pb2.TestMap()
1623 msg2.ParseFromString(serialized)
1624 keys = [1, 2]
1625 # The loop triggers PyErr_Occurred() in c extension.
1626 for key in keys:
1627 del msg2.map_int32_all_types[key]
1628
1629 def testMapByteSize(self):
1630 msg = map_unittest_pb2.TestMap()
1631 msg.map_int32_int32[1] = 1
1632 size = msg.ByteSize()
1633 msg.map_int32_int32[1] = 128
1634 self.assertEqual(msg.ByteSize(), size + 1)
1635
1636 msg.map_int32_foreign_message[19].c = 1
1637 size = msg.ByteSize()
1638 msg.map_int32_foreign_message[19].c = 128
1639 self.assertEqual(msg.ByteSize(), size + 1)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001640
1641 def testMergeFrom(self):
1642 msg = map_unittest_pb2.TestMap()
1643 msg.map_int32_int32[12] = 34
1644 msg.map_int32_int32[56] = 78
1645 msg.map_int64_int64[22] = 33
1646 msg.map_int32_foreign_message[111].c = 5
1647 msg.map_int32_foreign_message[222].c = 10
1648
1649 msg2 = map_unittest_pb2.TestMap()
1650 msg2.map_int32_int32[12] = 55
1651 msg2.map_int64_int64[88] = 99
1652 msg2.map_int32_foreign_message[222].c = 15
Austin Schuh40c16522018-10-28 20:27:54 -07001653 msg2.map_int32_foreign_message[222].d = 20
1654 old_map_value = msg2.map_int32_foreign_message[222]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001655
1656 msg2.MergeFrom(msg)
1657
1658 self.assertEqual(34, msg2.map_int32_int32[12])
1659 self.assertEqual(78, msg2.map_int32_int32[56])
1660 self.assertEqual(33, msg2.map_int64_int64[22])
1661 self.assertEqual(99, msg2.map_int64_int64[88])
1662 self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
1663 self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
Austin Schuh40c16522018-10-28 20:27:54 -07001664 self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
1665 if api_implementation.Type() != 'cpp':
1666 # During the call to MergeFrom(), the C++ implementation will have
1667 # deallocated the underlying message, but this is very difficult to detect
1668 # properly. The line below is likely to cause a segmentation fault.
1669 # With the Python implementation, old_map_value is just 'detached' from
1670 # the main message. Using it will not crash of course, but since it still
1671 # have a reference to the parent message I'm sure we can find interesting
1672 # ways to cause inconsistencies.
1673 self.assertEqual(15, old_map_value.c)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001674
1675 # Verify that there is only one entry per key, even though the MergeFrom
1676 # may have internally created multiple entries for a single key in the
1677 # list representation.
1678 as_dict = {}
1679 for key in msg2.map_int32_foreign_message:
1680 self.assertFalse(key in as_dict)
1681 as_dict[key] = msg2.map_int32_foreign_message[key].c
1682
1683 self.assertEqual({111: 5, 222: 10}, as_dict)
1684
1685 # Special case: test that delete of item really removes the item, even if
1686 # there might have physically been duplicate keys due to the previous merge.
1687 # This is only a special case for the C++ implementation which stores the
1688 # map as an array.
1689 del msg2.map_int32_int32[12]
1690 self.assertFalse(12 in msg2.map_int32_int32)
1691
1692 del msg2.map_int32_foreign_message[222]
1693 self.assertFalse(222 in msg2.map_int32_foreign_message)
Austin Schuh40c16522018-10-28 20:27:54 -07001694 with self.assertRaises(TypeError):
1695 del msg2.map_int32_foreign_message['']
1696
1697 def testMapMergeFrom(self):
1698 msg = map_unittest_pb2.TestMap()
1699 msg.map_int32_int32[12] = 34
1700 msg.map_int32_int32[56] = 78
1701 msg.map_int64_int64[22] = 33
1702 msg.map_int32_foreign_message[111].c = 5
1703 msg.map_int32_foreign_message[222].c = 10
1704
1705 msg2 = map_unittest_pb2.TestMap()
1706 msg2.map_int32_int32[12] = 55
1707 msg2.map_int64_int64[88] = 99
1708 msg2.map_int32_foreign_message[222].c = 15
1709 msg2.map_int32_foreign_message[222].d = 20
1710
1711 msg2.map_int32_int32.MergeFrom(msg.map_int32_int32)
1712 self.assertEqual(34, msg2.map_int32_int32[12])
1713 self.assertEqual(78, msg2.map_int32_int32[56])
1714
1715 msg2.map_int64_int64.MergeFrom(msg.map_int64_int64)
1716 self.assertEqual(33, msg2.map_int64_int64[22])
1717 self.assertEqual(99, msg2.map_int64_int64[88])
1718
1719 msg2.map_int32_foreign_message.MergeFrom(msg.map_int32_foreign_message)
1720 self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
1721 self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
1722 self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
1723
1724 def testMergeFromBadType(self):
1725 msg = map_unittest_pb2.TestMap()
1726 with self.assertRaisesRegexp(
1727 TypeError,
1728 r'Parameter to MergeFrom\(\) must be instance of same class: expected '
1729 r'.*TestMap got int\.'):
1730 msg.MergeFrom(1)
1731
1732 def testCopyFromBadType(self):
1733 msg = map_unittest_pb2.TestMap()
1734 with self.assertRaisesRegexp(
1735 TypeError,
1736 r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
1737 r'expected .*TestMap got int\.'):
1738 msg.CopyFrom(1)
Brian Silverman9c614bc2016-02-15 20:20:02 -05001739
1740 def testIntegerMapWithLongs(self):
1741 msg = map_unittest_pb2.TestMap()
1742 msg.map_int32_int32[long(-123)] = long(-456)
1743 msg.map_int64_int64[long(-2**33)] = long(-2**34)
1744 msg.map_uint32_uint32[long(123)] = long(456)
1745 msg.map_uint64_uint64[long(2**33)] = long(2**34)
1746
1747 serialized = msg.SerializeToString()
1748 msg2 = map_unittest_pb2.TestMap()
1749 msg2.ParseFromString(serialized)
1750
1751 self.assertEqual(-456, msg2.map_int32_int32[-123])
1752 self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
1753 self.assertEqual(456, msg2.map_uint32_uint32[123])
1754 self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
1755
1756 def testMapAssignmentCausesPresence(self):
1757 msg = map_unittest_pb2.TestMapSubmessage()
1758 msg.test_map.map_int32_int32[123] = 456
1759
1760 serialized = msg.SerializeToString()
1761 msg2 = map_unittest_pb2.TestMapSubmessage()
1762 msg2.ParseFromString(serialized)
1763
1764 self.assertEqual(msg, msg2)
1765
1766 # Now test that various mutations of the map properly invalidate the
1767 # cached size of the submessage.
1768 msg.test_map.map_int32_int32[888] = 999
1769 serialized = msg.SerializeToString()
1770 msg2.ParseFromString(serialized)
1771 self.assertEqual(msg, msg2)
1772
1773 msg.test_map.map_int32_int32.clear()
1774 serialized = msg.SerializeToString()
1775 msg2.ParseFromString(serialized)
1776 self.assertEqual(msg, msg2)
1777
1778 def testMapAssignmentCausesPresenceForSubmessages(self):
1779 msg = map_unittest_pb2.TestMapSubmessage()
1780 msg.test_map.map_int32_foreign_message[123].c = 5
1781
1782 serialized = msg.SerializeToString()
1783 msg2 = map_unittest_pb2.TestMapSubmessage()
1784 msg2.ParseFromString(serialized)
1785
1786 self.assertEqual(msg, msg2)
1787
1788 # Now test that various mutations of the map properly invalidate the
1789 # cached size of the submessage.
1790 msg.test_map.map_int32_foreign_message[888].c = 7
1791 serialized = msg.SerializeToString()
1792 msg2.ParseFromString(serialized)
1793 self.assertEqual(msg, msg2)
1794
1795 msg.test_map.map_int32_foreign_message[888].MergeFrom(
1796 msg.test_map.map_int32_foreign_message[123])
1797 serialized = msg.SerializeToString()
1798 msg2.ParseFromString(serialized)
1799 self.assertEqual(msg, msg2)
1800
1801 msg.test_map.map_int32_foreign_message.clear()
1802 serialized = msg.SerializeToString()
1803 msg2.ParseFromString(serialized)
1804 self.assertEqual(msg, msg2)
1805
1806 def testModifyMapWhileIterating(self):
1807 msg = map_unittest_pb2.TestMap()
1808
1809 string_string_iter = iter(msg.map_string_string)
1810 int32_foreign_iter = iter(msg.map_int32_foreign_message)
1811
1812 msg.map_string_string['abc'] = '123'
1813 msg.map_int32_foreign_message[5].c = 5
1814
1815 with self.assertRaises(RuntimeError):
1816 for key in string_string_iter:
1817 pass
1818
1819 with self.assertRaises(RuntimeError):
1820 for key in int32_foreign_iter:
1821 pass
1822
1823 def testSubmessageMap(self):
1824 msg = map_unittest_pb2.TestMap()
1825
1826 submsg = msg.map_int32_foreign_message[111]
1827 self.assertIs(submsg, msg.map_int32_foreign_message[111])
1828 self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
1829
1830 submsg.c = 5
1831
1832 serialized = msg.SerializeToString()
1833 msg2 = map_unittest_pb2.TestMap()
1834 msg2.ParseFromString(serialized)
1835
1836 self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
1837
1838 # Doesn't allow direct submessage assignment.
1839 with self.assertRaises(ValueError):
1840 msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
1841
1842 def testMapIteration(self):
1843 msg = map_unittest_pb2.TestMap()
1844
1845 for k, v in msg.map_int32_int32.items():
1846 # Should not be reached.
1847 self.assertTrue(False)
1848
1849 msg.map_int32_int32[2] = 4
1850 msg.map_int32_int32[3] = 6
1851 msg.map_int32_int32[4] = 8
1852 self.assertEqual(3, len(msg.map_int32_int32))
1853
1854 matching_dict = {2: 4, 3: 6, 4: 8}
1855 self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
1856
Austin Schuh40c16522018-10-28 20:27:54 -07001857 def testPython2Map(self):
1858 if sys.version_info < (3,):
1859 msg = map_unittest_pb2.TestMap()
1860 msg.map_int32_int32[2] = 4
1861 msg.map_int32_int32[3] = 6
1862 msg.map_int32_int32[4] = 8
1863 msg.map_int32_int32[5] = 10
1864 map_int32 = msg.map_int32_int32
1865 self.assertEqual(4, len(map_int32))
1866 msg2 = map_unittest_pb2.TestMap()
1867 msg2.ParseFromString(msg.SerializeToString())
1868
1869 def CheckItems(seq, iterator):
1870 self.assertEqual(next(iterator), seq[0])
1871 self.assertEqual(list(iterator), seq[1:])
1872
1873 CheckItems(map_int32.items(), map_int32.iteritems())
1874 CheckItems(map_int32.keys(), map_int32.iterkeys())
1875 CheckItems(map_int32.values(), map_int32.itervalues())
1876
1877 self.assertEqual(6, map_int32.get(3))
1878 self.assertEqual(None, map_int32.get(999))
1879 self.assertEqual(6, map_int32.pop(3))
1880 self.assertEqual(0, map_int32.pop(3))
1881 self.assertEqual(3, len(map_int32))
1882 key, value = map_int32.popitem()
1883 self.assertEqual(2 * key, value)
1884 self.assertEqual(2, len(map_int32))
1885 map_int32.clear()
1886 self.assertEqual(0, len(map_int32))
1887
1888 with self.assertRaises(KeyError):
1889 map_int32.popitem()
1890
1891 self.assertEqual(0, map_int32.setdefault(2))
1892 self.assertEqual(1, len(map_int32))
1893
1894 map_int32.update(msg2.map_int32_int32)
1895 self.assertEqual(4, len(map_int32))
1896
1897 with self.assertRaises(TypeError):
1898 map_int32.update(msg2.map_int32_int32,
1899 msg2.map_int32_int32)
1900 with self.assertRaises(TypeError):
1901 map_int32.update(0)
1902 with self.assertRaises(TypeError):
1903 map_int32.update(value=12)
1904
1905 def testMapItems(self):
1906 # Map items used to have strange behaviors when use c extension. Because
1907 # [] may reorder the map and invalidate any exsting iterators.
1908 # TODO(jieluo): Check if [] reordering the map is a bug or intended
1909 # behavior.
1910 msg = map_unittest_pb2.TestMap()
1911 msg.map_string_string['local_init_op'] = ''
1912 msg.map_string_string['trainable_variables'] = ''
1913 msg.map_string_string['variables'] = ''
1914 msg.map_string_string['init_op'] = ''
1915 msg.map_string_string['summaries'] = ''
1916 items1 = msg.map_string_string.items()
1917 items2 = msg.map_string_string.items()
1918 self.assertEqual(items1, items2)
1919
1920 def testMapDeterministicSerialization(self):
1921 golden_data = (b'r\x0c\n\x07init_op\x12\x01d'
1922 b'r\n\n\x05item1\x12\x01e'
1923 b'r\n\n\x05item2\x12\x01f'
1924 b'r\n\n\x05item3\x12\x01g'
1925 b'r\x0b\n\x05item4\x12\x02QQ'
1926 b'r\x12\n\rlocal_init_op\x12\x01a'
1927 b'r\x0e\n\tsummaries\x12\x01e'
1928 b'r\x18\n\x13trainable_variables\x12\x01b'
1929 b'r\x0e\n\tvariables\x12\x01c')
1930 msg = map_unittest_pb2.TestMap()
1931 msg.map_string_string['local_init_op'] = 'a'
1932 msg.map_string_string['trainable_variables'] = 'b'
1933 msg.map_string_string['variables'] = 'c'
1934 msg.map_string_string['init_op'] = 'd'
1935 msg.map_string_string['summaries'] = 'e'
1936 msg.map_string_string['item1'] = 'e'
1937 msg.map_string_string['item2'] = 'f'
1938 msg.map_string_string['item3'] = 'g'
1939 msg.map_string_string['item4'] = 'QQ'
1940
1941 # If deterministic serialization is not working correctly, this will be
1942 # "flaky" depending on the exact python dict hash seed.
1943 #
1944 # Fortunately, there are enough items in this map that it is extremely
1945 # unlikely to ever hit the "right" in-order combination, so the test
1946 # itself should fail reliably.
1947 self.assertEqual(golden_data, msg.SerializeToString(deterministic=True))
1948
Brian Silverman9c614bc2016-02-15 20:20:02 -05001949 def testMapIterationClearMessage(self):
1950 # Iterator needs to work even if message and map are deleted.
1951 msg = map_unittest_pb2.TestMap()
1952
1953 msg.map_int32_int32[2] = 4
1954 msg.map_int32_int32[3] = 6
1955 msg.map_int32_int32[4] = 8
1956
1957 it = msg.map_int32_int32.items()
1958 del msg
1959
1960 matching_dict = {2: 4, 3: 6, 4: 8}
1961 self.assertMapIterEquals(it, matching_dict)
1962
1963 def testMapConstruction(self):
1964 msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
1965 self.assertEqual(2, msg.map_int32_int32[1])
1966 self.assertEqual(4, msg.map_int32_int32[3])
1967
1968 msg = map_unittest_pb2.TestMap(
1969 map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
1970 self.assertEqual(5, msg.map_int32_foreign_message[3].c)
1971
1972 def testMapValidAfterFieldCleared(self):
1973 # Map needs to work even if field is cleared.
1974 # For the C++ implementation this tests the correctness of
1975 # ScalarMapContainer::Release()
1976 msg = map_unittest_pb2.TestMap()
1977 int32_map = msg.map_int32_int32
1978
1979 int32_map[2] = 4
1980 int32_map[3] = 6
1981 int32_map[4] = 8
1982
1983 msg.ClearField('map_int32_int32')
1984 self.assertEqual(b'', msg.SerializeToString())
1985 matching_dict = {2: 4, 3: 6, 4: 8}
1986 self.assertMapIterEquals(int32_map.items(), matching_dict)
1987
1988 def testMessageMapValidAfterFieldCleared(self):
1989 # Map needs to work even if field is cleared.
1990 # For the C++ implementation this tests the correctness of
1991 # ScalarMapContainer::Release()
1992 msg = map_unittest_pb2.TestMap()
1993 int32_foreign_message = msg.map_int32_foreign_message
1994
1995 int32_foreign_message[2].c = 5
1996
1997 msg.ClearField('map_int32_foreign_message')
1998 self.assertEqual(b'', msg.SerializeToString())
1999 self.assertTrue(2 in int32_foreign_message.keys())
2000
2001 def testMapIterInvalidatedByClearField(self):
2002 # Map iterator is invalidated when field is cleared.
2003 # But this case does need to not crash the interpreter.
2004 # For the C++ implementation this tests the correctness of
2005 # ScalarMapContainer::Release()
2006 msg = map_unittest_pb2.TestMap()
2007
2008 it = iter(msg.map_int32_int32)
2009
2010 msg.ClearField('map_int32_int32')
2011 with self.assertRaises(RuntimeError):
2012 for _ in it:
2013 pass
2014
2015 it = iter(msg.map_int32_foreign_message)
2016 msg.ClearField('map_int32_foreign_message')
2017 with self.assertRaises(RuntimeError):
2018 for _ in it:
2019 pass
2020
2021 def testMapDelete(self):
2022 msg = map_unittest_pb2.TestMap()
2023
2024 self.assertEqual(0, len(msg.map_int32_int32))
2025
2026 msg.map_int32_int32[4] = 6
2027 self.assertEqual(1, len(msg.map_int32_int32))
2028
2029 with self.assertRaises(KeyError):
2030 del msg.map_int32_int32[88]
2031
2032 del msg.map_int32_int32[4]
2033 self.assertEqual(0, len(msg.map_int32_int32))
2034
Austin Schuh40c16522018-10-28 20:27:54 -07002035 with self.assertRaises(KeyError):
2036 del msg.map_int32_all_types[32]
2037
Brian Silverman9c614bc2016-02-15 20:20:02 -05002038 def testMapsAreMapping(self):
2039 msg = map_unittest_pb2.TestMap()
2040 self.assertIsInstance(msg.map_int32_int32, collections.Mapping)
2041 self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping)
2042 self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping)
2043 self.assertIsInstance(msg.map_int32_foreign_message,
2044 collections.MutableMapping)
2045
Austin Schuh40c16522018-10-28 20:27:54 -07002046 def testMapsCompare(self):
2047 msg = map_unittest_pb2.TestMap()
2048 msg.map_int32_int32[-123] = -456
2049 self.assertEqual(msg.map_int32_int32, msg.map_int32_int32)
2050 self.assertEqual(msg.map_int32_foreign_message,
2051 msg.map_int32_foreign_message)
2052 self.assertNotEqual(msg.map_int32_int32, 0)
2053
Brian Silverman9c614bc2016-02-15 20:20:02 -05002054 def testMapFindInitializationErrorsSmokeTest(self):
2055 msg = map_unittest_pb2.TestMap()
2056 msg.map_string_string['abc'] = '123'
2057 msg.map_int32_int32[35] = 64
2058 msg.map_string_foreign_message['foo'].c = 5
2059 self.assertEqual(0, len(msg.FindInitializationErrors()))
2060
Brian Silverman9c614bc2016-02-15 20:20:02 -05002061
2062
Austin Schuh40c16522018-10-28 20:27:54 -07002063class ValidTypeNamesTest(BaseTestCase):
Brian Silverman9c614bc2016-02-15 20:20:02 -05002064
2065 def assertImportFromName(self, msg, base_name):
2066 # Parse <type 'module.class_name'> to extra 'some.name' as a string.
2067 tp_name = str(type(msg)).split("'")[1]
2068 valid_names = ('Repeated%sContainer' % base_name,
2069 'Repeated%sFieldContainer' % base_name)
2070 self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
2071 '%r does end with any of %r' % (tp_name, valid_names))
2072
2073 parts = tp_name.split('.')
2074 class_name = parts[-1]
2075 module_name = '.'.join(parts[:-1])
2076 __import__(module_name, fromlist=[class_name])
2077
2078 def testTypeNamesCanBeImported(self):
2079 # If import doesn't work, pickling won't work either.
2080 pb = unittest_pb2.TestAllTypes()
2081 self.assertImportFromName(pb.repeated_int32, 'Scalar')
2082 self.assertImportFromName(pb.repeated_nested_message, 'Composite')
2083
Austin Schuh40c16522018-10-28 20:27:54 -07002084class PackedFieldTest(BaseTestCase):
Brian Silverman9c614bc2016-02-15 20:20:02 -05002085
2086 def setMessage(self, message):
2087 message.repeated_int32.append(1)
2088 message.repeated_int64.append(1)
2089 message.repeated_uint32.append(1)
2090 message.repeated_uint64.append(1)
2091 message.repeated_sint32.append(1)
2092 message.repeated_sint64.append(1)
2093 message.repeated_fixed32.append(1)
2094 message.repeated_fixed64.append(1)
2095 message.repeated_sfixed32.append(1)
2096 message.repeated_sfixed64.append(1)
2097 message.repeated_float.append(1.0)
2098 message.repeated_double.append(1.0)
2099 message.repeated_bool.append(True)
2100 message.repeated_nested_enum.append(1)
2101
2102 def testPackedFields(self):
2103 message = packed_field_test_pb2.TestPackedTypes()
2104 self.setMessage(message)
2105 golden_data = (b'\x0A\x01\x01'
2106 b'\x12\x01\x01'
2107 b'\x1A\x01\x01'
2108 b'\x22\x01\x01'
2109 b'\x2A\x01\x02'
2110 b'\x32\x01\x02'
2111 b'\x3A\x04\x01\x00\x00\x00'
2112 b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2113 b'\x4A\x04\x01\x00\x00\x00'
2114 b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
2115 b'\x5A\x04\x00\x00\x80\x3f'
2116 b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
2117 b'\x6A\x01\x01'
2118 b'\x72\x01\x01')
2119 self.assertEqual(golden_data, message.SerializeToString())
2120
2121 def testUnpackedFields(self):
2122 message = packed_field_test_pb2.TestUnpackedTypes()
2123 self.setMessage(message)
2124 golden_data = (b'\x08\x01'
2125 b'\x10\x01'
2126 b'\x18\x01'
2127 b'\x20\x01'
2128 b'\x28\x02'
2129 b'\x30\x02'
2130 b'\x3D\x01\x00\x00\x00'
2131 b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
2132 b'\x4D\x01\x00\x00\x00'
2133 b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
2134 b'\x5D\x00\x00\x80\x3f'
2135 b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
2136 b'\x68\x01'
2137 b'\x70\x01')
2138 self.assertEqual(golden_data, message.SerializeToString())
2139
Austin Schuh40c16522018-10-28 20:27:54 -07002140
2141@unittest.skipIf(api_implementation.Type() != 'cpp' or
2142 sys.version_info < (2, 7),
2143 'explicit tests of the C++ implementation for PY27 and above')
2144class OversizeProtosTest(BaseTestCase):
2145
2146 @classmethod
2147 def setUpClass(cls):
2148 # At the moment, reference cycles between DescriptorPool and Message classes
2149 # are not detected and these objects are never freed.
2150 # To avoid errors with ReferenceLeakChecker, we create the class only once.
2151 file_desc = """
2152 name: "f/f.msg2"
2153 package: "f"
2154 message_type {
2155 name: "msg1"
2156 field {
2157 name: "payload"
2158 number: 1
2159 label: LABEL_OPTIONAL
2160 type: TYPE_STRING
2161 }
2162 }
2163 message_type {
2164 name: "msg2"
2165 field {
2166 name: "field"
2167 number: 1
2168 label: LABEL_OPTIONAL
2169 type: TYPE_MESSAGE
2170 type_name: "msg1"
2171 }
2172 }
2173 """
2174 pool = descriptor_pool.DescriptorPool()
2175 desc = descriptor_pb2.FileDescriptorProto()
2176 text_format.Parse(file_desc, desc)
2177 pool.Add(desc)
2178 cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
2179 pool.FindMessageTypeByName('f.msg2'))
2180
2181 def setUp(self):
2182 self.p = self.proto_cls()
2183 self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
2184 self.p_serialized = self.p.SerializeToString()
2185
2186 def testAssertOversizeProto(self):
2187 from google.protobuf.pyext._message import SetAllowOversizeProtos
2188 SetAllowOversizeProtos(False)
2189 q = self.proto_cls()
2190 try:
2191 q.ParseFromString(self.p_serialized)
2192 except message.DecodeError as e:
2193 self.assertEqual(str(e), 'Error parsing message')
2194
2195 def testSucceedOversizeProto(self):
2196 from google.protobuf.pyext._message import SetAllowOversizeProtos
2197 SetAllowOversizeProtos(True)
2198 q = self.proto_cls()
2199 q.ParseFromString(self.p_serialized)
2200 self.assertEqual(self.p.field.payload, q.field.payload)
2201
Brian Silverman9c614bc2016-02-15 20:20:02 -05002202if __name__ == '__main__':
2203 unittest.main()