diff --git a/evennia/typeclasses/tests.py b/evennia/typeclasses/tests.py index d97e7c3b5e..41ddaa41d5 100644 --- a/evennia/typeclasses/tests.py +++ b/evennia/typeclasses/tests.py @@ -2,9 +2,10 @@ Unit tests for typeclass base system """ + from django.test import override_settings -from evennia.utils.test_resources import BaseEvenniaTest, EvenniaTestCase from evennia.typeclasses import attributes +from evennia.utils.test_resources import BaseEvenniaTest, EvenniaTestCase from mock import patch from parameterized import parameterized @@ -13,6 +14,10 @@ from parameterized import parameterized # ------------------------------------------------------------ +class DictSubclass(dict): + pass + + class TestAttributes(BaseEvenniaTest): def test_attrhandler(self): key = "testattr" @@ -22,6 +27,25 @@ class TestAttributes(BaseEvenniaTest): self.obj1.db.testattr = value self.assertEqual(self.obj1.db.testattr, value) + # "plain" subclasses + value = DictSubclass({"fo": "foo", "bar": "bar"}) + self.obj1.db.testattr = value + self.assertEqual(self.obj1.db.testattr, value) + + self.obj1.db.testattr["fo"] = "foo2" + value.update({"fo": "foo2"}) + self.assertEqual(self.obj1.db.testattr, value) + self.assertEqual(self.obj1.attributes.get("testattr"), value) + + # nested subclasses + value = DictSubclass({"nested": True, "deep": DictSubclass({"fo": "foo", "bar": "bar"})}) + self.obj1.db.testattr = value + + self.obj1.db.testattr["deep"]["fo"] = "nemo" + value["deep"].update({"fo": "nemo"}) + self.assertEqual(self.obj1.db.testattr, value) + self.assertEqual(self.obj1.attributes.get("testattr"), value) + @override_settings(TYPECLASS_AGGRESSIVE_CACHE=False) @patch("evennia.typeclasses.attributes._TYPECLASS_AGGRESSIVE_CACHE", False) def test_attrhandler_nocache(self): @@ -35,6 +59,27 @@ class TestAttributes(BaseEvenniaTest): self.assertEqual(self.obj1.db.testattr, value) self.assertFalse(self.obj1.attributes.backend._cache) + # "plain" subclasses + value = DictSubclass({"fo": "foo", "bar": "bar"}) + self.obj1.db.testattr = value + self.assertEqual(self.obj1.db.testattr, value) + + self.obj1.db.testattr["fo"] = "foo2" + value.update({"fo": "foo2"}) + self.assertEqual(self.obj1.db.testattr, value) + self.assertEqual(self.obj1.attributes.get("testattr"), value) + self.assertFalse(self.obj1.attributes.backend._cache) + + # nested subclasses + value = DictSubclass({"nested": True, "deep": DictSubclass({"fo": "foo", "bar": "bar"})}) + self.obj1.db.testattr = value + + self.obj1.db.testattr["deep"]["fo"] = "nemo" + value["deep"].update({"fo": "nemo"}) + self.assertEqual(self.obj1.db.testattr, value) + self.assertEqual(self.obj1.attributes.get("testattr"), value) + self.assertFalse(self.obj1.attributes.backend._cache) + def test_weird_text_save(self): "test 'weird' text type (different in py2 vs py3)" from django.utils.safestring import SafeText diff --git a/evennia/utils/dbserialize.py b/evennia/utils/dbserialize.py index 11321d8dfd..136d182a7b 100644 --- a/evennia/utils/dbserialize.py +++ b/evennia/utils/dbserialize.py @@ -263,7 +263,7 @@ class _SaverList(_SaverMutable, MutableSequence): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._data = list() + self._data = kwargs.pop("_class", list)() @_save def __iadd__(self, otherlist): @@ -307,7 +307,7 @@ class _SaverDict(_SaverMutable, MutableMapping): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._data = dict() + self._data = kwargs.pop("_class", dict)() def has_key(self, key): return key in self._data @@ -645,11 +645,20 @@ def to_pickle(data): pass if hasattr(item, "__iter__"): - # we try to conserve the iterable class, if not convert to list try: - return item.__class__([process_item(val) for val in item]) - except (AttributeError, TypeError): - return [process_item(val) for val in item] + # we try to conserve the iterable class, if not convert to dict + try: + return item.__class__( + (process_item(key), process_item(val)) for key, val in item.items() + ) + except (AttributeError, TypeError): + return {process_item(key): process_item(val) for key, val in item.items()} + except: + # we try to conserve the iterable class, if not convert to list + try: + return item.__class__([process_item(val) for val in item]) + except (AttributeError, TypeError): + return [process_item(val) for val in item] elif hasattr(item, "sessid") and hasattr(item, "conn_time"): return pack_session(item) try: @@ -714,11 +723,20 @@ def from_pickle(data, db_obj=None): return deque(process_item(val) for val in item) elif hasattr(item, "__iter__"): try: - # we try to conserve the iterable class if - # it accepts an iterator - return item.__class__(process_item(val) for val in item) - except (AttributeError, TypeError): - return [process_item(val) for val in item] + # we try to conserve the iterable class, if not convert to dict + try: + return item.__class__( + (process_item(key), process_item(val)) for key, val in item.items() + ) + except (AttributeError, TypeError): + return {process_item(key): process_item(val) for key, val in item.items()} + except: + try: + # we try to conserve the iterable class if + # it accepts an iterator + return item.__class__(process_item(val) for val in item) + except (AttributeError, TypeError): + return [process_item(val) for val in item] if hasattr(item, "__deserialize_dbobjs__"): # this allows the object to custom-deserialize any embedded dbobjs @@ -780,13 +798,30 @@ def from_pickle(data, db_obj=None): return dat elif hasattr(item, "__iter__"): try: - # we try to conserve the iterable class if it - # accepts an iterator - return item.__class__(process_tree(val, parent) for val in item) - except (AttributeError, TypeError): - dat = _SaverList(_parent=parent) - dat._data.extend(process_tree(val, dat) for val in item) - return dat + # we try to conserve the iterable class, if not convert to dict + try: + dat = _SaverDict(_parent=parent, _class=item.__class__) + dat._data.update( + (process_item(key), process_tree(val, dat)) for key, val in item.items() + ) + return dat + except (AttributeError, TypeError): + dat = _SaverDict(_parent=parent) + dat._data.update( + (process_item(key), process_tree(val, dat)) for key, val in item.items() + ) + return dat + except: + try: + # we try to conserve the iterable class if it + # accepts an iterator + dat = _SaverList(_parent=parent, _class=item.__class__) + dat._data.extend(process_tree(val, dat) for val in item) + return dat + except (AttributeError, TypeError): + dat = _SaverList(_parent=parent) + dat._data.extend(process_tree(val, dat) for val in item) + return dat if hasattr(item, "__deserialize_dbobjs__"): try: @@ -800,7 +835,9 @@ def from_pickle(data, db_obj=None): # convert lists, dicts and sets to their Saved* counterparts. It # is only relevant if the "root" is an iterable of the right type. dtype = type(data) - if dtype == list: + if dtype in (str, int, float, bool, bytes, SafeString, tuple): + return process_item(data) + elif dtype == list: dat = _SaverList(_db_obj=db_obj) dat._data.extend(process_tree(val, dat) for val in data) return dat @@ -830,6 +867,34 @@ def from_pickle(data, db_obj=None): dat = _SaverDeque(_db_obj=db_obj) dat._data.extend(process_item(val) for val in data) return dat + elif hasattr(data, "__iter__"): + try: + # we try to conserve the iterable class, if not convert to dict + try: + dat = _SaverDict(_db_obj=db_obj, _class=data.__class__) + dat._data.update( + (process_item(key), process_tree(val, dat)) for key, val in data.items() + ) + return dat + except (AttributeError, TypeError): + dat = _SaverDict(_db_obj=db_obj) + dat._data.update( + (process_item(key), process_tree(val, dat)) for key, val in data.items() + ) + return dat + except: + try: + # we try to conserve the iterable class if it + # accepts an iterator + dat = _SaverList(_db_obj=db_obj, _class=data.__class__) + dat._data.extend(process_tree(val, dat) for val in data) + return dat + + except (AttributeError, TypeError): + dat = _SaverList(_db_obj=db_obj) + dat._data.extend(process_tree(val, dat) for val in data) + return dat + return process_item(data)