diff --git a/evennia/utils/dbserialize.py b/evennia/utils/dbserialize.py index 0fbec4c9e3..11321d8dfd 100644 --- a/evennia/utils/dbserialize.py +++ b/evennia/utils/dbserialize.py @@ -18,19 +18,20 @@ in-situ, e.g `obj.db.mynestedlist[3][5] = 3` would never be saved and be out of sync with the database. """ +from collections import OrderedDict, defaultdict, deque +from collections.abc import MutableMapping, MutableSequence, MutableSet from functools import update_wrapper -from collections import deque, OrderedDict, defaultdict -from collections.abc import MutableSequence, MutableSet, MutableMapping try: - from pickle import dumps, loads, UnpicklingError + from pickle import UnpicklingError, dumps, loads except ImportError: from pickle import dumps, loads -from django.core.exceptions import ObjectDoesNotExist + from django.contrib.contenttypes.models import ContentType +from django.core.exceptions import ObjectDoesNotExist from django.utils.safestring import SafeString -from evennia.utils.utils import uses_database, is_iter, to_bytes from evennia.utils import logger +from evennia.utils.utils import is_iter, to_bytes, uses_database __all__ = ("to_pickle", "from_pickle", "do_pickle", "do_unpickle", "dbserialize", "dbunserialize") @@ -786,6 +787,13 @@ def from_pickle(data, db_obj=None): dat = _SaverList(_parent=parent) dat._data.extend(process_tree(val, dat) for val in item) return dat + + if hasattr(item, "__deserialize_dbobjs__"): + try: + item.__deserialize_dbobjs__() + except (TypeError, UnpicklingError): + pass + return item if db_obj: diff --git a/evennia/utils/tests/test_dbserialize.py b/evennia/utils/tests/test_dbserialize.py index d4fd101ef0..8d7950baef 100644 --- a/evennia/utils/tests/test_dbserialize.py +++ b/evennia/utils/tests/test_dbserialize.py @@ -2,10 +2,11 @@ Tests for dbserialize module """ -from collections import deque +from collections import defaultdict, deque + from django.test import TestCase -from evennia.utils import dbserialize from evennia.objects.objects import DefaultObject +from evennia.utils import dbserialize from parameterized import parameterized @@ -93,8 +94,6 @@ class TestDbSerialize(TestCase): self.assertEqual(self.obj.db.test, {"a": True, "b": False}) def test_defaultdict(self): - from collections import defaultdict - # baseline behavior for a defaultdict _dd = defaultdict(list) _dd["a"] @@ -164,12 +163,45 @@ class DbObjWrappers(TestCase): con = _ValidContainer(self.dbobj2) self.dbobj1.db.testarg = con - # accessing the same data twice + # accessing the same data multiple times res1 = self.dbobj1.db.testarg res2 = self.dbobj1.db.testarg + res3 = self.dbobj1.db.testarg self.assertEqual(res1, res2) + self.assertEqual(res1, res3) self.assertEqual(res1, con) self.assertEqual(res2, con) self.assertEqual(res1.hidden_obj, self.dbobj2) self.assertEqual(res2.hidden_obj, self.dbobj2) + self.assertEqual(res3.hidden_obj, self.dbobj2) + + def test_dbobj_hidden_dict(self): + con1 = _ValidContainer(self.dbobj2) + con2 = _ValidContainer(self.dbobj2) + + self.dbobj1.db.dict = {} + + self.dbobj1.db.dict["key1"] = con1 + self.dbobj1.db.dict["key2"] = con2 + + self.assertEqual(self.dbobj1.db.dict["key1"].hidden_obj, self.dbobj2) + self.assertEqual(self.dbobj1.db.dict["key1"].hidden_obj, self.dbobj2) + self.assertEqual(self.dbobj1.db.dict["key2"].hidden_obj, self.dbobj2) + self.assertEqual(self.dbobj1.db.dict["key2"].hidden_obj, self.dbobj2) + + def test_dbobj_hidden_defaultdict(self): + + con1 = _ValidContainer(self.dbobj2) + con2 = _ValidContainer(self.dbobj2) + + self.dbobj1.db.dfdict = defaultdict(dict) + + self.dbobj1.db.dfdict["key"]["con1"] = con1 + self.dbobj1.db.dfdict["key"]["con2"] = con2 + + self.assertEqual(self.dbobj1.db.dfdict["key"]["con1"].hidden_obj, self.dbobj2) + + self.assertEqual(self.dbobj1.db.dfdict["key"]["con1"].hidden_obj, self.dbobj2) + self.assertEqual(self.dbobj1.db.dfdict["key"]["con2"].hidden_obj, self.dbobj2) + self.assertEqual(self.dbobj1.db.dfdict["key"]["con2"].hidden_obj, self.dbobj2)