blob: df9138a4d8349fdd83cc5255e88a8823ee315a3b [file] [log] [blame]
Brian Silverman9c614bc2016-02-15 20:20:02 -05001// Protocol Buffers - Google's data interchange format
2// Copyright 2008 Google Inc. All rights reserved.
3// https://developers.google.com/protocol-buffers/
4//
5// Redistribution and use in source and binary forms, with or without
6// modification, are permitted provided that the following conditions are
7// met:
8//
9// * Redistributions of source code must retain the above copyright
10// notice, this list of conditions and the following disclaimer.
11// * Redistributions in binary form must reproduce the above
12// copyright notice, this list of conditions and the following disclaimer
13// in the documentation and/or other materials provided with the
14// distribution.
15// * Neither the name of Google Inc. nor the names of its
16// contributors may be used to endorse or promote products derived from
17// this software without specific prior written permission.
18//
19// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
23// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
24// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
25// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
26// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
27// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
31// Author: haberman@google.com (Josh Haberman)
32
33#include <google/protobuf/pyext/map_container.h>
34
35#include <google/protobuf/stubs/logging.h>
36#include <google/protobuf/stubs/common.h>
37#include <google/protobuf/stubs/scoped_ptr.h>
38#include <google/protobuf/map_field.h>
39#include <google/protobuf/map.h>
40#include <google/protobuf/message.h>
41#include <google/protobuf/pyext/message.h>
42#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
43
44#if PY_MAJOR_VERSION >= 3
45 #define PyInt_FromLong PyLong_FromLong
46 #define PyInt_FromSize_t PyLong_FromSize_t
47#endif
48
49namespace google {
50namespace protobuf {
51namespace python {
52
53// Functions that need access to map reflection functionality.
54// They need to be contained in this class because it is friended.
55class MapReflectionFriend {
56 public:
57 // Methods that are in common between the map types.
58 static PyObject* Contains(PyObject* _self, PyObject* key);
59 static Py_ssize_t Length(PyObject* _self);
60 static PyObject* GetIterator(PyObject *_self);
61 static PyObject* IterNext(PyObject* _self);
62
63 // Methods that differ between the map types.
64 static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key);
65 static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
66 static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
67 static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
68};
69
70struct MapIterator {
71 PyObject_HEAD;
72
73 scoped_ptr< ::google::protobuf::MapIterator> iter;
74
75 // A pointer back to the container, so we can notice changes to the version.
76 // We own a ref on this.
77 MapContainer* container;
78
79 // We need to keep a ref on the Message* too, because
80 // MapIterator::~MapIterator() accesses it. Normally this would be ok because
81 // the ref on container (above) would guarantee outlive semantics. However in
82 // the case of ClearField(), InitializeAndCopyToParentContainer() resets the
83 // message pointer (and the owner) to a different message, a copy of the
84 // original. But our iterator still points to the original, which could now
85 // get deleted before us.
86 //
87 // To prevent this, we ensure that the Message will always stay alive as long
88 // as this iterator does. This is solely for the benefit of the MapIterator
89 // destructor -- we should never actually access the iterator in this state
90 // except to delete it.
91 shared_ptr<Message> owner;
92
93 // The version of the map when we took the iterator to it.
94 //
95 // We store this so that if the map is modified during iteration we can throw
96 // an error.
97 uint64 version;
98
99 // True if the container is empty. We signal this separately to avoid calling
100 // any of the iteration methods, which are non-const.
101 bool empty;
102};
103
104Message* MapContainer::GetMutableMessage() {
105 cmessage::AssureWritable(parent);
106 return const_cast<Message*>(message);
107}
108
109// Consumes a reference on the Python string object.
110static bool PyStringToSTL(PyObject* py_string, string* stl_string) {
111 char *value;
112 Py_ssize_t value_len;
113
114 if (!py_string) {
115 return false;
116 }
117 if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
118 Py_DECREF(py_string);
119 return false;
120 } else {
121 stl_string->assign(value, value_len);
122 Py_DECREF(py_string);
123 return true;
124 }
125}
126
127static bool PythonToMapKey(PyObject* obj,
128 const FieldDescriptor* field_descriptor,
129 MapKey* key) {
130 switch (field_descriptor->cpp_type()) {
131 case FieldDescriptor::CPPTYPE_INT32: {
132 GOOGLE_CHECK_GET_INT32(obj, value, false);
133 key->SetInt32Value(value);
134 break;
135 }
136 case FieldDescriptor::CPPTYPE_INT64: {
137 GOOGLE_CHECK_GET_INT64(obj, value, false);
138 key->SetInt64Value(value);
139 break;
140 }
141 case FieldDescriptor::CPPTYPE_UINT32: {
142 GOOGLE_CHECK_GET_UINT32(obj, value, false);
143 key->SetUInt32Value(value);
144 break;
145 }
146 case FieldDescriptor::CPPTYPE_UINT64: {
147 GOOGLE_CHECK_GET_UINT64(obj, value, false);
148 key->SetUInt64Value(value);
149 break;
150 }
151 case FieldDescriptor::CPPTYPE_BOOL: {
152 GOOGLE_CHECK_GET_BOOL(obj, value, false);
153 key->SetBoolValue(value);
154 break;
155 }
156 case FieldDescriptor::CPPTYPE_STRING: {
157 string str;
158 if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
159 return false;
160 }
161 key->SetStringValue(str);
162 break;
163 }
164 default:
165 PyErr_Format(
166 PyExc_SystemError, "Type %d cannot be a map key",
167 field_descriptor->cpp_type());
168 return false;
169 }
170 return true;
171}
172
173static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor,
174 const MapKey& key) {
175 switch (field_descriptor->cpp_type()) {
176 case FieldDescriptor::CPPTYPE_INT32:
177 return PyInt_FromLong(key.GetInt32Value());
178 case FieldDescriptor::CPPTYPE_INT64:
179 return PyLong_FromLongLong(key.GetInt64Value());
180 case FieldDescriptor::CPPTYPE_UINT32:
181 return PyInt_FromSize_t(key.GetUInt32Value());
182 case FieldDescriptor::CPPTYPE_UINT64:
183 return PyLong_FromUnsignedLongLong(key.GetUInt64Value());
184 case FieldDescriptor::CPPTYPE_BOOL:
185 return PyBool_FromLong(key.GetBoolValue());
186 case FieldDescriptor::CPPTYPE_STRING:
187 return ToStringObject(field_descriptor, key.GetStringValue());
188 default:
189 PyErr_Format(
190 PyExc_SystemError, "Couldn't convert type %d to value",
191 field_descriptor->cpp_type());
192 return NULL;
193 }
194}
195
196// This is only used for ScalarMap, so we don't need to handle the
197// CPPTYPE_MESSAGE case.
198PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor,
199 MapValueRef* value) {
200 switch (field_descriptor->cpp_type()) {
201 case FieldDescriptor::CPPTYPE_INT32:
202 return PyInt_FromLong(value->GetInt32Value());
203 case FieldDescriptor::CPPTYPE_INT64:
204 return PyLong_FromLongLong(value->GetInt64Value());
205 case FieldDescriptor::CPPTYPE_UINT32:
206 return PyInt_FromSize_t(value->GetUInt32Value());
207 case FieldDescriptor::CPPTYPE_UINT64:
208 return PyLong_FromUnsignedLongLong(value->GetUInt64Value());
209 case FieldDescriptor::CPPTYPE_FLOAT:
210 return PyFloat_FromDouble(value->GetFloatValue());
211 case FieldDescriptor::CPPTYPE_DOUBLE:
212 return PyFloat_FromDouble(value->GetDoubleValue());
213 case FieldDescriptor::CPPTYPE_BOOL:
214 return PyBool_FromLong(value->GetBoolValue());
215 case FieldDescriptor::CPPTYPE_STRING:
216 return ToStringObject(field_descriptor, value->GetStringValue());
217 case FieldDescriptor::CPPTYPE_ENUM:
218 return PyInt_FromLong(value->GetEnumValue());
219 default:
220 PyErr_Format(
221 PyExc_SystemError, "Couldn't convert type %d to value",
222 field_descriptor->cpp_type());
223 return NULL;
224 }
225}
226
227// This is only used for ScalarMap, so we don't need to handle the
228// CPPTYPE_MESSAGE case.
229static bool PythonToMapValueRef(PyObject* obj,
230 const FieldDescriptor* field_descriptor,
231 bool allow_unknown_enum_values,
232 MapValueRef* value_ref) {
233 switch (field_descriptor->cpp_type()) {
234 case FieldDescriptor::CPPTYPE_INT32: {
235 GOOGLE_CHECK_GET_INT32(obj, value, false);
236 value_ref->SetInt32Value(value);
237 return true;
238 }
239 case FieldDescriptor::CPPTYPE_INT64: {
240 GOOGLE_CHECK_GET_INT64(obj, value, false);
241 value_ref->SetInt64Value(value);
242 return true;
243 }
244 case FieldDescriptor::CPPTYPE_UINT32: {
245 GOOGLE_CHECK_GET_UINT32(obj, value, false);
246 value_ref->SetUInt32Value(value);
247 return true;
248 }
249 case FieldDescriptor::CPPTYPE_UINT64: {
250 GOOGLE_CHECK_GET_UINT64(obj, value, false);
251 value_ref->SetUInt64Value(value);
252 return true;
253 }
254 case FieldDescriptor::CPPTYPE_FLOAT: {
255 GOOGLE_CHECK_GET_FLOAT(obj, value, false);
256 value_ref->SetFloatValue(value);
257 return true;
258 }
259 case FieldDescriptor::CPPTYPE_DOUBLE: {
260 GOOGLE_CHECK_GET_DOUBLE(obj, value, false);
261 value_ref->SetDoubleValue(value);
262 return true;
263 }
264 case FieldDescriptor::CPPTYPE_BOOL: {
265 GOOGLE_CHECK_GET_BOOL(obj, value, false);
266 value_ref->SetBoolValue(value);
267 return true;;
268 }
269 case FieldDescriptor::CPPTYPE_STRING: {
270 string str;
271 if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) {
272 return false;
273 }
274 value_ref->SetStringValue(str);
275 return true;
276 }
277 case FieldDescriptor::CPPTYPE_ENUM: {
278 GOOGLE_CHECK_GET_INT32(obj, value, false);
279 if (allow_unknown_enum_values) {
280 value_ref->SetEnumValue(value);
281 return true;
282 } else {
283 const EnumDescriptor* enum_descriptor = field_descriptor->enum_type();
284 const EnumValueDescriptor* enum_value =
285 enum_descriptor->FindValueByNumber(value);
286 if (enum_value != NULL) {
287 value_ref->SetEnumValue(value);
288 return true;
289 } else {
290 PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value);
291 return false;
292 }
293 }
294 break;
295 }
296 default:
297 PyErr_Format(
298 PyExc_SystemError, "Setting value to a field of unknown type %d",
299 field_descriptor->cpp_type());
300 return false;
301 }
302}
303
304// Map methods common to ScalarMap and MessageMap //////////////////////////////
305
306static MapContainer* GetMap(PyObject* obj) {
307 return reinterpret_cast<MapContainer*>(obj);
308}
309
310Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
311 MapContainer* self = GetMap(_self);
312 const google::protobuf::Message* message = self->message;
313 return message->GetReflection()->MapSize(*message,
314 self->parent_field_descriptor);
315}
316
317PyObject* Clear(PyObject* _self) {
318 MapContainer* self = GetMap(_self);
319 Message* message = self->GetMutableMessage();
320 const Reflection* reflection = message->GetReflection();
321
322 reflection->ClearField(message, self->parent_field_descriptor);
323
324 Py_RETURN_NONE;
325}
326
327PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) {
328 MapContainer* self = GetMap(_self);
329
330 const Message* message = self->message;
331 const Reflection* reflection = message->GetReflection();
332 MapKey map_key;
333
334 if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
335 return NULL;
336 }
337
338 if (reflection->ContainsMapKey(*message, self->parent_field_descriptor,
339 map_key)) {
340 Py_RETURN_TRUE;
341 } else {
342 Py_RETURN_FALSE;
343 }
344}
345
346// Initializes the underlying Message object of "to" so it becomes a new parent
347// repeated scalar, and copies all the values from "from" to it. A child scalar
348// container can be released by passing it as both from and to (e.g. making it
349// the recipient of the new parent message and copying the values from itself).
350static int InitializeAndCopyToParentContainer(MapContainer* from,
351 MapContainer* to) {
352 // For now we require from == to, re-evaluate if we want to support deep copy
353 // as in repeated_scalar_container.cc.
354 GOOGLE_DCHECK(from == to);
355 Message* new_message = from->message->New();
356
357 if (MapReflectionFriend::Length(reinterpret_cast<PyObject*>(from)) > 0) {
358 // A somewhat roundabout way of copying just one field from old_message to
359 // new_message. This is the best we can do with what Reflection gives us.
360 Message* mutable_old = from->GetMutableMessage();
361 vector<const FieldDescriptor*> fields;
362 fields.push_back(from->parent_field_descriptor);
363
364 // Move the map field into the new message.
365 mutable_old->GetReflection()->SwapFields(mutable_old, new_message, fields);
366
367 // If/when we support from != to, this will be required also to copy the
368 // map field back into the existing message:
369 // mutable_old->MergeFrom(*new_message);
370 }
371
372 // If from == to this could delete old_message.
373 to->owner.reset(new_message);
374
375 to->parent = NULL;
376 to->parent_field_descriptor = from->parent_field_descriptor;
377 to->message = new_message;
378
379 // Invalidate iterators, since they point to the old copy of the field.
380 to->version++;
381
382 return 0;
383}
384
385int MapContainer::Release() {
386 return InitializeAndCopyToParentContainer(this, this);
387}
388
389
390// ScalarMap ///////////////////////////////////////////////////////////////////
391
392PyObject *NewScalarMapContainer(
393 CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
394 if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
395 return NULL;
396 }
397
398#if PY_MAJOR_VERSION >= 3
399 ScopedPyObjectPtr obj(PyType_GenericAlloc(
400 reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0));
401#else
402 ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0));
403#endif
404 if (obj.get() == NULL) {
405 return PyErr_Format(PyExc_RuntimeError,
406 "Could not allocate new container.");
407 }
408
409 MapContainer* self = GetMap(obj.get());
410
411 self->message = parent->message;
412 self->parent = parent;
413 self->parent_field_descriptor = parent_field_descriptor;
414 self->owner = parent->owner;
415 self->version = 0;
416
417 self->key_field_descriptor =
418 parent_field_descriptor->message_type()->FindFieldByName("key");
419 self->value_field_descriptor =
420 parent_field_descriptor->message_type()->FindFieldByName("value");
421
422 if (self->key_field_descriptor == NULL ||
423 self->value_field_descriptor == NULL) {
424 return PyErr_Format(PyExc_KeyError,
425 "Map entry descriptor did not have key/value fields");
426 }
427
428 return obj.release();
429}
430
431PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
432 PyObject* key) {
433 MapContainer* self = GetMap(_self);
434
435 Message* message = self->GetMutableMessage();
436 const Reflection* reflection = message->GetReflection();
437 MapKey map_key;
438 MapValueRef value;
439
440 if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
441 return NULL;
442 }
443
444 if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
445 map_key, &value)) {
446 self->version++;
447 }
448
449 return MapValueRefToPython(self->value_field_descriptor, &value);
450}
451
452int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
453 PyObject* v) {
454 MapContainer* self = GetMap(_self);
455
456 Message* message = self->GetMutableMessage();
457 const Reflection* reflection = message->GetReflection();
458 MapKey map_key;
459 MapValueRef value;
460
461 if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
462 return -1;
463 }
464
465 self->version++;
466
467 if (v) {
468 // Set item to v.
469 reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
470 map_key, &value);
471
472 return PythonToMapValueRef(v, self->value_field_descriptor,
473 reflection->SupportsUnknownEnumValues(), &value)
474 ? 0
475 : -1;
476 } else {
477 // Delete key from map.
478 if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
479 map_key)) {
480 return 0;
481 } else {
482 PyErr_Format(PyExc_KeyError, "Key not present in map");
483 return -1;
484 }
485 }
486}
487
488static PyObject* ScalarMapGet(PyObject* self, PyObject* args) {
489 PyObject* key;
490 PyObject* default_value = NULL;
491 if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
492 return NULL;
493 }
494
495 ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
496 if (is_present.get() == NULL) {
497 return NULL;
498 }
499
500 if (PyObject_IsTrue(is_present.get())) {
501 return MapReflectionFriend::ScalarMapGetItem(self, key);
502 } else {
503 if (default_value != NULL) {
504 Py_INCREF(default_value);
505 return default_value;
506 } else {
507 Py_RETURN_NONE;
508 }
509 }
510}
511
512static void ScalarMapDealloc(PyObject* _self) {
513 MapContainer* self = GetMap(_self);
514 self->owner.reset();
515 Py_TYPE(_self)->tp_free(_self);
516}
517
518static PyMethodDef ScalarMapMethods[] = {
519 { "__contains__", MapReflectionFriend::Contains, METH_O,
520 "Tests whether a key is a member of the map." },
521 { "clear", (PyCFunction)Clear, METH_NOARGS,
522 "Removes all elements from the map." },
523 { "get", ScalarMapGet, METH_VARARGS,
524 "Gets the value for the given key if present, or otherwise a default" },
525 /*
526 { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
527 "Makes a deep copy of the class." },
528 { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
529 "Outputs picklable representation of the repeated field." },
530 */
531 {NULL, NULL},
532};
533
534#if PY_MAJOR_VERSION >= 3
535 static PyType_Slot ScalarMapContainer_Type_slots[] = {
536 {Py_tp_dealloc, (void *)ScalarMapDealloc},
537 {Py_mp_length, (void *)MapReflectionFriend::Length},
538 {Py_mp_subscript, (void *)MapReflectionFriend::ScalarMapGetItem},
539 {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
540 {Py_tp_methods, (void *)ScalarMapMethods},
541 {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
542 {0, 0},
543 };
544
545 PyType_Spec ScalarMapContainer_Type_spec = {
546 FULL_MODULE_NAME ".ScalarMapContainer",
547 sizeof(MapContainer),
548 0,
549 Py_TPFLAGS_DEFAULT,
550 ScalarMapContainer_Type_slots
551 };
552 PyObject *ScalarMapContainer_Type;
553#else
554 static PyMappingMethods ScalarMapMappingMethods = {
555 MapReflectionFriend::Length, // mp_length
556 MapReflectionFriend::ScalarMapGetItem, // mp_subscript
557 MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript
558 };
559
560 PyTypeObject ScalarMapContainer_Type = {
561 PyVarObject_HEAD_INIT(&PyType_Type, 0)
562 FULL_MODULE_NAME ".ScalarMapContainer", // tp_name
563 sizeof(MapContainer), // tp_basicsize
564 0, // tp_itemsize
565 ScalarMapDealloc, // tp_dealloc
566 0, // tp_print
567 0, // tp_getattr
568 0, // tp_setattr
569 0, // tp_compare
570 0, // tp_repr
571 0, // tp_as_number
572 0, // tp_as_sequence
573 &ScalarMapMappingMethods, // tp_as_mapping
574 0, // tp_hash
575 0, // tp_call
576 0, // tp_str
577 0, // tp_getattro
578 0, // tp_setattro
579 0, // tp_as_buffer
580 Py_TPFLAGS_DEFAULT, // tp_flags
581 "A scalar map container", // tp_doc
582 0, // tp_traverse
583 0, // tp_clear
584 0, // tp_richcompare
585 0, // tp_weaklistoffset
586 MapReflectionFriend::GetIterator, // tp_iter
587 0, // tp_iternext
588 ScalarMapMethods, // tp_methods
589 0, // tp_members
590 0, // tp_getset
591 0, // tp_base
592 0, // tp_dict
593 0, // tp_descr_get
594 0, // tp_descr_set
595 0, // tp_dictoffset
596 0, // tp_init
597 };
598#endif
599
600
601// MessageMap //////////////////////////////////////////////////////////////////
602
603static MessageMapContainer* GetMessageMap(PyObject* obj) {
604 return reinterpret_cast<MessageMapContainer*>(obj);
605}
606
607static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
608 // Get or create the CMessage object corresponding to this message.
609 ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
610 PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
611
612 if (ret == NULL) {
613 CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init,
614 message->GetDescriptor());
615 ret = reinterpret_cast<PyObject*>(cmsg);
616
617 if (cmsg == NULL) {
618 return NULL;
619 }
620 cmsg->owner = self->owner;
621 cmsg->message = message;
622 cmsg->parent = self->parent;
623
624 if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
625 Py_DECREF(ret);
626 return NULL;
627 }
628 } else {
629 Py_INCREF(ret);
630 }
631
632 return ret;
633}
634
635PyObject* NewMessageMapContainer(
636 CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
637 PyObject* concrete_class) {
638 if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
639 return NULL;
640 }
641
642#if PY_MAJOR_VERSION >= 3
643 PyObject* obj = PyType_GenericAlloc(
644 reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0);
645#else
646 PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0);
647#endif
648 if (obj == NULL) {
649 return PyErr_Format(PyExc_RuntimeError,
650 "Could not allocate new container.");
651 }
652
653 MessageMapContainer* self = GetMessageMap(obj);
654
655 self->message = parent->message;
656 self->parent = parent;
657 self->parent_field_descriptor = parent_field_descriptor;
658 self->owner = parent->owner;
659 self->version = 0;
660
661 self->key_field_descriptor =
662 parent_field_descriptor->message_type()->FindFieldByName("key");
663 self->value_field_descriptor =
664 parent_field_descriptor->message_type()->FindFieldByName("value");
665
666 self->message_dict = PyDict_New();
667 if (self->message_dict == NULL) {
668 return PyErr_Format(PyExc_RuntimeError,
669 "Could not allocate message dict.");
670 }
671
672 Py_INCREF(concrete_class);
673 self->subclass_init = concrete_class;
674
675 if (self->key_field_descriptor == NULL ||
676 self->value_field_descriptor == NULL) {
677 Py_DECREF(obj);
678 return PyErr_Format(PyExc_KeyError,
679 "Map entry descriptor did not have key/value fields");
680 }
681
682 return obj;
683}
684
685int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key,
686 PyObject* v) {
687 if (v) {
688 PyErr_Format(PyExc_ValueError,
689 "Direct assignment of submessage not allowed");
690 return -1;
691 }
692
693 // Now we know that this is a delete, not a set.
694
695 MessageMapContainer* self = GetMessageMap(_self);
696 Message* message = self->GetMutableMessage();
697 const Reflection* reflection = message->GetReflection();
698 MapKey map_key;
699 MapValueRef value;
700
701 self->version++;
702
703 if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
704 return -1;
705 }
706
707 // Delete key from map.
708 if (reflection->DeleteMapValue(message, self->parent_field_descriptor,
709 map_key)) {
710 return 0;
711 } else {
712 PyErr_Format(PyExc_KeyError, "Key not present in map");
713 return -1;
714 }
715}
716
717PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
718 PyObject* key) {
719 MessageMapContainer* self = GetMessageMap(_self);
720
721 Message* message = self->GetMutableMessage();
722 const Reflection* reflection = message->GetReflection();
723 MapKey map_key;
724 MapValueRef value;
725
726 if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) {
727 return NULL;
728 }
729
730 if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor,
731 map_key, &value)) {
732 self->version++;
733 }
734
735 return GetCMessage(self, value.MutableMessageValue());
736}
737
738PyObject* MessageMapGet(PyObject* self, PyObject* args) {
739 PyObject* key;
740 PyObject* default_value = NULL;
741 if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) {
742 return NULL;
743 }
744
745 ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key));
746 if (is_present.get() == NULL) {
747 return NULL;
748 }
749
750 if (PyObject_IsTrue(is_present.get())) {
751 return MapReflectionFriend::MessageMapGetItem(self, key);
752 } else {
753 if (default_value != NULL) {
754 Py_INCREF(default_value);
755 return default_value;
756 } else {
757 Py_RETURN_NONE;
758 }
759 }
760}
761
762static void MessageMapDealloc(PyObject* _self) {
763 MessageMapContainer* self = GetMessageMap(_self);
764 self->owner.reset();
765 Py_DECREF(self->message_dict);
766 Py_TYPE(_self)->tp_free(_self);
767}
768
769static PyMethodDef MessageMapMethods[] = {
770 { "__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O,
771 "Tests whether the map contains this element."},
772 { "clear", (PyCFunction)Clear, METH_NOARGS,
773 "Removes all elements from the map."},
774 { "get", MessageMapGet, METH_VARARGS,
775 "Gets the value for the given key if present, or otherwise a default" },
776 { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O,
777 "Alias for getitem, useful to make explicit that the map is mutated." },
778 /*
779 { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS,
780 "Makes a deep copy of the class." },
781 { "__reduce__", (PyCFunction)Reduce, METH_NOARGS,
782 "Outputs picklable representation of the repeated field." },
783 */
784 {NULL, NULL},
785};
786
787#if PY_MAJOR_VERSION >= 3
788 static PyType_Slot MessageMapContainer_Type_slots[] = {
789 {Py_tp_dealloc, (void *)MessageMapDealloc},
790 {Py_mp_length, (void *)MapReflectionFriend::Length},
791 {Py_mp_subscript, (void *)MapReflectionFriend::MessageMapGetItem},
792 {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
793 {Py_tp_methods, (void *)MessageMapMethods},
794 {Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
795 {0, 0}
796 };
797
798 PyType_Spec MessageMapContainer_Type_spec = {
799 FULL_MODULE_NAME ".MessageMapContainer",
800 sizeof(MessageMapContainer),
801 0,
802 Py_TPFLAGS_DEFAULT,
803 MessageMapContainer_Type_slots
804 };
805
806 PyObject *MessageMapContainer_Type;
807#else
808 static PyMappingMethods MessageMapMappingMethods = {
809 MapReflectionFriend::Length, // mp_length
810 MapReflectionFriend::MessageMapGetItem, // mp_subscript
811 MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript
812 };
813
814 PyTypeObject MessageMapContainer_Type = {
815 PyVarObject_HEAD_INIT(&PyType_Type, 0)
816 FULL_MODULE_NAME ".MessageMapContainer", // tp_name
817 sizeof(MessageMapContainer), // tp_basicsize
818 0, // tp_itemsize
819 MessageMapDealloc, // tp_dealloc
820 0, // tp_print
821 0, // tp_getattr
822 0, // tp_setattr
823 0, // tp_compare
824 0, // tp_repr
825 0, // tp_as_number
826 0, // tp_as_sequence
827 &MessageMapMappingMethods, // tp_as_mapping
828 0, // tp_hash
829 0, // tp_call
830 0, // tp_str
831 0, // tp_getattro
832 0, // tp_setattro
833 0, // tp_as_buffer
834 Py_TPFLAGS_DEFAULT, // tp_flags
835 "A map container for message", // tp_doc
836 0, // tp_traverse
837 0, // tp_clear
838 0, // tp_richcompare
839 0, // tp_weaklistoffset
840 MapReflectionFriend::GetIterator, // tp_iter
841 0, // tp_iternext
842 MessageMapMethods, // tp_methods
843 0, // tp_members
844 0, // tp_getset
845 0, // tp_base
846 0, // tp_dict
847 0, // tp_descr_get
848 0, // tp_descr_set
849 0, // tp_dictoffset
850 0, // tp_init
851 };
852#endif
853
854// MapIterator /////////////////////////////////////////////////////////////////
855
856static MapIterator* GetIter(PyObject* obj) {
857 return reinterpret_cast<MapIterator*>(obj);
858}
859
860PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
861 MapContainer* self = GetMap(_self);
862
863 ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0));
864 if (obj == NULL) {
865 return PyErr_Format(PyExc_KeyError, "Could not allocate iterator");
866 }
867
868 MapIterator* iter = GetIter(obj.get());
869
870 Py_INCREF(self);
871 iter->container = self;
872 iter->version = self->version;
873 iter->owner = self->owner;
874
875 if (MapReflectionFriend::Length(_self) > 0) {
876 Message* message = self->GetMutableMessage();
877 const Reflection* reflection = message->GetReflection();
878
879 iter->iter.reset(new ::google::protobuf::MapIterator(
880 reflection->MapBegin(message, self->parent_field_descriptor)));
881 }
882
883 return obj.release();
884}
885
886PyObject* MapReflectionFriend::IterNext(PyObject* _self) {
887 MapIterator* self = GetIter(_self);
888
889 // This won't catch mutations to the map performed by MergeFrom(); no easy way
890 // to address that.
891 if (self->version != self->container->version) {
892 return PyErr_Format(PyExc_RuntimeError,
893 "Map modified during iteration.");
894 }
895
896 if (self->iter.get() == NULL) {
897 return NULL;
898 }
899
900 Message* message = self->container->GetMutableMessage();
901 const Reflection* reflection = message->GetReflection();
902
903 if (*self->iter ==
904 reflection->MapEnd(message, self->container->parent_field_descriptor)) {
905 return NULL;
906 }
907
908 PyObject* ret = MapKeyToPython(self->container->key_field_descriptor,
909 self->iter->GetKey());
910
911 ++(*self->iter);
912
913 return ret;
914}
915
916static void DeallocMapIterator(PyObject* _self) {
917 MapIterator* self = GetIter(_self);
918 self->iter.reset();
919 self->owner.reset();
920 Py_XDECREF(self->container);
921 Py_TYPE(_self)->tp_free(_self);
922}
923
924PyTypeObject MapIterator_Type = {
925 PyVarObject_HEAD_INIT(&PyType_Type, 0)
926 FULL_MODULE_NAME ".MapIterator", // tp_name
927 sizeof(MapIterator), // tp_basicsize
928 0, // tp_itemsize
929 DeallocMapIterator, // tp_dealloc
930 0, // tp_print
931 0, // tp_getattr
932 0, // tp_setattr
933 0, // tp_compare
934 0, // tp_repr
935 0, // tp_as_number
936 0, // tp_as_sequence
937 0, // tp_as_mapping
938 0, // tp_hash
939 0, // tp_call
940 0, // tp_str
941 0, // tp_getattro
942 0, // tp_setattro
943 0, // tp_as_buffer
944 Py_TPFLAGS_DEFAULT, // tp_flags
945 "A scalar map iterator", // tp_doc
946 0, // tp_traverse
947 0, // tp_clear
948 0, // tp_richcompare
949 0, // tp_weaklistoffset
950 PyObject_SelfIter, // tp_iter
951 MapReflectionFriend::IterNext, // tp_iternext
952 0, // tp_methods
953 0, // tp_members
954 0, // tp_getset
955 0, // tp_base
956 0, // tp_dict
957 0, // tp_descr_get
958 0, // tp_descr_set
959 0, // tp_dictoffset
960 0, // tp_init
961};
962
963} // namespace python
964} // namespace protobuf
965} // namespace google